In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import tqdm
import random
import torch.nn as nn
from dataset import Thyroid_Dataset
from model import Eff_Unet
from HarDMSEG import HarDMSEG
from loss_metric import DiceLoss, IOU_score, StructureLoss
from LightMed.model.LightMed import LightMed
from PMFSNet.lib.models.PMFSNet import PMFSNet
from PMFSNet.lib.models.PMFSNet_FFT import PMFSNet_FFT
import torchvision.transforms.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as tx
from torchvision.transforms import GaussianBlur
# from hybrid_model_v1 import HybridSegModel
from hybrid_model_v3_upsample import HybridSegModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ms960068sss[0m ([33ms960068sss-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
image_size = 128
batch_size = 64
def train_augmentation(image, mask, image_size):

    p = 0.5
    
    resize = T.Resize((image_size, image_size))
    image = resize(image)
    mask = resize(mask)
    
    # if(random.random() < p):
    #     jitter = T.ColorJitter(brightness=0.5)
    #     image = jitter(image)
    # if(random.random() < p):
    #     angle = random.uniform(-10, 10)  # 旋轉角度從 ±10 度
    #     translate = (random.uniform(-0.05, 0.05) * image.size[0],
    #                  random.uniform(-0.05, 0.05) * image.size[1])  # 最多平移 ±5%
    #     scale = random.uniform(0.95, 1.05)  # 尺度縮放 ±5%
    #     shear = [random.uniform(-5, 5), random.uniform(-5, 5)]  # 小幅剪切
    
    #     image = F.affine(image, angle=angle, translate=translate, scale=scale, shear=shear)
    #     mask = F.affine(mask, angle=angle, translate=translate, scale=scale, shear=shear)
    # if random.random() < p:
    #     image = F.hflip(image)
    #     mask = F.hflip(mask)
    # if(random.random() < p):
    #     transform = T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    #     image = transform(image)  # image must be a PIL image
    
    image_tensor = tx.to_tensor(image)
    mask_tensor = tx.to_tensor(mask)

    # If standardization
    mean = image_tensor.mean()
    std = image_tensor.std()
    std = std if std > 0 else 1.0  # avoid division by zero
    image_tensor = (image_tensor - mean) / std
    
    return image_tensor, mask_tensor
def test_augmentation(image, mask, image_size):
    resize = T.Resize((image_size, image_size))
    image = resize(image)
    mask = resize(mask)

    image_tensor = tx.to_tensor(image)
    mask_tensor = tx.to_tensor(mask)

    # If standardization
    mean = image_tensor.mean()
    std = image_tensor.std()
    std = std if std > 0 else 1.0  # avoid division by zero
    image_tensor = (image_tensor - mean) / std
    return image_tensor, mask_tensor

    
train_dataset = Thyroid_Dataset("train_v2.csv", transform = train_augmentation, image_size = image_size)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

test_dataset = Thyroid_Dataset("test_v2.csv", transform = test_augmentation, image_size = image_size, return_from_dataset = True, crop_DDTI = True, histo_match = False)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

image, mask, seg_type = next(iter(train_dataloader))
image = image
mask = mask

In [5]:
print("std : ", torch.std(image))
print("unique : ", torch.unique(mask))

std :  tensor(1.0000)
unique :  tensor([0., 1.])


In [6]:
# index = 39
# for index in range(5):
#     # plt.subplot(1,2,1)
#     # plt.imshow(mask[index][0], alpha = 0.3)
#     plt.imshow(image[index][0])
#     plt.imshow(mask[index][0], alpha = 0.3)
    
#     # plt.subplot(1,2,2)
#     # plt.imshow(mask[index][0])
#     plt.show()

In [7]:
# model = Eff_Unet(
#         layers=[5, 5, 15, 10],
#         embed_dims=[40, 80, 192, 384],
#         downsamples=[True, True, True, True],
#         vit_num=6,
#         drop_path_rate=0.1,
#         num_classes=1,
#         resolution = image_size).cuda()
# model = HarDMSEG(in_channels = 1)
# model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
# model = PMFSNet(in_channels = 1, out_channels = 2, dim = "2d")
# model = PMFSNet_FFT(in_channels = 1, out_channels = 2, dim = "2d")
# model = HybridSegModel(in_channels = 1, out_channels = 2, output_size = image_size, layers_num = 3)

In [8]:
# loss_fn_nodule = StructureLoss()
# loss_fn_gland = StructureLoss()

In [9]:
# model.to("cuda")
# image = image.to("cuda")
# mask = mask.to("cuda")

In [10]:
# optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
# for epoch in range(1000):
#     outputs = model(image)

#     nodule_output = outputs[:, 0:1, :, :][seg_type==1]
#     nodule_mask = mask[seg_type==1]
    
#     gland_output = outputs[:, 1:2, :, :][seg_type==2]
#     gland_mask = mask[seg_type==2]
#     # outputs = torch.sigmoid(logits)
#     # print(nodule_output.shape, nodule_mask.shape)
#     nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
#     gland_loss = loss_fn_gland(gland_output, gland_mask)
    
#     loss = nodule_loss + gland_loss
    
#     IOU = (IOU_score(nodule_output, nodule_mask) + IOU_score(gland_output, gland_mask)) / 2
    
#     dice_loss = DiceLoss()
#     # print(dice_loss(nodule_output, nodule_mask))
#     # print(dice_loss(gland_output, gland_mask)) 
#     DICE = ((1 - dice_loss(nodule_output, nodule_mask)) + (1 - dice_loss(gland_output, gland_mask))) / 2
    
#     # Backward and optimize
#     optimizer.zero_grad()   # clear previous gradients
#     loss.backward()         # compute gradients
#     optimizer.step()        # update weights

#     print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, IOU: {IOU.item():.4f}, DICE: {DICE.item()}")

In [12]:
# (1 - dice_loss(nodule_output, nodule_mask))

In [13]:
# index = 4
# plt.subplot(1,2,1)
# plt.imshow(nodule_mask[index][0].detach().cpu().numpy())
# nodule_output = (nodule_output > 0.5)
# plt.subplot(1,2,2)
# plt.imshow(nodule_output[index][0].detach().cpu().numpy())

In [14]:
# index = 4
# plt.subplot(1,2,1)
# plt.imshow(gland_mask[index][0].detach().cpu().numpy())
# gland_output = (gland_output > 0.5)
# plt.subplot(1,2,2)
# plt.imshow(gland_output[index][0].detach().cpu().numpy())

In [15]:
def train(dataloader, model, optimizer, loss_fn_nodule, loss_fn_gland, device):
    total_loss = 0
    total_IOU = 0
    total_DICE = 0
    
    model.train()
    model.to(device)
    count = 0
    for image, mask, seg_type in tqdm(dataloader):
        image, mask, seg_type = image.to(device), mask.to(device), seg_type.to(device)
        outputs = model(image)

        nodule_output = outputs[:, 0:1, :, :][seg_type==1]
        nodule_mask = mask[seg_type==1]
        
        gland_output = outputs[:, 1:2, :, :][seg_type==2]
        gland_mask = mask[seg_type==2]

        nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
        gland_loss = loss_fn_gland(gland_output, gland_mask)


        
        loss = nodule_loss + gland_loss
        IOU = (IOU_score(nodule_output, nodule_mask) + IOU_score(gland_output, gland_mask)) / 2

        dice_loss = DiceLoss()
        DICE = ((1 - dice_loss(nodule_output, nodule_mask)) + (1 - dice_loss(gland_output, gland_mask))) / 2

        
        # Backward and optimize
        optimizer.zero_grad()   # clear previous gradients
        loss.backward()         # compute gradients
        optimizer.step()        # update weights

        total_loss += loss.item()
        total_IOU += IOU.item()
        total_DICE += DICE.item()
        # count+=1
        # if count==5:
        #     break
    return total_loss/len(dataloader), total_IOU/len(dataloader), total_DICE/len(dataloader)


# 
# Only calculate nodule loss, IOU, DICE, because there is no gland data in the testing set
def val(dataloader, model, loss_fn_nodule, loss_fn_gland, device):
    total_loss = 0
    DDTI_total_loss = 0
    TN3K_total_loss = 0
    
    total_IOU = 0
    DDTI_total_IOU = 0
    TN3K_total_IOU = 0
    
    total_DICE = 0
    DDTI_total_DICE = 0
    TN3K_total_DICE = 0

    model.eval()
    model.to(device)
    for image, mask, seg_type, from_dataset in tqdm(dataloader):
        image, mask, seg_type = image.to(device), mask.to(device), seg_type.to(device)
        outputs = model(image)
        
        nodule_output = outputs[:, 0:1, :, :][seg_type==1]
        nodule_mask = mask[seg_type==1]


        DDTI_nodule_output = nodule_output[from_dataset==1]
        TN3K_nodule_output = nodule_output[from_dataset==3]

        DDTI_nodule_mask = nodule_mask[from_dataset==1]
        TN3K_nodule_mask = nodule_mask[from_dataset==3]
        

        nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
        DDTI_nodule_loss = loss_fn_nodule(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_nodule_loss = loss_fn_nodule(TN3K_nodule_output, TN3K_nodule_mask)
        
        IOU = IOU_score(nodule_output, nodule_mask)
        DDTI_IOU = IOU_score(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_IOU = IOU_score(TN3K_nodule_output, TN3K_nodule_mask)

        dice_loss = DiceLoss()
        DICE = 1 - dice_loss(nodule_output, nodule_mask)
        DDTI_DICE = 1 - dice_loss(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_DICE = 1 - dice_loss(TN3K_nodule_output, TN3K_nodule_mask)
    

        total_loss += nodule_loss.item()
        DDTI_total_loss += DDTI_nodule_loss.item()
        TN3K_total_loss += TN3K_nodule_loss.item()
        
        total_IOU += IOU.item()
        DDTI_total_IOU += DDTI_IOU.item()
        TN3K_total_IOU += TN3K_IOU.item()
        
        total_DICE += DICE.item()
        DDTI_total_DICE += DDTI_DICE.item()
        TN3K_total_DICE += TN3K_DICE.item()
    return total_loss/len(dataloader), total_IOU/len(dataloader), total_DICE/len(dataloader), DDTI_total_loss/len(dataloader), DDTI_total_IOU/len(dataloader), DDTI_total_DICE/len(dataloader), TN3K_total_loss/len(dataloader), TN3K_total_IOU/len(dataloader), TN3K_total_DICE/len(dataloader), 

In [16]:
epochs = 50
lr = 0.001
project = "thyroid_hybrid_model"
# name=name="PMFSNet_crop_DDTI_standardization_aug_affine(0.5)_hflip(0.5)_lr_0.001"
# name = "HarDnet_crop_DDTI_standardization_aug_affine(0.5)_lr_0.005"
# name = "LightMed_crop_DDTI"
# name = "test"
# name = "HarDnetMSEG_baseline"
name = "hybrid_v3_upsample_baseline"
wandb.init(
    project = project,  # Project name in W&B
    name = name,       # (optional) specific run name
    config={                     # (optional) hyperparameters
        "image_size": image_size,
        "learning_rate": lr,
        "epochs": epochs,
        "batch_size": batch_size,
        # "weight_decay":1e-4
    }
)

In [17]:
# model = Eff_Unet(
#         layers=[5, 5, 15, 10],
#         embed_dims=[10, 20, 48, 96],
#         # embed_dims=[40, 80, 192, 384],
#         downsamples=[True, True, True, True],
#         vit_num=6,
#         drop_path_rate=0.1,
#         num_classes=1,
#         resolution = image_size).cuda()
# model = HarDMSEG(in_channels = 1, out_channels = 2)
# model = LightMed(in_channels = 1, out_channels = 2, image_size = image_size)
# model = PMFSNet(in_channels = 1, out_channels = 2, dim = "2d")
# model = PMFSNet_FFT(in_channels = 1, out_channels = 2, dim = "2d")
model =  HybridSegModel(in_channels = 1, out_channels = 2, output_size = image_size, layers_num = 3)

In [18]:
print(f"image size : {image_size}, lr : {lr}, epochs : {epochs}, batch size : {batch_size}, model : {model}")

image size : 128, lr : 0.001, epochs : 50, batch size : 64, model : HybridSegModel(
  (backbone): HarDNetBackbone(
    (base_conv_1): ConvLayer(
      (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU6(inplace=True)
    )
    (base_conv_2): ConvLayer(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU6(inplace=True)
    )
    (base_max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (encoder_blocks): ModuleList(
      (0): EncoderBlock(
        (hardblock): HarDBlock(
          (layers): ModuleList(
            (0): ConvLayer(
              (conv): Conv2d(64, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (norm): Batch

### If using pretrained model

In [19]:
# pretrained_name = "PMFSNet_baseline"
# # model = HarDMSEG(in_channels = 1)
# model = PMFSNet(in_channels = 1, out_channels = 1, dim = "2d")
# # model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
# checkpoint = torch.load(f"models/{pretrained_name}/best_checkpoint.pth")
# model.load_state_dict(checkpoint['model_state_dict'])

In [20]:
# loss_fn = DiceLoss()
loss_fn_nodule = StructureLoss()
loss_fn_gland = StructureLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

In [21]:
max_IOU = -1
for epoch in range(epochs):
    print(f"epoch : {epoch}")
    total_loss_train, total_IOU_train, total_DICE_train = train(train_dataloader, model, optimizer, loss_fn_nodule, loss_fn_gland, "cuda")
    print(f"train loss : {total_loss_train}, train IOU : {total_IOU_train}, train DICE : {total_DICE_train}")
    total_loss_val, total_IOU_val, total_DICE_val, DDTI_total_loss_val, DDTI_total_IOU_val, DDTI_total_DICE_val, TN3K_total_loss_val, TN3K_total_IOU_val, TN3K_total_DICE_val = val(test_dataloader, model, loss_fn_nodule, loss_fn_gland, "cuda")
    print(f"val loss : {total_loss_val}, val IOU : {total_IOU_val}, val DICE : {total_DICE_val}")
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]



    if max_IOU < total_IOU_val:
        max_IOU = total_IOU_val
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            "IOU" : total_IOU_val,
            "DICE" : total_DICE_val,
            "loss" : total_loss_val
        }
        folder = f"models/{name}"
        if not os.path.exists(folder):
            os.mkdir(folder)
        torch.save(checkpoint, f"models/{name}/best_checkpoint.pth")

    
    wandb.log({
        "epoch": epoch,
        "Learning Rate":current_lr,
        
        "train_loss": total_loss_train,
        "train_IOU": total_IOU_train,
        "train_DICE": total_DICE_train,
        
        "val_loss": total_loss_val,
        "val_IOU": total_IOU_val,
        "val_DICE": total_DICE_val,

        "DDTI_val_loss": DDTI_total_loss_val,
        "DDTI_val_IOU": DDTI_total_IOU_val,
        "DDTI_val_DICE": DDTI_total_DICE_val,

        "TN3K_val_loss": TN3K_total_loss_val,
        "TN3K_val_IOU": TN3K_total_IOU_val,
        "TN3K_val_DICE": TN3K_total_DICE_val,
        
    })




checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    "IOU" : total_IOU_val,
    "DICE" : total_DICE_val,
    "loss" : total_loss_val
}
if not os.path.exists(folder):
    os.mkdir(folder)
torch.save(checkpoint, f"models/{name}/last_checkpoint.pth")
wandb.finish()

epoch : 0


100%|██████████| 101/101 [01:18<00:00,  1.28it/s]


train loss : 1.285356316235986, train IOU : 0.5534881978747573, train DICE : 0.6806488426605074


100%|██████████| 20/20 [00:09<00:00,  2.05it/s]


val loss : 0.8283487975597381, val IOU : 0.47170667350292206, val DICE : 0.6437902420759201
epoch : 1


100%|██████████| 101/101 [00:18<00:00,  5.37it/s]


train loss : 0.9482538930260309, train IOU : 0.659552006438227, train DICE : 0.7950005560818285


100%|██████████| 20/20 [00:02<00:00,  7.38it/s]


val loss : 0.7260559886693955, val IOU : 0.536646680533886, val DICE : 0.7005877524614335
epoch : 2


100%|██████████| 101/101 [00:18<00:00,  5.35it/s]


train loss : 0.829348893448858, train IOU : 0.7041357448785612, train DICE : 0.8254050759985896


100%|██████████| 20/20 [00:02<00:00,  7.40it/s]


val loss : 0.6918688863515854, val IOU : 0.5594232261180878, val DICE : 0.7122121185064316
epoch : 3


100%|██████████| 101/101 [00:18<00:00,  5.35it/s]


train loss : 0.7427708719036367, train IOU : 0.7331751561400914, train DICE : 0.8482421093648023


100%|██████████| 20/20 [00:02<00:00,  7.56it/s]


val loss : 0.6468048006296158, val IOU : 0.5916127309203147, val DICE : 0.7489693939685822
epoch : 4


100%|██████████| 101/101 [00:18<00:00,  5.37it/s]


train loss : 0.686064358394925, train IOU : 0.7537501106167784, train DICE : 0.8635201985293096


100%|██████████| 20/20 [00:02<00:00,  7.54it/s]


val loss : 0.6438147515058518, val IOU : 0.5945205062627792, val DICE : 0.7543130695819855
epoch : 5


100%|██████████| 101/101 [00:18<00:00,  5.44it/s]


train loss : 0.6263367102877928, train IOU : 0.7756139335065785, train DICE : 0.877552339346102


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.6517310529947281, val IOU : 0.6040982365608215, val DICE : 0.7613282471895217
epoch : 6


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.5825876655555008, train IOU : 0.7916907573690509, train DICE : 0.8877710458075646


100%|██████████| 20/20 [00:02<00:00,  7.58it/s]


val loss : 0.7291287004947662, val IOU : 0.558964017033577, val DICE : 0.7402098685503006
epoch : 7


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.5515437391724917, train IOU : 0.8024789950635174, train DICE : 0.894392202986349


100%|██████████| 20/20 [00:02<00:00,  7.61it/s]


val loss : 0.5547318264842034, val IOU : 0.6551604062318802, val DICE : 0.7998262345790863
epoch : 8


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.5239285437187345, train IOU : 0.8121552467346191, train DICE : 0.9013856525468354


100%|██████████| 20/20 [00:02<00:00,  7.58it/s]


val loss : 0.587027621269226, val IOU : 0.6371447950601578, val DICE : 0.7795161128044128
epoch : 9


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.4914483321775304, train IOU : 0.8235145234825587, train DICE : 0.9089510499840916


100%|██████████| 20/20 [00:02<00:00,  7.57it/s]


val loss : 0.6241381108760834, val IOU : 0.610661256313324, val DICE : 0.7711813598871231
epoch : 10


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.4763409161921775, train IOU : 0.8289167426600315, train DICE : 0.9118518616893504


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.6079150184988975, val IOU : 0.6349625170230866, val DICE : 0.7794553756713867
epoch : 11


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.44612565459591325, train IOU : 0.8400373553285504, train DICE : 0.9189791042025727


100%|██████████| 20/20 [00:02<00:00,  7.61it/s]


val loss : 0.658744877576828, val IOU : 0.625190943479538, val DICE : 0.7775425106287003
epoch : 12


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.423157806443696, train IOU : 0.8482258249037337, train DICE : 0.9234769161384885


100%|██████████| 20/20 [00:02<00:00,  7.62it/s]


val loss : 0.5387239381670952, val IOU : 0.6783502489328385, val DICE : 0.8212381482124329
epoch : 13


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.39067366040579166, train IOU : 0.8592152890592518, train DICE : 0.9308091319433534


100%|██████████| 20/20 [00:02<00:00,  7.64it/s]


val loss : 0.5516443595290184, val IOU : 0.6781838715076447, val DICE : 0.79954574406147
epoch : 14


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.3700290109851573, train IOU : 0.8668713734881712, train DICE : 0.9353267804230794


100%|██████████| 20/20 [00:02<00:00,  7.60it/s]


val loss : 0.596599668264389, val IOU : 0.6612340867519378, val DICE : 0.7861560463905335
epoch : 15


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.34424441610232437, train IOU : 0.8762624983740325, train DICE : 0.9401517321567724


100%|██████████| 20/20 [00:02<00:00,  7.64it/s]


val loss : 0.526800000667572, val IOU : 0.694926318526268, val DICE : 0.8319249421358108
epoch : 16


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.31952053662573937, train IOU : 0.8844271862860953, train DICE : 0.945702857310229


100%|██████████| 20/20 [00:02<00:00,  7.55it/s]


val loss : 0.5442569926381111, val IOU : 0.6976105004549027, val DICE : 0.8343873709440232
epoch : 17


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.29537695823329513, train IOU : 0.8933281585721686, train DICE : 0.9506231463781678


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.6552057325839996, val IOU : 0.674617412686348, val DICE : 0.8262146472930908
epoch : 18


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.2757848992501155, train IOU : 0.9001975513920926, train DICE : 0.9537466428067425


100%|██████████| 20/20 [00:02<00:00,  7.57it/s]


val loss : 0.6026426300406456, val IOU : 0.7022740811109542, val DICE : 0.834193542599678
epoch : 19


100%|██████████| 101/101 [00:18<00:00,  5.58it/s]


train loss : 0.25843285761847357, train IOU : 0.9065186546580626, train DICE : 0.9570484397434952


100%|██████████| 20/20 [00:02<00:00,  7.61it/s]


val loss : 0.6489278271794319, val IOU : 0.6950477659702301, val DICE : 0.8395493656396866
epoch : 20


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.2429286623650258, train IOU : 0.9116520155774485, train DICE : 0.9602299905059362


100%|██████████| 20/20 [00:02<00:00,  7.60it/s]


val loss : 0.7379490956664085, val IOU : 0.6775136351585388, val DICE : 0.8234014689922333
epoch : 21


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.23110585003206047, train IOU : 0.9157803176653267, train DICE : 0.9619958318106019


100%|██████████| 20/20 [00:02<00:00,  7.63it/s]


val loss : 0.7168415993452072, val IOU : 0.6982773840427399, val DICE : 0.8376651793718338
epoch : 22


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.21821187982464782, train IOU : 0.9202465816299514, train DICE : 0.9643894664131769


100%|██████████| 20/20 [00:02<00:00,  7.63it/s]


val loss : 0.6750083804130554, val IOU : 0.7201738804578781, val DICE : 0.8499150067567826
epoch : 23


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.20473379737669878, train IOU : 0.9246331825114713, train DICE : 0.966698677823095


100%|██████████| 20/20 [00:02<00:00,  7.60it/s]


val loss : 0.7372164830565453, val IOU : 0.7054192006587983, val DICE : 0.8455593019723893
epoch : 24


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.197249306015449, train IOU : 0.9272661150091945, train DICE : 0.9681009473186908


100%|██████████| 20/20 [00:02<00:00,  7.61it/s]


val loss : 0.7825571894645691, val IOU : 0.7068104654550552, val DICE : 0.8477076351642608
epoch : 25


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.18960314487466717, train IOU : 0.9297374616755117, train DICE : 0.9696068327025612


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.6360206589102745, val IOU : 0.7286186307668686, val DICE : 0.8545337229967117
epoch : 26


100%|██████████| 101/101 [00:18<00:00,  5.54it/s]


train loss : 0.1736598560715666, train IOU : 0.9355127958968135, train DICE : 0.9722978803190855


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.7417069479823113, val IOU : 0.7263279974460601, val DICE : 0.8572685688734054
epoch : 27


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.1661550932296432, train IOU : 0.9381738499839707, train DICE : 0.9736156233466498


100%|██████████| 20/20 [00:02<00:00,  7.55it/s]


val loss : 0.8689594328403473, val IOU : 0.7066820800304413, val DICE : 0.8475569814443589
epoch : 28


100%|██████████| 101/101 [00:18<00:00,  5.54it/s]


train loss : 0.15406319130175183, train IOU : 0.9425250297725791, train DICE : 0.9756894979146448


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.7186224684119225, val IOU : 0.7397445648908615, val DICE : 0.8636872977018356
epoch : 29


100%|██████████| 101/101 [00:18<00:00,  5.58it/s]


train loss : 0.14696022204243311, train IOU : 0.9444949774458857, train DICE : 0.9770016245322652


100%|██████████| 20/20 [00:02<00:00,  7.62it/s]


val loss : 0.8702337354421615, val IOU : 0.7209880262613296, val DICE : 0.8536013841629029
epoch : 30


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.13965504845180135, train IOU : 0.9476504756672548, train DICE : 0.978033694300321


100%|██████████| 20/20 [00:02<00:00,  7.58it/s]


val loss : 0.8411388665437698, val IOU : 0.732870128750801, val DICE : 0.8609986513853073
epoch : 31


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.13056525868354457, train IOU : 0.9508112851936038, train DICE : 0.9795676216040508


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.8404191225767136, val IOU : 0.7364051878452301, val DICE : 0.862163645029068
epoch : 32


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.1260612072183354, train IOU : 0.9525151211436432, train DICE : 0.9803072521001985


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 0.9619240581989288, val IOU : 0.7270538687705994, val DICE : 0.8556590259075165
epoch : 33


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.12041233568498404, train IOU : 0.954413098273891, train DICE : 0.9811611488313958


100%|██████████| 20/20 [00:02<00:00,  7.62it/s]


val loss : 0.9630084156990051, val IOU : 0.73413844704628, val DICE : 0.8601498812437057
epoch : 34


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.11237987101373106, train IOU : 0.9574124470795735, train DICE : 0.9824889337662424


100%|██████████| 20/20 [00:02<00:00,  7.67it/s]


val loss : 0.9663858324289322, val IOU : 0.7397847950458527, val DICE : 0.8637845844030381
epoch : 35


100%|██████████| 101/101 [00:18<00:00,  5.54it/s]


train loss : 0.10845707932321152, train IOU : 0.9589429379689811, train DICE : 0.9830632445835831


100%|██████████| 20/20 [00:02<00:00,  7.58it/s]


val loss : 1.0086657255887985, val IOU : 0.7426986217498779, val DICE : 0.8655901610851288
epoch : 36


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.10328614888804974, train IOU : 0.9610833785321453, train DICE : 0.983927598093996


100%|██████████| 20/20 [00:02<00:00,  7.63it/s]


val loss : 0.9822363227605819, val IOU : 0.7444828420877456, val DICE : 0.8662304818630219
epoch : 37


100%|██████████| 101/101 [00:18<00:00,  5.52it/s]


train loss : 0.09880282973298932, train IOU : 0.9626056551933289, train DICE : 0.984626041780604


100%|██████████| 20/20 [00:02<00:00,  7.62it/s]


val loss : 1.0728956490755082, val IOU : 0.7426381796598435, val DICE : 0.8647291541099549
epoch : 38


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.0937787931744415, train IOU : 0.9645046326193479, train DICE : 0.985445777968605


100%|██████████| 20/20 [00:02<00:00,  7.57it/s]


val loss : 1.0690275430679321, val IOU : 0.7436597853899002, val DICE : 0.8660279422998428
epoch : 39


100%|██████████| 101/101 [00:18<00:00,  5.57it/s]


train loss : 0.09034211771322949, train IOU : 0.9657048270253852, train DICE : 0.9859890625028327


100%|██████████| 20/20 [00:02<00:00,  7.67it/s]


val loss : 1.0891041964292527, val IOU : 0.7487425833940506, val DICE : 0.8676401674747467
epoch : 40


100%|██████████| 101/101 [00:18<00:00,  5.53it/s]


train loss : 0.0869433721988508, train IOU : 0.9672053922521006, train DICE : 0.9864923328456312


100%|██████████| 20/20 [00:02<00:00,  7.57it/s]


val loss : 1.1339978218078612, val IOU : 0.748973286151886, val DICE : 0.8680414497852326
epoch : 41


100%|██████████| 101/101 [00:18<00:00,  5.52it/s]


train loss : 0.08388376154816977, train IOU : 0.9685374958680408, train DICE : 0.9870030065574268


100%|██████████| 20/20 [00:02<00:00,  7.60it/s]


val loss : 1.1675984561443329, val IOU : 0.7494499623775482, val DICE : 0.8677129834890366
epoch : 42


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.08118721407533873, train IOU : 0.9694725021277324, train DICE : 0.9874425807801803


100%|██████████| 20/20 [00:02<00:00,  7.67it/s]


val loss : 1.1958944499492645, val IOU : 0.7491907924413681, val DICE : 0.8675025045871735
epoch : 43


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.07839410936478342, train IOU : 0.9707423442661172, train DICE : 0.9878575860863865


100%|██████████| 20/20 [00:02<00:00,  7.60it/s]


val loss : 1.2054522186517715, val IOU : 0.7507305026054383, val DICE : 0.8682935506105423
epoch : 44


100%|██████████| 101/101 [00:18<00:00,  5.55it/s]


train loss : 0.07666465367125992, train IOU : 0.9713945341582345, train DICE : 0.9881281528142419


100%|██████████| 20/20 [00:02<00:00,  7.60it/s]


val loss : 1.2578779995441436, val IOU : 0.7493494927883149, val DICE : 0.8676337748765945
epoch : 45


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.07562870384737996, train IOU : 0.971786624724322, train DICE : 0.9882923518076981


100%|██████████| 20/20 [00:02<00:00,  7.64it/s]


val loss : 1.2578438609838485, val IOU : 0.7501315683126449, val DICE : 0.8680261313915253
epoch : 46


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.07459201538326717, train IOU : 0.9722408852954902, train DICE : 0.9884687012965137


100%|██████████| 20/20 [00:02<00:00,  7.57it/s]


val loss : 1.2550379902124404, val IOU : 0.7506937116384507, val DICE : 0.8681362688541412
epoch : 47


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.07368332051699704, train IOU : 0.9726602559042449, train DICE : 0.9886211127337843


100%|██████████| 20/20 [00:02<00:00,  7.61it/s]


val loss : 1.2581817597150802, val IOU : 0.7514533460140228, val DICE : 0.8685450971126556
epoch : 48


100%|██████████| 101/101 [00:18<00:00,  5.53it/s]


train loss : 0.07341946218863572, train IOU : 0.9726770496604467, train DICE : 0.9886275795426699


100%|██████████| 20/20 [00:02<00:00,  7.58it/s]


val loss : 1.2728574395179748, val IOU : 0.7504221498966217, val DICE : 0.867884686589241
epoch : 49


100%|██████████| 101/101 [00:18<00:00,  5.56it/s]


train loss : 0.0730737428558935, train IOU : 0.9728870769538501, train DICE : 0.9887151446672949


100%|██████████| 20/20 [00:02<00:00,  7.59it/s]


val loss : 1.2830171018838883, val IOU : 0.7504754453897476, val DICE : 0.8680468410253525


0,1
DDTI_val_DICE,▁▄▄▅▅▃▇▇▄▅█▆▅██▇▇▅▆▇▆█▇▆█▇▇▇▇▇█▇▇▇▇▇▇▇▇▇
DDTI_val_IOU,▁▄▅▆▅▂█▇▄▃▅██▅▆▄▆▇▅▆▇▅▇▆▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇
DDTI_val_loss,▂▁▁▁▁▂▁▁▂▁▁▁▂▁▁▂▂▃▃▃▄▃▄▃▅▄▅▅▅▆▆▆▆▇▇▇████
Learning Rate,███████▇▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁
TN3K_val_DICE,▁▂▃▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇█▇██████████████████
TN3K_val_IOU,▁▂▃▃▃▄▅▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███████████████
TN3K_val_loss,█▇▆▆▆▆▄▅▄▅▄▃▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train_DICE,▁▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████████
train_IOU,▁▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████

0,1
DDTI_val_DICE,0.70435
DDTI_val_IOU,0.53437
DDTI_val_loss,2.47729
Learning Rate,0.0
TN3K_val_DICE,0.98916
TN3K_val_IOU,0.97293
TN3K_val_loss,0.04381
epoch,49.0
train_DICE,0.98872
train_IOU,0.97289


In [22]:
print(max_IOU)

0.7514533460140228


In [23]:
inference_name = "PMFSNet_baseline"
# model = HarDMSEG(in_channels = 1)
model = PMFSNet(in_channels = 1, out_channels = 1, dim = "2d")
# model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
checkpoint = torch.load(f"models/{inference_name}/last_checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

FileNotFoundError: [Errno 2] No such file or directory: 'models/PMFSNet_baseline/last_checkpoint.pth'

In [None]:

total_loss_val, total_IOU_val, total_DICE_val = val(test_dataloader, model, loss_fn, "cuda")
print(f"val loss : {total_loss_val}, val IOU : {total_IOU_val}, val DICE : {total_DICE_val}")


In [None]:
DICE = 0
# model.to("cuda")
model.eval()
for image, mask in tqdm(test_dataloader):
    image, mask = image, mask
    preds = model(image)
    DICE += (1 - dice_loss(preds, mask)).item()
print(DICE/len(test_dataloader))

In [None]:
index = 7
plt.subplot(1,2,1)
plt.imshow(preds[index][0])
plt.subplot(1,2,2)
plt.imshow(mask[index][0])

In [None]:
dice_loss = DiceLoss()
print(1 - dice_loss(preds, mask))

In [None]:
plt.imshow(mask[index][0])