In [None]:
TRAIN_SEG = False
TRAIN_CLASS = False
TRAIN_STAR = True
TRAIN_STAR_3D = False
colab = False

try:
    from google.colab import drive
    colab = True
    import os
    import sys
    drive.mount('/content/drive/')
    %cd /content/drive/MyDrive/prl_seg
except ImportError as e:
    pass

if colab:
    try:
        import monai
    except ImportError as e:
        ! pip install monai
    try:
        import torcheval
    except ImportError as e:
        ! pip install torcheval 

In [8]:

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Subset, TensorDataset
from sklearn.model_selection import StratifiedKFold

from torcheval.metrics import (
    BinaryAccuracy,
    BinaryPrecision,
    BinaryRecall,
    BinaryF1Score,
    MulticlassAccuracy,
    MulticlassF1Score

)


from monai.metrics import DiceMetric
from monai.losses import DiceLoss, GeneralizedWassersteinDiceLoss, GeneralizedDiceLoss, DiceCELoss

import os
import random
import matplotlib.pyplot as plt

%reload_ext autoreload
%autoreload 2
from config import *
from pylib import training
from pylib.models import resnet, unet, resnet3d, unet3d
from pylib.datasets.lazy_dataset import LazyDataset

if colab:
    from monai.networks.nets import UNet
    from monai.networks.layers import Norm


random.seed(SEED)
torch.manual_seed(SEED)

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"CUDA Version: {torch.version.cuda}")

PyTorch Version: 2.5.1
CUDA Available: False
CUDA Version: None


# x - Lesion segmentation

In [9]:
#Segmentation training parameters
BATCH_SIZE_SEG = 16
LEARNING_RATE_SEG = 0.005
EPOCHS_SEG = 8

In [10]:
if TRAIN_SEG:
    train_seg_data = torch.load(PATH_TRAIN_LESION_SEG_DATASET, weights_only=False) 
    val_seg_data = torch.load(PATH_VAL_LESION_SEG_DATASET, weights_only=False) 
        #test_data = torch.load(PATH_TEST_DATASET, weights_only=False) 


    train_seg_loader = DataLoader(dataset=train_seg_data, batch_size=BATCH_SIZE_SEG, shuffle=True)
    val_seg_loader = DataLoader(dataset=val_seg_data, batch_size=BATCH_SIZE_SEG, shuffle=True)
    #test_loader = DataLoader(dataset=test_data)

In [11]:
if TRAIN_SEG:
    all_labels = torch.stack([label for _, label in train_seg_data])
    n_lesion_pixels = torch.sum(all_labels)
    n_total_pixels = all_labels.numel()
    #weight = negative/positive
    weight = 1 - torch.tensor(n_lesion_pixels/n_total_pixels)
    weight = weight.to(DEVICE)
    print(weight)

In [12]:
if TRAIN_SEG:
    bin_unet = unet.PRLUNET(n_channels=1, n_classes=1)
    bin_unet = bin_unet.to(DEVICE)

    optimizer = Adam(bin_unet.parameters(), lr=LEARNING_RATE_SEG)

    loss_fn = nn.BCEWithLogitsLoss(pos_weight=weight)

    trainer = training.Trainer(OUT_DIR)

    trainer(
        bin_unet,
        train_loader=train_seg_loader,
        val_loader=val_seg_loader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        lr=LEARNING_RATE_SEG,
        batch_size=BATCH_SIZE_SEG,
        epochs = EPOCHS_SEG,
        device=DEVICE
    )


# x - Classification of PRLs

In [13]:
#Classification training parameters¨
BATCH_SIZE_CLASS = 64
LEARNING_RATE_CLASS = 1e-5
EPOCHS_CLASS = 20

In [14]:
if TRAIN_CLASS:
    train_class_data = torch.load(PATH_3D_TRAIN_PRL_CLASS_DATASET, weights_only=False)
    val_class_data = torch.load(PATH_3D_VAL_PRL_CLASS_DATASET, weights_only=False)

    train_class_loader = DataLoader(dataset=train_class_data, batch_size=BATCH_SIZE_CLASS, shuffle=True)
    val_class_loader = DataLoader(dataset=val_class_data, batch_size=BATCH_SIZE_CLASS, shuffle=True)

In [15]:
# print(len(train_class_data))
# print(len(val_class_data))
# print([i for i in val_class_data[:][1] if i == 0].count(0))

In [16]:
if TRAIN_CLASS:
    all_labels = torch.stack([label for _, label in train_class_data])
    n_prls = torch.sum(all_labels)
    n_negatives = all_labels.numel() - n_prls
    #weight = negative/positive
    weight = torch.tensor(n_negatives/n_prls).to(DEVICE)
    weight = weight.to(DEVICE)
    print(weight)

In [17]:

if TRAIN_CLASS:
    resnet_3d = resnet3d.ResNet3d(n_channels = 1, n_classes=1)
    resnet_3d = resnet_3d.to(DEVICE)

    optimizer = Adam(resnet_3d.parameters(), lr=LEARNING_RATE_CLASS)

    loss_fn = nn.BCEWithLogitsLoss(pos_weight=weight)
    trainer = training.Trainer(OUT_DIR)

    trainer(
        resnet_3d,
        train_loader=train_class_loader,
        val_loader=val_class_loader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        lr=LEARNING_RATE_CLASS,
        batch_size=BATCH_SIZE_CLASS,
        epochs = EPOCHS_CLASS,
        device=DEVICE,
        metrics=[BinaryAccuracy(), BinaryF1Score()]
    )

# x - Semantic classification of SWI

In [None]:
BATCH_SIZE_STAR = 32
EPOCHS_STAR = 40

