In [None]:
import os
import sys
from pathlib import Path
ROOT_PATH = Path.cwd().parent
sys.path.append(str(ROOT_PATH))

from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

from hydra.utils import instantiate
from hydra import initialize_config_dir, compose

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import v2
import torch.nn.functional as F

from src.data.dataset import Nuclei
from src.models.unet import UNet
from src.utils.dice import dice_loss, dice_coeff

In [None]:
config_dir = str(ROOT_PATH / 'conf')
with initialize_config_dir(version_base=None, config_dir=config_dir):
    cfg = compose(config_name='config.yaml')

In [None]:
train_tfms = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomRotation(degrees=(0, 360)),
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),
])

val_tfms = v2.Compose([
    v2.PILToTensor(),
    v2.ToDtype(torch.float32, scale=True),
])

CHECKPOINTS_PATH = ROOT_PATH / cfg.paths.checkpoints
DATA_DIR = ROOT_PATH / cfg.paths.train_data
VAL_PERCENT = 0.2
cfg.batch_size = 1

train_files = ['Fused_S1_1.tif']
image_paths = [str(DATA_DIR / img_file) for img_file in train_files]
dataset =  Nuclei(image_paths, transforms=train_tfms)
n_val = int(len(dataset) * 0.2)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val])

train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True, pin_memory=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=cfg.batch_size, shuffle=True, pin_memory=True, num_workers=4)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNet(1, 1)
model.to(cfg.device)
optimizer = instantiate(cfg.optimizer, params=model.parameters())
criterion = instantiate(cfg.loss)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-4, total_steps=cfg.epochs * n_train)


In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    num_val_batches = len(dataloader)
    dice_score = 0

    # iterate over the validation set
    for imgs, masks in dataloader:
        images, true_masks = imgs.to(device), masks.to(device)
        with torch.no_grad():
            masks_pred = model(images)

            # predict the mask
            mask_pred = model(images)

            assert true_masks.min() >= 0 and true_masks.max() <= 1, 'True mask indices should be in [0, 1]'
            mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
            # compute the Dice score
            dice_score += dice_coeff(mask_pred, true_masks, reduce_batch_first=False)

    return dice_score / max(len(dataloader), 1)

def save_checkpoint(model, epoch):
    state_dict = model.state_dict()
    epoch = str(epoch)
    epoch_num =  '0' * (3 - len(epoch)) + epoch
    checkpoint_path = str(CHECKPOINTS_PATH / f'epoch{epoch_num}.pth')
    torch.save(state_dict, checkpoint_path)

In [None]:
for epoch in range(1, cfg.epochs + 1):
    model.train()
    epoch_loss = 0

    pbar = tqdm(total=n_train, desc=f'Epoch {epoch}')
    for imgs, masks in train_loader:
        images, true_masks = imgs.to(device), masks.to(device)
        masks_pred = model(images)

        batch_loss = criterion(masks_pred.squeeze(1), true_masks.squeeze(1))
        masks_pred_dice = F.sigmoid(masks_pred.squeeze(1))
        batch_loss += dice_loss(masks_pred_dice, true_masks.squeeze(1), multiclass=False)

        optimizer.zero_grad(set_to_none=True)
        batch_loss.backward()
        optimizer.step()
        scheduler.step()

        epoch_loss += batch_loss.item()
        
        pbar.update(1)
        pbar.set_postfix(**{'loss': f'{batch_loss.item():.4f}'})

    avg_batch_loss = epoch_loss / n_train
    pbar.set_postfix(**{'loss': f'{avg_batch_loss:.4f}'})
    pbar.close()
    
    # Validate and save scheckpoint
    val_score = evaluate(model, val_loader, device).item()
    print(f'Validation Dice score: {val_score}')
    save_checkpoint(model, epoch)