In [12]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import tqdm
import random
import torch.nn as nn
from dataset import Thyroid_Dataset
from model import Eff_Unet
from HarDMSEG import HarDMSEG
from loss_metric import DiceLoss, IOU_score, StructureLoss
from LightMed.model.LightMed import LightMed
from PMFSNet.lib.models.PMFSNet import PMFSNet
from PMFSNet.lib.models.PMFSNet_FFT import PMFSNet_FFT
import torchvision.transforms.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as tx
from torchvision.transforms import GaussianBlur
# from hybrid_model_v1 import HybridSegModel
from hybrid_model_v3 import HybridSegModel

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [14]:
import wandb
wandb.login()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ms960068sss[0m ([33ms960068sss-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [15]:
image_size = 128
batch_size = 128
def train_augmentation(image, mask, image_size):

    p = 0.5
    
    resize = T.Resize((image_size, image_size))
    image = resize(image)
    mask = resize(mask)
    
    # if(random.random() < p):
    #     jitter = T.ColorJitter(brightness=0.5)
    #     image = jitter(image)
    # if(random.random() < p):
    #     angle = random.uniform(-10, 10)  # 旋轉角度從 ±10 度
    #     translate = (random.uniform(-0.05, 0.05) * image.size[0],
    #                  random.uniform(-0.05, 0.05) * image.size[1])  # 最多平移 ±5%
    #     scale = random.uniform(0.95, 1.05)  # 尺度縮放 ±5%
    #     shear = [random.uniform(-5, 5), random.uniform(-5, 5)]  # 小幅剪切
    
    #     image = F.affine(image, angle=angle, translate=translate, scale=scale, shear=shear)
    #     mask = F.affine(mask, angle=angle, translate=translate, scale=scale, shear=shear)
    # if random.random() < p:
    #     image = F.hflip(image)
    #     mask = F.hflip(mask)
    # if(random.random() < p):
    #     transform = T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    #     image = transform(image)  # image must be a PIL image
    
    image_tensor = tx.to_tensor(image)
    mask_tensor = tx.to_tensor(mask)

    # If standardization
    mean = image_tensor.mean()
    std = image_tensor.std()
    std = std if std > 0 else 1.0  # avoid division by zero
    image_tensor = (image_tensor - mean) / std
    
    return image_tensor, mask_tensor
def test_augmentation(image, mask, image_size):
    resize = T.Resize((image_size, image_size))
    image = resize(image)
    mask = resize(mask)

    image_tensor = tx.to_tensor(image)
    mask_tensor = tx.to_tensor(mask)

    # If standardization
    mean = image_tensor.mean()
    std = image_tensor.std()
    std = std if std > 0 else 1.0  # avoid division by zero
    image_tensor = (image_tensor - mean) / std
    return image_tensor, mask_tensor

    
train_dataset = Thyroid_Dataset("train_v2.csv", transform = train_augmentation, image_size = image_size)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

test_dataset = Thyroid_Dataset("test_v2.csv", transform = test_augmentation, image_size = image_size, return_from_dataset = True, crop_DDTI = True, histo_match = False)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

image, mask, seg_type = next(iter(train_dataloader))
image = image
mask = mask

In [16]:
print("std : ", torch.std(image))
print("unique : ", torch.unique(mask))

std :  tensor(1.0000)
unique :  tensor([0., 1.])


In [17]:
# index = 39
# for index in range(5):
#     # plt.subplot(1,2,1)
#     # plt.imshow(mask[index][0], alpha = 0.3)
#     plt.imshow(image[index][0])
#     plt.imshow(mask[index][0], alpha = 0.3)
    
#     # plt.subplot(1,2,2)
#     # plt.imshow(mask[index][0])
#     plt.show()

In [18]:
# model = Eff_Unet(
#         layers=[5, 5, 15, 10],
#         embed_dims=[40, 80, 192, 384],
#         downsamples=[True, True, True, True],
#         vit_num=6,
#         drop_path_rate=0.1,
#         num_classes=1,
#         resolution = image_size).cuda()
# model = HarDMSEG(in_channels = 1)
# model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
# model = PMFSNet(in_channels = 1, out_channels = 2, dim = "2d")
# model = PMFSNet_FFT(in_channels = 1, out_channels = 2, dim = "2d")
# model = HybridSegModel(in_channels = 1, out_channels = 2, output_size = image_size, layers_num = 3)

In [19]:
# loss_fn_nodule = StructureLoss()
# loss_fn_gland = StructureLoss()

In [20]:
# model.to("cuda")
# image = image.to("cuda")
# mask = mask.to("cuda")

In [21]:
# optimizer = optim.Adam(model.parameters(), lr=0.001)

In [22]:
# for epoch in range(1000):
#     outputs = model(image)

#     nodule_output = outputs[:, 0:1, :, :][seg_type==1]
#     nodule_mask = mask[seg_type==1]
    
#     gland_output = outputs[:, 1:2, :, :][seg_type==2]
#     gland_mask = mask[seg_type==2]
#     # outputs = torch.sigmoid(logits)
#     # print(nodule_output.shape, nodule_mask.shape)
#     nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
#     gland_loss = loss_fn_gland(gland_output, gland_mask)
    
#     loss = nodule_loss + gland_loss
    
#     IOU = (IOU_score(nodule_output, nodule_mask) + IOU_score(gland_output, gland_mask)) / 2
    
#     dice_loss = DiceLoss()
#     # print(dice_loss(nodule_output, nodule_mask))
#     # print(dice_loss(gland_output, gland_mask)) 
#     DICE = ((1 - dice_loss(nodule_output, nodule_mask)) + (1 - dice_loss(gland_output, gland_mask))) / 2
    
#     # Backward and optimize
#     optimizer.zero_grad()   # clear previous gradients
#     loss.backward()         # compute gradients
#     optimizer.step()        # update weights

#     print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, IOU: {IOU.item():.4f}, DICE: {DICE.item()}")

In [23]:
# (1 - dice_loss(nodule_output, nodule_mask))

In [24]:
# index = 4
# plt.subplot(1,2,1)
# plt.imshow(nodule_mask[index][0].detach().cpu().numpy())
# nodule_output = (nodule_output > 0.5)
# plt.subplot(1,2,2)
# plt.imshow(nodule_output[index][0].detach().cpu().numpy())

In [25]:
# index = 4
# plt.subplot(1,2,1)
# plt.imshow(gland_mask[index][0].detach().cpu().numpy())
# gland_output = (gland_output > 0.5)
# plt.subplot(1,2,2)
# plt.imshow(gland_output[index][0].detach().cpu().numpy())

In [26]:
def train(dataloader, model, optimizer, loss_fn_nodule, loss_fn_gland, device):
    total_loss = 0
    total_IOU = 0
    total_DICE = 0
    
    model.train()
    model.to(device)
    count = 0
    for image, mask, seg_type in tqdm(dataloader):
        image, mask, seg_type = image.to(device), mask.to(device), seg_type.to(device)
        outputs = model(image)

        nodule_output = outputs[:, 0:1, :, :][seg_type==1]
        nodule_mask = mask[seg_type==1]
        
        gland_output = outputs[:, 1:2, :, :][seg_type==2]
        gland_mask = mask[seg_type==2]

        nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
        gland_loss = loss_fn_gland(gland_output, gland_mask)


        
        loss = nodule_loss + gland_loss
        IOU = (IOU_score(nodule_output, nodule_mask) + IOU_score(gland_output, gland_mask)) / 2

        dice_loss = DiceLoss()
        DICE = ((1 - dice_loss(nodule_output, nodule_mask)) + (1 - dice_loss(gland_output, gland_mask))) / 2

        
        # Backward and optimize
        optimizer.zero_grad()   # clear previous gradients
        loss.backward()         # compute gradients
        optimizer.step()        # update weights

        total_loss += loss.item()
        total_IOU += IOU.item()
        total_DICE += DICE.item()
        # count+=1
        # if count==5:
        #     break
    return total_loss/len(dataloader), total_IOU/len(dataloader), total_DICE/len(dataloader)


# 
# Only calculate nodule loss, IOU, DICE, because there is no gland data in the testing set
def val(dataloader, model, loss_fn_nodule, loss_fn_gland, device):
    total_loss = 0
    DDTI_total_loss = 0
    TN3K_total_loss = 0
    
    total_IOU = 0
    DDTI_total_IOU = 0
    TN3K_total_IOU = 0
    
    total_DICE = 0
    DDTI_total_DICE = 0
    TN3K_total_DICE = 0

    model.eval()
    model.to(device)
    for image, mask, seg_type, from_dataset in tqdm(dataloader):
        image, mask, seg_type = image.to(device), mask.to(device), seg_type.to(device)
        outputs = model(image)
        
        nodule_output = outputs[:, 0:1, :, :][seg_type==1]
        nodule_mask = mask[seg_type==1]


        DDTI_nodule_output = nodule_output[from_dataset==1]
        TN3K_nodule_output = nodule_output[from_dataset==3]

        DDTI_nodule_mask = nodule_mask[from_dataset==1]
        TN3K_nodule_mask = nodule_mask[from_dataset==3]
        

        nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
        DDTI_nodule_loss = loss_fn_nodule(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_nodule_loss = loss_fn_nodule(TN3K_nodule_output, TN3K_nodule_mask)
        
        IOU = IOU_score(nodule_output, nodule_mask)
        DDTI_IOU = IOU_score(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_IOU = IOU_score(TN3K_nodule_output, TN3K_nodule_mask)

        dice_loss = DiceLoss()
        DICE = 1 - dice_loss(nodule_output, nodule_mask)
        DDTI_DICE = 1 - dice_loss(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_DICE = 1 - dice_loss(TN3K_nodule_output, TN3K_nodule_mask)
    

        total_loss += nodule_loss.item()
        DDTI_total_loss += DDTI_nodule_loss.item()
        TN3K_total_loss += TN3K_nodule_loss.item()
        
        total_IOU += IOU.item()
        DDTI_total_IOU += DDTI_IOU.item()
        TN3K_total_IOU += TN3K_IOU.item()
        
        total_DICE += DICE.item()
        DDTI_total_DICE += DDTI_DICE.item()
        TN3K_total_DICE += TN3K_DICE.item()
    return total_loss/len(dataloader), total_IOU/len(dataloader), total_DICE/len(dataloader), DDTI_total_loss/len(dataloader), DDTI_total_IOU/len(dataloader), DDTI_total_DICE/len(dataloader), TN3K_total_loss/len(dataloader), TN3K_total_IOU/len(dataloader), TN3K_total_DICE/len(dataloader), 

In [27]:
epochs = 50
lr = 0.001
project = "thyroid_hybrid_model"
# name=name="PMFSNet_crop_DDTI_standardization_aug_affine(0.5)_hflip(0.5)_lr_0.001"
# name = "HarDnet_crop_DDTI_standardization_aug_affine(0.5)_lr_0.005"
# name = "LightMed_crop_DDTI"
# name = "test"
# name = "HarDnetMSEG_baseline"
name = "hybrid_v4_baseline"
wandb.init(
    project = project,  # Project name in W&B
    name = name,       # (optional) specific run name
    config={                     # (optional) hyperparameters
        "image_size": image_size,
        "learning_rate": lr,
        "epochs": epochs,
        "batch_size": batch_size,
        # "weight_decay":1e-4
    }
)

In [28]:
# model = Eff_Unet(
#         layers=[5, 5, 15, 10],
#         embed_dims=[10, 20, 48, 96],
#         # embed_dims=[40, 80, 192, 384],
#         downsamples=[True, True, True, True],
#         vit_num=6,
#         drop_path_rate=0.1,
#         num_classes=1,
#         resolution = image_size).cuda()
# model = HarDMSEG(in_channels = 1, out_channels = 2)
# model = LightMed(in_channels = 1, out_channels = 2, image_size = image_size)
# model = PMFSNet(in_channels = 1, out_channels = 2, dim = "2d")
# model = PMFSNet_FFT(in_channels = 1, out_channels = 2, dim = "2d")
model =  HybridSegModel(in_channels = 1, out_channels = 2, output_size = image_size, layers_num = 2)

In [29]:
print(f"image size : {image_size}, lr : {lr}, epochs : {epochs}, batch size : {batch_size}, model : {model}")

image size : 128, lr : 0.001, epochs : 50, batch size : 128, model : HybridSegModel(
  (backbone): HarDNetBackbone(
    (base_conv_1): ConvLayer(
      (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU6(inplace=True)
    )
    (base_conv_2): ConvLayer(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU6(inplace=True)
    )
    (base_max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (encoder_blocks): ModuleList(
      (0): EncoderBlock(
        (hardblock): HarDBlock(
          (layers): ModuleList(
            (0): ConvLayer(
              (conv): Conv2d(64, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (norm): Batc

### If using pretrained model

In [30]:
# pretrained_name = "PMFSNet_baseline"
# # model = HarDMSEG(in_channels = 1)
# model = PMFSNet(in_channels = 1, out_channels = 1, dim = "2d")
# # model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
# checkpoint = torch.load(f"models/{pretrained_name}/best_checkpoint.pth")
# model.load_state_dict(checkpoint['model_state_dict'])

In [31]:
# loss_fn = DiceLoss()
loss_fn_nodule = StructureLoss()
loss_fn_gland = StructureLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

In [32]:
max_IOU = -1
for epoch in range(epochs):
    print(f"epoch : {epoch}")
    total_loss_train, total_IOU_train, total_DICE_train = train(train_dataloader, model, optimizer, loss_fn_nodule, loss_fn_gland, "cuda")
    print(f"train loss : {total_loss_train}, train IOU : {total_IOU_train}, train DICE : {total_DICE_train}")
    total_loss_val, total_IOU_val, total_DICE_val, DDTI_total_loss_val, DDTI_total_IOU_val, DDTI_total_DICE_val, TN3K_total_loss_val, TN3K_total_IOU_val, TN3K_total_DICE_val = val(test_dataloader, model, loss_fn_nodule, loss_fn_gland, "cuda")
    print(f"val loss : {total_loss_val}, val IOU : {total_IOU_val}, val DICE : {total_DICE_val}")
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]



    if max_IOU < total_IOU_val:
        max_IOU = total_IOU_val
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            "IOU" : total_IOU_val,
            "DICE" : total_DICE_val,
            "loss" : total_loss_val
        }
        folder = f"models/{name}"
        if not os.path.exists(folder):
            os.mkdir(folder)
        torch.save(checkpoint, f"models/{name}/best_checkpoint.pth")

    
    wandb.log({
        "epoch": epoch,
        "Learning Rate":current_lr,
        
        "train_loss": total_loss_train,
        "train_IOU": total_IOU_train,
        "train_DICE": total_DICE_train,
        
        "val_loss": total_loss_val,
        "val_IOU": total_IOU_val,
        "val_DICE": total_DICE_val,

        "DDTI_val_loss": DDTI_total_loss_val,
        "DDTI_val_IOU": DDTI_total_IOU_val,
        "DDTI_val_DICE": DDTI_total_DICE_val,

        "TN3K_val_loss": TN3K_total_loss_val,
        "TN3K_val_IOU": TN3K_total_IOU_val,
        "TN3K_val_DICE": TN3K_total_DICE_val,
        
    })




checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    "IOU" : total_IOU_val,
    "DICE" : total_DICE_val,
    "loss" : total_loss_val
}
if not os.path.exists(folder):
    os.mkdir(folder)
torch.save(checkpoint, f"models/{name}/last_checkpoint.pth")
wandb.finish()

epoch : 0


100%|██████████| 51/51 [00:21<00:00,  2.32it/s]


train loss : 1.6259741689644607, train IOU : 0.455181777769444, train DICE : 0.5526097662308637


100%|██████████| 10/10 [00:04<00:00,  2.35it/s]


val loss : 0.8756903886795044, val IOU : 0.4340920329093933, val DICE : 0.5791772603988647
epoch : 1


100%|██████████| 51/51 [00:18<00:00,  2.81it/s]


train loss : 1.0214434008972317, train IOU : 0.636589527130127, train DICE : 0.7737167606166765


100%|██████████| 10/10 [00:02<00:00,  3.65it/s]


val loss : 0.7841041922569275, val IOU : 0.5050180196762085, val DICE : 0.66857990026474
epoch : 2


100%|██████████| 51/51 [00:18<00:00,  2.80it/s]


train loss : 0.8796161109325933, train IOU : 0.6862183026238984, train DICE : 0.8156481887779984


100%|██████████| 10/10 [00:02<00:00,  3.91it/s]


val loss : 0.811805248260498, val IOU : 0.4916646659374237, val DICE : 0.6716313004493714
epoch : 3


100%|██████████| 51/51 [00:18<00:00,  2.74it/s]


train loss : 0.7786615549349317, train IOU : 0.723206100510616, train DICE : 0.843464151316998


100%|██████████| 10/10 [00:02<00:00,  3.68it/s]


val loss : 0.7068629920482635, val IOU : 0.5677817404270172, val DICE : 0.7276856660842895
epoch : 4


100%|██████████| 51/51 [00:18<00:00,  2.78it/s]


train loss : 0.7189277796184316, train IOU : 0.7437901286517873, train DICE : 0.8579946347311431


100%|██████████| 10/10 [00:02<00:00,  3.64it/s]


val loss : 0.8032983303070068, val IOU : 0.5195952862501144, val DICE : 0.681363558769226
epoch : 5


100%|██████████| 51/51 [00:18<00:00,  2.82it/s]


train loss : 0.6693949827960893, train IOU : 0.761727883535273, train DICE : 0.8701829512914022


100%|██████████| 10/10 [00:02<00:00,  3.60it/s]


val loss : 0.7146992862224579, val IOU : 0.5489626407623291, val DICE : 0.709788852930069
epoch : 6


100%|██████████| 51/51 [00:18<00:00,  2.80it/s]


train loss : 0.6258333886370939, train IOU : 0.7772558892474455, train DICE : 0.8800138024722829


100%|██████████| 10/10 [00:02<00:00,  3.62it/s]


val loss : 0.7245767295360566, val IOU : 0.5650387167930603, val DICE : 0.727181875705719
epoch : 7


100%|██████████| 51/51 [00:18<00:00,  2.81it/s]


train loss : 0.5915368526589637, train IOU : 0.7886306351306391, train DICE : 0.8876459773849038


100%|██████████| 10/10 [00:02<00:00,  3.59it/s]


val loss : 0.6486891210079193, val IOU : 0.6091369211673736, val DICE : 0.7749590575695038
epoch : 8


100%|██████████| 51/51 [00:18<00:00,  2.80it/s]


train loss : 0.5777136251038196, train IOU : 0.7932960110552171, train DICE : 0.8912825116924211


100%|██████████| 10/10 [00:02<00:00,  3.55it/s]


val loss : 0.6708660304546357, val IOU : 0.586253571510315, val DICE : 0.7513015925884247
epoch : 9


100%|██████████| 51/51 [00:18<00:00,  2.80it/s]


train loss : 0.5324719904684553, train IOU : 0.8102447273684483, train DICE : 0.9017654248312408


100%|██████████| 10/10 [00:02<00:00,  3.89it/s]


val loss : 0.7250673055648804, val IOU : 0.5789863586425781, val DICE : 0.7482457160949707
epoch : 10


100%|██████████| 51/51 [00:18<00:00,  2.74it/s]


train loss : 0.4923776581006892, train IOU : 0.8239804249183804, train DICE : 0.9097941517829895


100%|██████████| 10/10 [00:02<00:00,  3.69it/s]


val loss : 0.6103137195110321, val IOU : 0.6373073518276214, val DICE : 0.7861678600311279
epoch : 11


100%|██████████| 51/51 [00:15<00:00,  3.37it/s]


train loss : 0.46720126212811935, train IOU : 0.8315794841915953, train DICE : 0.9165112329464332


100%|██████████| 10/10 [00:02<00:00,  4.06it/s]


val loss : 0.6294735431671142, val IOU : 0.632767629623413, val DICE : 0.7989810049533844
epoch : 12


100%|██████████| 51/51 [00:14<00:00,  3.45it/s]


train loss : 0.4499847661046421, train IOU : 0.8380334505847856, train DICE : 0.919951386311475


100%|██████████| 10/10 [00:02<00:00,  4.07it/s]


val loss : 0.6024244427680969, val IOU : 0.6574016511440277, val DICE : 0.8107283294200898
epoch : 13


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.4209863870751624, train IOU : 0.8476518240629458, train DICE : 0.9260203066994163


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 0.576460188627243, val IOU : 0.6666791677474976, val DICE : 0.8161856472492218
epoch : 14


100%|██████████| 51/51 [00:14<00:00,  3.45it/s]


train loss : 0.40079152175024446, train IOU : 0.85558441807242, train DICE : 0.9300818747165156


100%|██████████| 10/10 [00:02<00:00,  4.09it/s]


val loss : 0.6322009086608886, val IOU : 0.6458377122879029, val DICE : 0.806236183643341
epoch : 15


100%|██████████| 51/51 [00:14<00:00,  3.45it/s]


train loss : 0.37041207037720025, train IOU : 0.8664428206051097, train DICE : 0.9366544812333351


100%|██████████| 10/10 [00:02<00:00,  4.05it/s]


val loss : 0.6022589445114136, val IOU : 0.6691635251045227, val DICE : 0.8133502185344696
epoch : 16


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.3561120898115869, train IOU : 0.8716841153070038, train DICE : 0.939565099921881


100%|██████████| 10/10 [00:02<00:00,  3.65it/s]


val loss : 0.6478517591953278, val IOU : 0.6719067752361297, val DICE : 0.8235528945922852
epoch : 17


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.34800341959093134, train IOU : 0.8744750911114263, train DICE : 0.9405774439082426


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 0.7398828029632568, val IOU : 0.6376881897449493, val DICE : 0.7989899694919587
epoch : 18


100%|██████████| 51/51 [00:14<00:00,  3.45it/s]


train loss : 0.3315005886788462, train IOU : 0.8805671334266663, train DICE : 0.9440760916354609


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 0.5797177612781524, val IOU : 0.6871443450450897, val DICE : 0.8260057091712951
epoch : 19


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.302651350404702, train IOU : 0.8909039403878006, train DICE : 0.9498343397589291


100%|██████████| 10/10 [00:02<00:00,  4.02it/s]


val loss : 0.6156498193740845, val IOU : 0.6793666303157806, val DICE : 0.8253415822982788
epoch : 20


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.2835715150131899, train IOU : 0.8974059247503093, train DICE : 0.9530825018882751


100%|██████████| 10/10 [00:02<00:00,  4.03it/s]


val loss : 0.7013170897960663, val IOU : 0.6787371933460236, val DICE : 0.8303032040596008
epoch : 21


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.2718540506035674, train IOU : 0.9015910976073321, train DICE : 0.9553935258996253


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 0.6468005061149598, val IOU : 0.6884904205799103, val DICE : 0.8344435453414917
epoch : 22


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.26704789464380224, train IOU : 0.9024680467212901, train DICE : 0.9562007142048256


100%|██████████| 10/10 [00:02<00:00,  4.06it/s]


val loss : 0.7677481532096863, val IOU : 0.6738708853721619, val DICE : 0.8255382299423217
epoch : 23


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.25013958063780095, train IOU : 0.9087405111275467, train DICE : 0.9591578665901633


100%|██████████| 10/10 [00:02<00:00,  4.06it/s]


val loss : 0.7183911442756653, val IOU : 0.6910272061824798, val DICE : 0.8348344624042511
epoch : 24


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.23915040697537215, train IOU : 0.912124307716594, train DICE : 0.9612403476939482


100%|██████████| 10/10 [00:02<00:00,  4.05it/s]


val loss : 0.7371488094329834, val IOU : 0.6873298525810242, val DICE : 0.839188140630722
epoch : 25


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.22233281971192828, train IOU : 0.9183342819120369, train DICE : 0.9640903239156685


100%|██████████| 10/10 [00:02<00:00,  4.08it/s]


val loss : 0.7488954246044159, val IOU : 0.7000667154788971, val DICE : 0.8387617349624634
epoch : 26


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.2074470113889844, train IOU : 0.924111126684675, train DICE : 0.9666942439827264


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 0.7353555619716644, val IOU : 0.7108044862747193, val DICE : 0.8465182304382324
epoch : 27


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.19649244172900332, train IOU : 0.9278280735015869, train DICE : 0.9686649745585871


100%|██████████| 10/10 [00:02<00:00,  4.05it/s]


val loss : 0.7495669305324555, val IOU : 0.7106425881385803, val DICE : 0.8484707295894622
epoch : 28


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.19047048132793576, train IOU : 0.9292997554236767, train DICE : 0.9695277038742515


100%|██████████| 10/10 [00:02<00:00,  4.08it/s]


val loss : 0.9163150548934936, val IOU : 0.6917639255523682, val DICE : 0.8323816359043121
epoch : 29


100%|██████████| 51/51 [00:14<00:00,  3.40it/s]


train loss : 0.18253072511916066, train IOU : 0.932635414834116, train DICE : 0.9710417109377244


100%|██████████| 10/10 [00:02<00:00,  4.08it/s]


val loss : 0.8077127039432526, val IOU : 0.7172720313072205, val DICE : 0.8512865483760834
epoch : 30


100%|██████████| 51/51 [00:14<00:00,  3.44it/s]


train loss : 0.1756596790224898, train IOU : 0.9346261550398434, train DICE : 0.9720988378805273


100%|██████████| 10/10 [00:02<00:00,  3.86it/s]


val loss : 0.8386175870895386, val IOU : 0.7123576581478119, val DICE : 0.845290207862854
epoch : 31


100%|██████████| 51/51 [00:16<00:00,  3.14it/s]


train loss : 0.1622009549070807, train IOU : 0.9397737839642692, train DICE : 0.9743716868699765


100%|██████████| 10/10 [00:02<00:00,  3.89it/s]


val loss : 0.9080664694309235, val IOU : 0.7168885111808777, val DICE : 0.8516594350337983
epoch : 32


100%|██████████| 51/51 [00:16<00:00,  3.05it/s]


train loss : 0.156896693741574, train IOU : 0.9417444327298332, train DICE : 0.9752361225146874


100%|██████████| 10/10 [00:02<00:00,  3.54it/s]


val loss : 0.9116718888282775, val IOU : 0.7191061675548553, val DICE : 0.8525829792022706
epoch : 33


100%|██████████| 51/51 [00:17<00:00,  2.97it/s]


train loss : 0.14972570626174703, train IOU : 0.9442577303624621, train DICE : 0.9764483091877956


100%|██████████| 10/10 [00:02<00:00,  3.50it/s]


val loss : 0.9706218659877777, val IOU : 0.7200993418693542, val DICE : 0.8525216221809387
epoch : 34


100%|██████████| 51/51 [00:17<00:00,  2.93it/s]


train loss : 0.1443133748629514, train IOU : 0.9464981322195015, train DICE : 0.9773210590960932


100%|██████████| 10/10 [00:02<00:00,  3.92it/s]


val loss : 0.9289769768714905, val IOU : 0.7260763823986054, val DICE : 0.8539708733558655
epoch : 35


100%|██████████| 51/51 [00:16<00:00,  3.07it/s]


train loss : 0.13721726557203368, train IOU : 0.9487211318577037, train DICE : 0.9785749292841145


100%|██████████| 10/10 [00:02<00:00,  3.85it/s]


val loss : 1.0427220344543457, val IOU : 0.7186409771442414, val DICE : 0.8527248501777649
epoch : 36


100%|██████████| 51/51 [00:16<00:00,  3.01it/s]


train loss : 0.13078558459585787, train IOU : 0.9511849552977318, train DICE : 0.9795133787042954


100%|██████████| 10/10 [00:02<00:00,  3.68it/s]


val loss : 1.1033837437629699, val IOU : 0.7207356095314026, val DICE : 0.8540454864501953
epoch : 37


100%|██████████| 51/51 [00:17<00:00,  2.88it/s]


train loss : 0.12611699323443806, train IOU : 0.9527813593546549, train DICE : 0.9804309583177754


100%|██████████| 10/10 [00:02<00:00,  3.47it/s]


val loss : 1.0945828199386596, val IOU : 0.7273571908473968, val DICE : 0.856674325466156
epoch : 38


100%|██████████| 51/51 [00:16<00:00,  3.10it/s]


train loss : 0.12401853604059593, train IOU : 0.9536855396102456, train DICE : 0.9806941478860145


100%|██████████| 10/10 [00:02<00:00,  3.81it/s]


val loss : 1.1606201827526093, val IOU : 0.7243730902671814, val DICE : 0.8551961600780487
epoch : 39


100%|██████████| 51/51 [00:15<00:00,  3.25it/s]


train loss : 0.1196915149396541, train IOU : 0.9552882091671813, train DICE : 0.9813637768521029


100%|██████████| 10/10 [00:02<00:00,  4.03it/s]


val loss : 1.1106695234775543, val IOU : 0.7301575839519501, val DICE : 0.8568877935409546
epoch : 40


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.11546978412889967, train IOU : 0.9572480250807369, train DICE : 0.9819404751646752


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 1.1382209599018096, val IOU : 0.72967050075531, val DICE : 0.857021301984787
epoch : 41


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.11139695580099143, train IOU : 0.9585748957652672, train DICE : 0.9826141946456012


100%|██████████| 10/10 [00:02<00:00,  4.04it/s]


val loss : 1.201300472021103, val IOU : 0.7293798446655273, val DICE : 0.8569592416286469
epoch : 42


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.10890956208402035, train IOU : 0.9597657357945162, train DICE : 0.9830218939220204


100%|██████████| 10/10 [00:02<00:00,  4.01it/s]


val loss : 1.2041770935058593, val IOU : 0.7293565809726715, val DICE : 0.8568849861621857
epoch : 43


100%|██████████| 51/51 [00:14<00:00,  3.43it/s]


train loss : 0.10746555643923142, train IOU : 0.9602713748520496, train DICE : 0.9832665639765122


100%|██████████| 10/10 [00:02<00:00,  4.08it/s]


val loss : 1.2046112596988678, val IOU : 0.7297796607017517, val DICE : 0.857206380367279
epoch : 44


 14%|█▎        | 7/51 [00:02<00:14,  3.10it/s]


KeyboardInterrupt: 

In [None]:
print(max_IOU)

In [None]:
inference_name = "PMFSNet_baseline"
# model = HarDMSEG(in_channels = 1)
model = PMFSNet(in_channels = 1, out_channels = 1, dim = "2d")
# model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
checkpoint = torch.load(f"models/{inference_name}/last_checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

In [None]:

total_loss_val, total_IOU_val, total_DICE_val = val(test_dataloader, model, loss_fn, "cuda")
print(f"val loss : {total_loss_val}, val IOU : {total_IOU_val}, val DICE : {total_DICE_val}")


In [None]:
DICE = 0
# model.to("cuda")
model.eval()
for image, mask in tqdm(test_dataloader):
    image, mask = image, mask
    preds = model(image)
    DICE += (1 - dice_loss(preds, mask)).item()
print(DICE/len(test_dataloader))

In [None]:
index = 7
plt.subplot(1,2,1)
plt.imshow(preds[index][0])
plt.subplot(1,2,2)
plt.imshow(mask[index][0])

In [None]:
dice_loss = DiceLoss()
print(1 - dice_loss(preds, mask))

In [None]:
plt.imshow(mask[index][0])