In [19]:
if TRAIN_STAR:
    print(DIR_TRAIN_SWI_DATASET)
    train_val_star_data = LazyDataset(DIR_TRAIN_SWI_DATASET)

In [20]:
if TRAIN_STAR:
    n_background = 0
    n_lesions = 0
    n_prls = 0
    for _, _, label in train_val_star_data:
        n_background += torch.sum(label[:, 0]).item()
        n_lesions += torch.sum(label[:, 1]).item()
        n_prls += torch.sum(label[:, 2]).item()

    print(f"n background: {n_background}, n lesions: {n_lesions}, n prls: {n_prls}")

    n_total =  n_lesions + n_prls
    weight = torch.tensor([n_total / n_lesions, n_total / n_prls]).to(DEVICE)
    print(weight)

In [None]:
run_k_fold = False
lrs = [2e-4]
if TRAIN_STAR:

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    n_prls_per_patient = torch.tensor([torch.sum(label[:, 2]) for _, _, label in train_val_star_data])
    median_n_prls = torch.median(n_prls_per_patient).item()
    print(f"Median number of PRLs (pixels) per patient: {median_n_prls}")
    k_fold_split = torch.where(n_prls_per_patient > median_n_prls, 1, 0).numpy()

    print(f"Number of patients with n PRLs > median: {k_fold_split.sum()}")
    print(f"Number of patients with n PRLs =< median: {len(k_fold_split) - k_fold_split.sum()}")

    best_fold = 3
    best_train_idx, best_val_idx  = [data for data in skf.split(train_val_star_data, k_fold_split)][best_fold]

    for fold, (train_idx, val_idx) in enumerate(skf.split(train_val_star_data, k_fold_split) if run_k_fold else [(best_train_idx, best_val_idx)]):
        print(f"Fold {fold + 1}/{5}") if run_k_fold else print(f"Using Best Fold {best_fold + 1}/{5}")
        print(f"Train idx: {train_idx} | Val idx: {val_idx}")
        for lr in lrs:
            print(f"Using Learning Rate {lr}")

            prl_seg_unet = unet.PRLUNet(n_channels=1, n_classes=3)

            optimizer = Adam(prl_seg_unet.parameters(), lr=lr, weight_decay=1e-5)
            loss_fn = DiceLoss(include_background=False, sigmoid = True, weight = weight)


            train_subset = Subset(train_val_star_data, train_idx)
            val_subset = Subset(train_val_star_data, val_idx)

            # Combine the images and labels from the subsets into a single tensor
            train_mags = torch.cat([img for img, _, _ in train_subset])
            train_phases = torch.cat([img for _, img, _ in train_subset])
            train_labels = torch.cat([label for _, _, label in train_subset])

            val_mags = torch.cat([img for img, _, _ in val_subset])
            val_phases = torch.cat([img for _, img, _ in val_subset])
            val_labels = torch.cat([label for _, _, label in val_subset])


            train_subset = TensorDataset(train_mags, train_phases, train_labels)
            val_subset = TensorDataset(val_mags, val_phases, val_labels)

            train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE_STAR, shuffle=True)
            val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE_STAR, shuffle=False)

            losses_train, losses_val = training.train_prl_seg_unet(prl_seg_unet,
                        train_loader=train_loader,
                        val_loader=val_loader,
                        loss_fn=loss_fn,
                        optimizer=optimizer,
                        epochs=EPOCHS_STAR,
                        device=DEVICE,
                        save_path=os.path.join(OUT_DIR, f"prl_unet_lr_{lr}.pt"),
                        save_last_epoch_val = True,
                        save_last_epoch_val_path=os.path.join(OUT_DIR, f"prl_unet_lr_{lr}_last_epoch_val.pt"),
                        reduce_lr_on_plateau=False,
                        validation_interval=EPOCHS_STAR
                )
            torch.save(torch.tensor(losses_train), os.path.join(OUT_DIR, f"losses_train_lr_{lr}_fold_{fold}.pt"))
            torch.save(torch.tensor(losses_val), os.path.join(OUT_DIR, f"losses_val_lr_{lr}_fold_{fold}.pt"))


            test_data = LazyDataset(DIR_TEST_SWI_DATASET)
            test_loader = DataLoader(test_data, batch_size=BATCH_SIZE_STAR, shuffle=False)

            


In [22]:
if TRAIN_STAR:
    test_data = LazyDataset(DIR_TEST_SWI_DATASET)

    test_mags = torch.cat([img for img, _, _ in test_data])
    test_phases = torch.cat([img for _, img, _ in test_data])
    test_labels = torch.cat([label for _, _, label in test_data])

    test_subset = TensorDataset(test_mags, test_phases, test_labels)
    test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE_STAR, shuffle=False)

    prl_seg_unet.eval()

    with torch.no_grad():
        all_phases, all_mags, all_labels, all_preds = [], [], [], []
        for i, (mags, phases, labels) in enumerate(test_loader):
            mags = mags.to(DEVICE)
            phases = phases.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = prl_seg_unet(mags, phases)
            preds = torch.sigmoid(outputs)
            preds = preds > 0.5

            all_phases.append(phases.cpu())
            all_mags.append(mags.cpu())
            all_labels.append(labels.cpu())
            all_preds.append(preds.cpu())
        all_phases = torch.cat(all_phases, dim=0)
        all_mags = torch.cat(all_mags, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        all_preds = torch.cat(all_preds, dim=0)
    
    torch.save(TensorDataset(all_mags, all_phases, all_labels, all_preds), f"prl_unet_lr_{lr}_last_epoch_test.pt")
    
    