### Imports

In [410]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import lightning as L
from lightning.pytorch.loggers import WandbLogger

import tqdm
import os, glob

from sklearn.model_selection import StratifiedKFold
import cv2
import numpy as np
from albumentations import Compose, Normalize, Resize, HorizontalFlip
import matplotlib.pyplot as plt
import monai

### Dataset Class and Dataloader

In [345]:
class GazeSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, gaze_paths, transform=None, use_gaze=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.gaze_paths = gaze_paths
        self.transform = transform
        self.use_gaze = use_gaze

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx], 0).astype(np.float64) / 255
        mask = cv2.imread(self.mask_paths[idx], 0).astype(np.float64).T / 255
        gaze_seq = np.load(self.gaze_paths[idx])
        gaze = self.prepare_gaze(gaze_seq)

        fused_input = np.zeros((1024, 1024, 2))
        fused_input[:, :, 0] = image[:]
        fused_input[:, :, 1] = gaze[:]
        expanded_mask = np.zeros((1024, 1024, 2))
        expanded_mask[:, :, 0] = mask[:]
        expanded_mask[:, :, 1] = mask[:]

        # transforms
        if self.use_gaze:
            transformed = self.transform(image=fused_input, mask=expanded_mask)
            mask = np.transpose(transformed["mask"], (2, 0, 1))
            mask = np.expand_dims(mask[0, :, :], 0)
        else:
            image = np.expand_dims(image, 2)
            mask = np.expand_dims(mask, 2)
            transformed = self.transform(image=image, mask=mask)
            mask = np.transpose(transformed["mask"], (2, 0, 1))
            
        
        image = np.transpose(transformed["image"], (2, 0, 1))
        
        return np.expand_dims(image, 0), np.expand_dims(mask, 0)
    
    def prepare_gaze(self, gaze_seq):
        gaze_seq = [(int(g[0]*1024), int(g[1]*1024), int(g[2])) for g in gaze_seq]
        gaze = np.zeros((1024, 1024))
        for point in gaze_seq:
            point = list(point)
            center = point[:2][::-1]
            radius = point[-1]*20
            color = (255, 255, 255)
            thickness = -1
            cv2.circle(gaze, center, radius, color, thickness)
        return gaze


class SegmentationDataModule(L.LightningDataModule):
    def __init__(self, dataset_dir, transform=None, batch_size=8, num_folds=5, use_gaze=False):
        super().__init__()
        self.image_paths = glob.glob(os.path.join(dataset_dir, "images", "*"))
        self.mask_paths = glob.glob(os.path.join(dataset_dir, "masks", "*"))
        self.gaze_paths = glob.glob(os.path.join(dataset_dir, "gaze", "*"))
        self.transform = transform
        self.batch_size = batch_size
        self.num_folds = num_folds
        self.use_gaze = use_gaze

    def setup(self, stage=None):
        # Stratified K-Fold Split
        labels = [(cv2.imread(mask_path, 0) / 255).max() for mask_path in self.mask_paths]
        skf = StratifiedKFold(n_splits=self.num_folds)
        self.folds = list(skf.split(self.image_paths, labels))

    def prepare_data(self):
        pass

    def train_dataloader(self, fold_idx):
        train_idx, _ = self.folds[fold_idx]
        train_dataset = Subset(GazeSegmentationDataset(self.image_paths, self.mask_paths, self.gaze_paths, self.transform, self.use_gaze), train_idx)
        return train_dataset, DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self, fold_idx):
        _, val_idx = self.folds[fold_idx]
        val_dataset = Subset(GazeSegmentationDataset(self.image_paths, self.mask_paths, self.gaze_paths), val_idx)
        return val_dataset, DataLoader(val_dataset, batch_size=self.batch_size)

### Model Class

In [411]:
class UNET(nn.Module):
    def __init__(self, model_params=None):
        super(UNET, self).__init__()
        if model_params is None:
            self.model = monai.networks.nets.UNet(
                spatial_dims=2,
                in_channels=2,
                out_channels=1,
                channels=(16, 32, 64, 128, 256),
                strides=(2, 2, 2, 2)
            )
        else:
            self.model = monai.networks.nets.UNet(
                spatial_dims=model_params["spatial_dims"],
                in_channels=model_params["in_channels"],
                out_channels=model_params["out_channels"],
                channels=model_params["channels"],
                strides=model_params["strides"]
            )

    def forward(self, x):
        return self.model(x)

### Lightning Module

In [None]:
class CustomLightningModule(L.LightningModule):
    def __init__(self, model):
        self.model = model
        self.dice_loss = monai.losses.GeneralizedDiceLoss(sigmoid=True)
        self.dice_metric = monai.metrics.DiceMetric()
    
    def forward(self, X):
        return self.model(X)
    
    def training_step(self, batch, batch_idx):
        X, y = batch
        X = X.type(torch.torch.cuda.FloatTensor)
        y = y.type(torch.torch.cuda.FloatTensor)

        y_hat = self.model(X)
        loss = self.dice_loss(y_hat, y)

        self.log("train/loss", loss, sync_dist=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        X, y = batch
        X = X.type(torch.torch.cuda.FloatTensor)
        y = y.type(torch.torch.cuda.FloatTensor)

        y_hat = self.model(X)
        val_loss = self.dice_loss(y_hat, y)

        y_hat = F.sigmoid(y_hat)
        y_hat = 0 + (y_hat > 0.5)
        y = 0 + (y > 0.5)
        val_metric = self.dice_metric(y_pred=y_hat, y=y)

        self.log("val/loss", val_loss, sync_dist=True)
        self.log("val/metric", val_metric, sync_dist=True)

        return val_loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

### Prepare data and setup model

In [None]:
dataset_dir = "pneumothorax"

gpus_available = torch.cuda.device_count()

transform = Compose([
    Resize(height=224, width=224),
    HorizontalFlip(p=0.5),
    Normalize(mean=(0.5,), std=(0.5,)),
])
data_module = SegmentationDataModule("pneumothorax", transform=transform, use_gaze=False)
data_module.setup()

model = UNET()
lightning_module = CustomLightningModule(model)

# Train the model using 5-fold cross-validation
for fold_idx in range(5):
    print(f"Training fold {fold_idx+1}/5")
    wandb_logger = WandbLogger(project="GAZE", name=f"Fold-{fold_idx}")
    trainer = L.Trainer(
        devices=gpus_available,
        accelerator="gpu",
        strategy="ddp",
        logger=wandb_logger,
        max_epochs=100)
    trainer.fit(lightning_module, train_dataloaders=data_module.train_dataloader(fold_idx), val_dataloaders=data_module.val_dataloader(fold_idx))