In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from templates import *
from templates_cls import *
import monai.transforms as T
from monai.networks.nets import densenet121



In [3]:
channel = 0  # 0 = Flair
train_transforms = T.Compose(
[
        T.LoadImaged(keys=["image", "label"]),
        T.EnsureChannelFirstd(keys=["image", "label"]),
        T.Lambdad(keys=["image"], func=lambda x: x[channel, None, :, :, :]),
        T.EnsureTyped(keys=["image", "label"]),
        T.Orientationd(keys=["image", "label"], axcodes="RAS"),
        T.Spacingd(keys=["image", "label"], pixdim=(3.0, 3.0, 2.0), mode=("bilinear", "nearest")),
        T.CenterSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 64)),
        T.ScaleIntensityRangePercentilesd(keys="image", lower=0, upper=99.5, b_min=0, b_max=1),
        T.RandSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 1), random_size=False),
        T.Lambdad(keys=["image", "label"], func=lambda x: x.squeeze(-1)),
        T.CopyItemsd(keys=["label"], times=1, names=["slice_label"]),
        T.Lambdad(keys=["slice_label"], func=lambda x: 1.0 if x.sum() > 0 else 0.0),
]
)

dataset_train =  DecathlonDataset(
root_dir="/DATA/NAS/datasets_source/brain",
task="Task01_BrainTumour",
section="training",
cache_rate=1.0,  # you may need a few Gb of RAM... Set to 0 otherwise
num_workers=4,
download=True,  # Set download to True if the dataset hasnt been downloaded yet
seed=0,
transform=train_transforms,
)

val_transforms = T.Compose(
[
        T.LoadImaged(keys=["image", "label"]),
        T.EnsureChannelFirstd(keys=["image", "label"]),
        T.Lambdad(keys=["image"], func=lambda x: x[channel, None, :, :, :]),
        T.EnsureTyped(keys=["image", "label"]),
        T.Orientationd(keys=["image", "label"], axcodes="RAS"),
        T.Spacingd(keys=["image", "label"], pixdim=(3.0, 3.0, 2.0), mode=("bilinear", "nearest")),
        T.CenterSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 1)),
        T.ScaleIntensityRangePercentilesd(keys="image", lower=0, upper=99.5, b_min=0, b_max=1),
        T.Lambdad(keys=["image", "label"], func=lambda x: x.squeeze(-1)),
        T.CopyItemsd(keys=["label"], times=1, names=["slice_label"]),
        T.Lambdad(keys=["slice_label"], func=lambda x: 1 if x.sum() > 0 else 0),
]
)

dataset_val =  DecathlonDataset(
root_dir="/DATA/NAS/datasets_source/brain",
task="Task01_BrainTumour",
section="validation",
cache_rate=1.0,  # you may need a few Gb of RAM... Set to 0 otherwise
num_workers=4,
download=True,  # Set download to True if the dataset hasnt been downloaded yet
seed=0,
transform=val_transforms,
)


2024-03-25 13:27:40,445 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2024-03-25 13:27:40,446 - INFO - File exists: /DATA/NAS/datasets_source/brain/Task01_BrainTumour.tar, skipped downloading.
2024-03-25 13:27:40,446 - INFO - Non-empty folder exists in /DATA/NAS/datasets_source/brain/Task01_BrainTumour, skipped extracting.


Loading dataset: 100%|██████████| 388/388 [02:17<00:00,  2.82it/s]


2024-03-25 13:30:08,779 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2024-03-25 13:30:08,781 - INFO - File exists: /DATA/NAS/datasets_source/brain/Task01_BrainTumour.tar, skipped downloading.
2024-03-25 13:30:08,781 - INFO - Non-empty folder exists in /DATA/NAS/datasets_source/brain/Task01_BrainTumour, skipped extracting.


Loading dataset: 100%|██████████| 96/96 [00:32<00:00,  2.93it/s]


In [4]:
batch_size = 8
train_loader = DataLoader(dataset_train, batch_size, shuffle=False, num_workers=8, pin_memory=True)
val_loader = DataLoader(dataset_val, batch_size, shuffle=False, num_workers=8, pin_memory=True)

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import DenseNet121
from sklearn.metrics import accuracy_score

# Model setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1).to(device)
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

# Early Stopping Setup
early_stopping_patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch_data in train_loader:
        images, labels = batch_data["image"].to(device), batch_data["slice_label"].float().to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels.unsqueeze(1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Learning rate scheduler step
    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    total_preds = []
    total_labels = []
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = batch_data["image"].to(device), batch_data["slice_label"].float().to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels.unsqueeze(1))
            val_loss += loss.item()

            preds = torch.sigmoid(outputs).round() 
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    val_acc = accuracy_score(total_labels, total_preds)

    print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stopping_patience:
            print('Early stopping!')
            break


Epoch 1/100, Train Loss: 0.6204, Val Loss: 0.5165, Val Acc: 0.9062
Epoch 2/100, Train Loss: 0.4991, Val Loss: 0.4646, Val Acc: 0.9167
Epoch 3/100, Train Loss: 0.4287, Val Loss: 0.4331, Val Acc: 0.9167
Epoch 4/100, Train Loss: 0.3752, Val Loss: 0.4066, Val Acc: 0.9167
Epoch 5/100, Train Loss: 0.3308, Val Loss: 0.3858, Val Acc: 0.9167
Epoch 6/100, Train Loss: 0.2931, Val Loss: 0.3686, Val Acc: 0.9271
Epoch 7/100, Train Loss: 0.2601, Val Loss: 0.3539, Val Acc: 0.9271
Epoch 8/100, Train Loss: 0.2310, Val Loss: 0.3423, Val Acc: 0.9271
Epoch 9/100, Train Loss: 0.2053, Val Loss: 0.3317, Val Acc: 0.9271
Epoch 10/100, Train Loss: 0.1827, Val Loss: 0.3220, Val Acc: 0.9271
Epoch 11/100, Train Loss: 0.1629, Val Loss: 0.3140, Val Acc: 0.9271
Epoch 12/100, Train Loss: 0.1457, Val Loss: 0.3055, Val Acc: 0.9271
Epoch 13/100, Train Loss: 0.1306, Val Loss: 0.2989, Val Acc: 0.9271
Epoch 14/100, Train Loss: 0.1177, Val Loss: 0.2923, Val Acc: 0.9271
Epoch 15/100, Train Loss: 0.1062, Val Loss: 0.2880, Val A

In [15]:
# Save the model state dictionary
torch.save(model.state_dict(), "/home/matan/latent_dae/diffae/resnet/brats_resnet.pt")

In [6]:
from monai.networks.nets import DenseNet121
from sklearn.metrics import roc_auc_score, f1_score

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=1).to(device)

model.load_state_dict(torch.load("/home/matan/latent_dae/diffae/resnet/brats_resnet.pt"))
model.eval()

total_preds = []
total_labels = []
with torch.no_grad():
    for batch_data in val_loader:
        images, labels = batch_data['image'].to(device), batch_data['slice_label'].to(device)
        labels = (labels > 0)
        labels = labels.float()
        outputs = model(images)
        preds = torch.sigmoid(outputs).round()
        total_preds.extend(preds.cpu().numpy())
        total_labels.extend(labels.cpu().numpy())

auc = roc_auc_score(total_labels, total_preds)
f1 = f1_score(total_labels, total_preds)

print("AUC: ", auc)
print("F1: ", f1)

AUC:  0.49444444444444446
F1:  0.9621621621621622
