In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
from torchvision import transforms
import pandas as pd
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import tqdm
import random
import torch.nn as nn
from dataset import Thyroid_Dataset
from model import Eff_Unet
from HarDMSEG import HarDMSEG
from loss_metric import DiceLoss, IOU_score, StructureLoss
from LightMed.model.LightMed import LightMed
from PMFSNet.lib.models.PMFSNet import PMFSNet
from PMFSNet.lib.models.PMFSNet_FFT import PMFSNet_FFT
import torchvision.transforms.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as tx
from torchvision.transforms import GaussianBlur
# from hybrid_model_v1 import HybridSegModel
from hybrid_model_v3 import HybridSegModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [3]:
import wandb
wandb.login()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33ms960068sss[0m ([33ms960068sss-ucl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
image_size = 224
batch_size = 100
def train_augmentation(image, mask, image_size):

    p = 0.5
    
    resize = T.Resize((image_size, image_size))
    image = resize(image)
    mask = resize(mask)
    
    # if(random.random() < p):
    #     jitter = T.ColorJitter(brightness=0.5)
    #     image = jitter(image)
    # if(random.random() < p):
    #     angle = random.uniform(-10, 10)  # 旋轉角度從 ±10 度
    #     translate = (random.uniform(-0.05, 0.05) * image.size[0],
    #                  random.uniform(-0.05, 0.05) * image.size[1])  # 最多平移 ±5%
    #     scale = random.uniform(0.95, 1.05)  # 尺度縮放 ±5%
    #     shear = [random.uniform(-5, 5), random.uniform(-5, 5)]  # 小幅剪切
    
    #     image = F.affine(image, angle=angle, translate=translate, scale=scale, shear=shear)
    #     mask = F.affine(mask, angle=angle, translate=translate, scale=scale, shear=shear)
    # if random.random() < p:
    #     image = F.hflip(image)
    #     mask = F.hflip(mask)
    # if(random.random() < p):
    #     transform = T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    #     image = transform(image)  # image must be a PIL image
    
    image_tensor = tx.to_tensor(image)
    mask_tensor = tx.to_tensor(mask)

    # If standardization
    mean = image_tensor.mean()
    std = image_tensor.std()
    std = std if std > 0 else 1.0  # avoid division by zero
    image_tensor = (image_tensor - mean) / std
    
    return image_tensor, mask_tensor
def test_augmentation(image, mask, image_size):
    resize = T.Resize((image_size, image_size))
    image = resize(image)
    mask = resize(mask)

    image_tensor = tx.to_tensor(image)
    mask_tensor = tx.to_tensor(mask)

    # If standardization
    mean = image_tensor.mean()
    std = image_tensor.std()
    std = std if std > 0 else 1.0  # avoid division by zero
    image_tensor = (image_tensor - mean) / std
    return image_tensor, mask_tensor

    
train_dataset = Thyroid_Dataset("train_v2.csv", transform = train_augmentation, image_size = image_size)
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)

test_dataset = Thyroid_Dataset("test_v2.csv", transform = test_augmentation, image_size = image_size, return_from_dataset = True, crop_DDTI = True, histo_match = False)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

image, mask, seg_type = next(iter(train_dataloader))
image = image
mask = mask

In [5]:
print("std : ", torch.std(image))
print("unique : ", torch.unique(mask))

std :  tensor(1.0000)
unique :  tensor([0., 1.])


In [6]:
# index = 39
# for index in range(5):
#     # plt.subplot(1,2,1)
#     # plt.imshow(mask[index][0], alpha = 0.3)
#     plt.imshow(image[index][0])
#     plt.imshow(mask[index][0], alpha = 0.3)
    
#     # plt.subplot(1,2,2)
#     # plt.imshow(mask[index][0])
#     plt.show()

In [7]:
# model = Eff_Unet(
#         layers=[5, 5, 15, 10],
#         embed_dims=[40, 80, 192, 384],
#         downsamples=[True, True, True, True],
#         vit_num=6,
#         drop_path_rate=0.1,
#         num_classes=1,
#         resolution = image_size).cuda()
# model = HarDMSEG(in_channels = 1)
# model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
# model = PMFSNet(in_channels = 1, out_channels = 2, dim = "2d")
# model = PMFSNet_FFT(in_channels = 1, out_channels = 2, dim = "2d")
# model = HybridSegModel(in_channels = 1, out_channels = 2, output_size = image_size, layers_num = 3)

In [8]:
# loss_fn_nodule = StructureLoss()
# loss_fn_gland = StructureLoss()

In [9]:
# model.to("cuda")
# image = image.to("cuda")
# mask = mask.to("cuda")

In [10]:
# optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
# for epoch in range(1000):
#     outputs = model(image)

#     nodule_output = outputs[:, 0:1, :, :][seg_type==1]
#     nodule_mask = mask[seg_type==1]
    
#     gland_output = outputs[:, 1:2, :, :][seg_type==2]
#     gland_mask = mask[seg_type==2]
#     # outputs = torch.sigmoid(logits)
#     # print(nodule_output.shape, nodule_mask.shape)
#     nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
#     gland_loss = loss_fn_gland(gland_output, gland_mask)
    
#     loss = nodule_loss + gland_loss
    
#     IOU = (IOU_score(nodule_output, nodule_mask) + IOU_score(gland_output, gland_mask)) / 2
    
#     dice_loss = DiceLoss()
#     # print(dice_loss(nodule_output, nodule_mask))
#     # print(dice_loss(gland_output, gland_mask)) 
#     DICE = ((1 - dice_loss(nodule_output, nodule_mask)) + (1 - dice_loss(gland_output, gland_mask))) / 2
    
#     # Backward and optimize
#     optimizer.zero_grad()   # clear previous gradients
#     loss.backward()         # compute gradients
#     optimizer.step()        # update weights

#     print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, IOU: {IOU.item():.4f}, DICE: {DICE.item()}")

In [12]:
# (1 - dice_loss(nodule_output, nodule_mask))

In [13]:
# index = 4
# plt.subplot(1,2,1)
# plt.imshow(nodule_mask[index][0].detach().cpu().numpy())
# nodule_output = (nodule_output > 0.5)
# plt.subplot(1,2,2)
# plt.imshow(nodule_output[index][0].detach().cpu().numpy())

In [14]:
# index = 4
# plt.subplot(1,2,1)
# plt.imshow(gland_mask[index][0].detach().cpu().numpy())
# gland_output = (gland_output > 0.5)
# plt.subplot(1,2,2)
# plt.imshow(gland_output[index][0].detach().cpu().numpy())

In [15]:
def train(dataloader, model, optimizer, loss_fn_nodule, loss_fn_gland, device):
    total_loss = 0
    total_IOU = 0
    total_DICE = 0
    
    model.train()
    model.to(device)
    count = 0
    for image, mask, seg_type in tqdm(dataloader):
        image, mask, seg_type = image.to(device), mask.to(device), seg_type.to(device)
        outputs = model(image)

        nodule_output = outputs[:, 0:1, :, :][seg_type==1]
        nodule_mask = mask[seg_type==1]
        
        gland_output = outputs[:, 1:2, :, :][seg_type==2]
        gland_mask = mask[seg_type==2]

        nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
        gland_loss = loss_fn_gland(gland_output, gland_mask)


        
        loss = nodule_loss + gland_loss
        IOU = (IOU_score(nodule_output, nodule_mask) + IOU_score(gland_output, gland_mask)) / 2

        dice_loss = DiceLoss()
        DICE = ((1 - dice_loss(nodule_output, nodule_mask)) + (1 - dice_loss(gland_output, gland_mask))) / 2

        
        # Backward and optimize
        optimizer.zero_grad()   # clear previous gradients
        loss.backward()         # compute gradients
        optimizer.step()        # update weights

        total_loss += loss.item()
        total_IOU += IOU.item()
        total_DICE += DICE.item()
        # count+=1
        # if count==5:
        #     break
    return total_loss/len(dataloader), total_IOU/len(dataloader), total_DICE/len(dataloader)


# 
# Only calculate nodule loss, IOU, DICE, because there is no gland data in the testing set
def val(dataloader, model, loss_fn_nodule, loss_fn_gland, device):
    total_loss = 0
    DDTI_total_loss = 0
    TN3K_total_loss = 0
    
    total_IOU = 0
    DDTI_total_IOU = 0
    TN3K_total_IOU = 0
    
    total_DICE = 0
    DDTI_total_DICE = 0
    TN3K_total_DICE = 0

    model.eval()
    model.to(device)
    for image, mask, seg_type, from_dataset in tqdm(dataloader):
        image, mask, seg_type = image.to(device), mask.to(device), seg_type.to(device)
        outputs = model(image)
        
        nodule_output = outputs[:, 0:1, :, :][seg_type==1]
        nodule_mask = mask[seg_type==1]


        DDTI_nodule_output = nodule_output[from_dataset==1]
        TN3K_nodule_output = nodule_output[from_dataset==3]

        DDTI_nodule_mask = nodule_mask[from_dataset==1]
        TN3K_nodule_mask = nodule_mask[from_dataset==3]
        

        nodule_loss = loss_fn_nodule(nodule_output, nodule_mask)
        DDTI_nodule_loss = loss_fn_nodule(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_nodule_loss = loss_fn_nodule(TN3K_nodule_output, TN3K_nodule_mask)
        
        IOU = IOU_score(nodule_output, nodule_mask)
        DDTI_IOU = IOU_score(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_IOU = IOU_score(TN3K_nodule_output, TN3K_nodule_mask)

        dice_loss = DiceLoss()
        DICE = 1 - dice_loss(nodule_output, nodule_mask)
        DDTI_DICE = 1 - dice_loss(DDTI_nodule_output, DDTI_nodule_mask)
        TN3K_DICE = 1 - dice_loss(TN3K_nodule_output, TN3K_nodule_mask)
    

        total_loss += nodule_loss.item()
        DDTI_total_loss += DDTI_nodule_loss.item()
        TN3K_total_loss += TN3K_nodule_loss.item()
        
        total_IOU += IOU.item()
        DDTI_total_IOU += DDTI_IOU.item()
        TN3K_total_IOU += TN3K_IOU.item()
        
        total_DICE += DICE.item()
        DDTI_total_DICE += DDTI_DICE.item()
        TN3K_total_DICE += TN3K_DICE.item()
    return total_loss/len(dataloader), total_IOU/len(dataloader), total_DICE/len(dataloader), DDTI_total_loss/len(dataloader), DDTI_total_IOU/len(dataloader), DDTI_total_DICE/len(dataloader), TN3K_total_loss/len(dataloader), TN3K_total_IOU/len(dataloader), TN3K_total_DICE/len(dataloader), 

In [16]:
epochs = 50
lr = 0.001
project = "thyroid_hybrid_model"
# name=name="PMFSNet_crop_DDTI_standardization_aug_affine(0.5)_hflip(0.5)_lr_0.001"
# name = "HarDnet_crop_DDTI_standardization_aug_affine(0.5)_lr_0.005"
# name = "LightMed_crop_DDTI"
# name = "test"
# name = "HarDnetMSEG_baseline"
name = "hybrid_v3_baseline_224"
wandb.init(
    project = project,  # Project name in W&B
    name = name,       # (optional) specific run name
    config={                     # (optional) hyperparameters
        "image_size": image_size,
        "learning_rate": lr,
        "epochs": epochs,
        "batch_size": batch_size,
        # "weight_decay":1e-4
    }
)

In [17]:
# model = Eff_Unet(
#         layers=[5, 5, 15, 10],
#         embed_dims=[10, 20, 48, 96],
#         # embed_dims=[40, 80, 192, 384],
#         downsamples=[True, True, True, True],
#         vit_num=6,
#         drop_path_rate=0.1,
#         num_classes=1,
#         resolution = image_size).cuda()
# model = HarDMSEG(in_channels = 1, out_channels = 2)
# model = LightMed(in_channels = 1, out_channels = 2, image_size = image_size)
# model = PMFSNet(in_channels = 1, out_channels = 2, dim = "2d")
# model = PMFSNet_FFT(in_channels = 1, out_channels = 2, dim = "2d")
model =  HybridSegModel(in_channels = 1, out_channels = 2, output_size = image_size, layers_num = 3)

In [18]:
print(f"image size : {image_size}, lr : {lr}, epochs : {epochs}, batch size : {batch_size}, model : {model}")

image size : 224, lr : 0.001, epochs : 50, batch size : 100, model : HybridSegModel(
  (backbone): HarDNetBackbone(
    (base_conv_1): ConvLayer(
      (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU6(inplace=True)
    )
    (base_conv_2): ConvLayer(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU6(inplace=True)
    )
    (base_max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (encoder_blocks): ModuleList(
      (0): EncoderBlock(
        (hardblock): HarDBlock(
          (layers): ModuleList(
            (0): ConvLayer(
              (conv): Conv2d(64, 14, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (norm): Batc

### If using pretrained model

In [19]:
# pretrained_name = "PMFSNet_baseline"
# # model = HarDMSEG(in_channels = 1)
# model = PMFSNet(in_channels = 1, out_channels = 1, dim = "2d")
# # model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
# checkpoint = torch.load(f"models/{pretrained_name}/best_checkpoint.pth")
# model.load_state_dict(checkpoint['model_state_dict'])

In [20]:
# loss_fn = DiceLoss()
loss_fn_nodule = StructureLoss()
loss_fn_gland = StructureLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

In [21]:
max_IOU = -1
for epoch in range(epochs):
    print(f"epoch : {epoch}")
    total_loss_train, total_IOU_train, total_DICE_train = train(train_dataloader, model, optimizer, loss_fn_nodule, loss_fn_gland, "cuda")
    print(f"train loss : {total_loss_train}, train IOU : {total_IOU_train}, train DICE : {total_DICE_train}")
    total_loss_val, total_IOU_val, total_DICE_val, DDTI_total_loss_val, DDTI_total_IOU_val, DDTI_total_DICE_val, TN3K_total_loss_val, TN3K_total_IOU_val, TN3K_total_DICE_val = val(test_dataloader, model, loss_fn_nodule, loss_fn_gland, "cuda")
    print(f"val loss : {total_loss_val}, val IOU : {total_IOU_val}, val DICE : {total_DICE_val}")
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]



    if max_IOU < total_IOU_val:
        max_IOU = total_IOU_val
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            "IOU" : total_IOU_val,
            "DICE" : total_DICE_val,
            "loss" : total_loss_val
        }
        folder = f"models/{name}"
        if not os.path.exists(folder):
            os.mkdir(folder)
        torch.save(checkpoint, f"models/{name}/best_checkpoint.pth")

    
    wandb.log({
        "epoch": epoch,
        "Learning Rate":current_lr,
        
        "train_loss": total_loss_train,
        "train_IOU": total_IOU_train,
        "train_DICE": total_DICE_train,
        
        "val_loss": total_loss_val,
        "val_IOU": total_IOU_val,
        "val_DICE": total_DICE_val,

        "DDTI_val_loss": DDTI_total_loss_val,
        "DDTI_val_IOU": DDTI_total_IOU_val,
        "DDTI_val_DICE": DDTI_total_DICE_val,

        "TN3K_val_loss": TN3K_total_loss_val,
        "TN3K_val_IOU": TN3K_total_IOU_val,
        "TN3K_val_DICE": TN3K_total_DICE_val,
        
    })




checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    "IOU" : total_IOU_val,
    "DICE" : total_DICE_val,
    "loss" : total_loss_val
}
if not os.path.exists(folder):
    os.mkdir(folder)
torch.save(checkpoint, f"models/{name}/last_checkpoint.pth")
wandb.finish()

epoch : 0


100%|██████████| 65/65 [00:37<00:00,  1.74it/s]


train loss : 1.6147734293570886, train IOU : 0.4962871871315516, train DICE : 0.5909997761249542


100%|██████████| 13/13 [00:05<00:00,  2.42it/s]


val loss : 1.023854517019712, val IOU : 0.3605102759141188, val DICE : 0.5511583823424119
epoch : 1


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 1.121367455445803, train IOU : 0.6331697454819313, train DICE : 0.7704144973021287


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.9782948218859159, val IOU : 0.34867404515926653, val DICE : 0.5491639696634733
epoch : 2


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 1.0010986117216256, train IOU : 0.6738931738413297, train DICE : 0.8031409520369309


100%|██████████| 13/13 [00:03<00:00,  3.31it/s]


val loss : 0.9075318116408128, val IOU : 0.43839975045277524, val DICE : 0.6217882449810321
epoch : 3


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.8990957681949322, train IOU : 0.7097249471224272, train DICE : 0.8291402697563172


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.8604131111731896, val IOU : 0.4590920874705681, val DICE : 0.6566345554131728
epoch : 4


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.8204577711912302, train IOU : 0.7393567809691796, train DICE : 0.8483415090120756


100%|██████████| 13/13 [00:03<00:00,  3.30it/s]


val loss : 0.8287495558078473, val IOU : 0.4972732731929192, val DICE : 0.6554801968427805
epoch : 5


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.7708316592069773, train IOU : 0.7559941961215093, train DICE : 0.859984935246981


100%|██████████| 13/13 [00:03<00:00,  3.30it/s]


val loss : 0.6823715980236347, val IOU : 0.5897650810388418, val DICE : 0.7436569470625657
epoch : 6


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.7148369789123535, train IOU : 0.7756395284946148, train DICE : 0.8727875251036424


100%|██████████| 13/13 [00:03<00:00,  3.27it/s]


val loss : 0.7145812282195458, val IOU : 0.5729852685561547, val DICE : 0.7260613441467285
epoch : 7


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.6741430117533758, train IOU : 0.7898200887900132, train DICE : 0.8835731029510498


100%|██████████| 13/13 [00:03<00:00,  3.30it/s]


val loss : 0.704776983994704, val IOU : 0.5701843316738422, val DICE : 0.7234097398244418
epoch : 8


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.6503768792519202, train IOU : 0.7977463309581463, train DICE : 0.8880484140836276


100%|██████████| 13/13 [00:03<00:00,  3.28it/s]


val loss : 0.6416948621089642, val IOU : 0.6253998142022353, val DICE : 0.7643418403772207
epoch : 9


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.6167224132097684, train IOU : 0.8091231465339661, train DICE : 0.8960516416109525


100%|██████████| 13/13 [00:03<00:00,  3.30it/s]


val loss : 0.7196352665240948, val IOU : 0.5729016019747808, val DICE : 0.7277284447963421
epoch : 10


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.5881059114749615, train IOU : 0.8191769077227666, train DICE : 0.9023809552192688


100%|██████████| 13/13 [00:03<00:00,  3.25it/s]


val loss : 0.8072242049070505, val IOU : 0.5146847367286682, val DICE : 0.6766123955066388
epoch : 11


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.5507989273621485, train IOU : 0.8312314840463492, train DICE : 0.9102608818274278


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.6696609533750094, val IOU : 0.619760261132167, val DICE : 0.7570898211919345
epoch : 12


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.5355571788090926, train IOU : 0.8368290020869329, train DICE : 0.9140369956309978


100%|██████████| 13/13 [00:04<00:00,  3.07it/s]


val loss : 0.6721418683345501, val IOU : 0.6218248055531428, val DICE : 0.7591608029145461
epoch : 13


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.5316023432291471, train IOU : 0.8375237070597135, train DICE : 0.9142101737169119


100%|██████████| 13/13 [00:04<00:00,  3.25it/s]


val loss : 0.9358037068293645, val IOU : 0.49713061635310835, val DICE : 0.678575323178218
epoch : 14


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.5078601846328148, train IOU : 0.8458173990249633, train DICE : 0.920620217690101


100%|██████████| 13/13 [00:03<00:00,  3.32it/s]


val loss : 0.6177164591275729, val IOU : 0.6322764112399175, val DICE : 0.7839256754288306
epoch : 15


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.47041650185218226, train IOU : 0.8579173060563895, train DICE : 0.9281891776965214


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.6087675369702853, val IOU : 0.6332452755707961, val DICE : 0.786767629476694
epoch : 16


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.4401486616868239, train IOU : 0.8679818153381348, train DICE : 0.9347616681685814


100%|██████████| 13/13 [00:03<00:00,  3.26it/s]


val loss : 0.5976946812409621, val IOU : 0.6628037507717426, val DICE : 0.7983859318953294
epoch : 17


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.41569513586851264, train IOU : 0.8758645066848167, train DICE : 0.9396628132233253


100%|██████████| 13/13 [00:03<00:00,  3.28it/s]


val loss : 0.5749300420284271, val IOU : 0.6791589489349952, val DICE : 0.8179338849507846
epoch : 18


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.40248729311502895, train IOU : 0.8802763388707088, train DICE : 0.941891334607051


100%|██████████| 13/13 [00:03<00:00,  3.30it/s]


val loss : 0.5999939441680908, val IOU : 0.6655369767775903, val DICE : 0.8052432858026944
epoch : 19


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.3809363291813777, train IOU : 0.8873275197469271, train DICE : 0.9461240392464858


100%|██████████| 13/13 [00:03<00:00,  3.28it/s]


val loss : 0.5767710988338177, val IOU : 0.6834839628292964, val DICE : 0.8174050266926105
epoch : 20


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.3651916201298053, train IOU : 0.8921714104138888, train DICE : 0.9487397258098309


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.6107127483074481, val IOU : 0.6891362437835107, val DICE : 0.8225430662815387
epoch : 21


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.3491148334283095, train IOU : 0.8973843244405894, train DICE : 0.9518342458284819


100%|██████████| 13/13 [00:03<00:00,  3.34it/s]


val loss : 0.6138841005472037, val IOU : 0.6865703005057114, val DICE : 0.8194639591070322
epoch : 22


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.33346610160974355, train IOU : 0.9018835828854488, train DICE : 0.9545126126362727


100%|██████████| 13/13 [00:03<00:00,  3.31it/s]


val loss : 0.6268582481604356, val IOU : 0.6831236344117385, val DICE : 0.8272845240739676
epoch : 23


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.3230459690093994, train IOU : 0.9053216796654922, train DICE : 0.9561388960251441


100%|██████████| 13/13 [00:03<00:00,  3.26it/s]


val loss : 0.6836749406961294, val IOU : 0.677873799434075, val DICE : 0.8146156622813299
epoch : 24


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.3231003087300521, train IOU : 0.9055789415652935, train DICE : 0.9561481833457947


100%|██████████| 13/13 [00:03<00:00,  3.27it/s]


val loss : 0.6525460137770727, val IOU : 0.6938513746628394, val DICE : 0.8085619944792527
epoch : 25


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.3091280941779797, train IOU : 0.9094937746341412, train DICE : 0.9583167021091168


100%|██████████| 13/13 [00:03<00:00,  3.33it/s]


val loss : 0.6406230743114765, val IOU : 0.6932130960317758, val DICE : 0.8270570131448599
epoch : 26


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.2756603353298627, train IOU : 0.9200954418915969, train DICE : 0.9639252635148855


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.7115704600627606, val IOU : 0.6964024901390076, val DICE : 0.8301325440406799
epoch : 27


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.22868883059575007, train IOU : 0.933881165431096, train DICE : 0.970946321120629


100%|██████████| 13/13 [00:03<00:00,  3.30it/s]


val loss : 0.7885082822579604, val IOU : 0.6989262424982511, val DICE : 0.8342499457872831
epoch : 31


100%|██████████| 65/65 [00:33<00:00,  1.95it/s]


train loss : 0.21619968758179592, train IOU : 0.9373990535736084, train DICE : 0.9726592595760639


100%|██████████| 13/13 [00:03<00:00,  3.29it/s]


val loss : 0.851124277481666, val IOU : 0.7017999749917251, val DICE : 0.8395786056151757
epoch : 32


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.2064355533856612, train IOU : 0.9403478494057289, train DICE : 0.9741319289574256


100%|██████████| 13/13 [00:03<00:00,  3.28it/s]


val loss : 0.826722250534938, val IOU : 0.7085270652404199, val DICE : 0.8399726840165945
epoch : 33


100%|██████████| 65/65 [00:33<00:00,  1.97it/s]


train loss : 0.19810563967778133, train IOU : 0.94273067345986, train DICE : 0.9752727150917053


100%|██████████| 13/13 [00:03<00:00,  3.32it/s]


val loss : 0.909263340326456, val IOU : 0.7064114671487075, val DICE : 0.8393315627024724
epoch : 34


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.18993266201936282, train IOU : 0.9450871614309457, train DICE : 0.97630942601424


100%|██████████| 13/13 [00:03<00:00,  3.34it/s]


val loss : 0.8520312630213224, val IOU : 0.7141324694340045, val DICE : 0.8452771672835717
epoch : 35


100%|██████████| 65/65 [00:33<00:00,  1.96it/s]


train loss : 0.18144837067677425, train IOU : 0.9475454009496249, train DICE : 0.9775238238848173


100%|██████████| 13/13 [00:03<00:00,  3.27it/s]


val loss : 0.8992815430347736, val IOU : 0.7161923280129066, val DICE : 0.8463885554900537
epoch : 36


 43%|████▎     | 28/65 [00:14<00:19,  1.88it/s]


KeyboardInterrupt: 

In [None]:
print(max_IOU)

In [None]:
inference_name = "PMFSNet_baseline"
# model = HarDMSEG(in_channels = 1)
model = PMFSNet(in_channels = 1, out_channels = 1, dim = "2d")
# model = LightMed(in_channels = 1, out_channels = 1, image_size = image_size)
checkpoint = torch.load(f"models/{inference_name}/last_checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

In [None]:

total_loss_val, total_IOU_val, total_DICE_val = val(test_dataloader, model, loss_fn, "cuda")
print(f"val loss : {total_loss_val}, val IOU : {total_IOU_val}, val DICE : {total_DICE_val}")


In [None]:
DICE = 0
# model.to("cuda")
model.eval()
for image, mask in tqdm(test_dataloader):
    image, mask = image, mask
    preds = model(image)
    DICE += (1 - dice_loss(preds, mask)).item()
print(DICE/len(test_dataloader))

In [None]:
index = 7
plt.subplot(1,2,1)
plt.imshow(preds[index][0])
plt.subplot(1,2,2)
plt.imshow(mask[index][0])

In [None]:
dice_loss = DiceLoss()
print(1 - dice_loss(preds, mask))

In [None]:
plt.imshow(mask[index][0])