In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
from templates import *
from templates_cls import *
import monai.transforms as T
from monai.utils.misc import first

from medmnist.dataset import RetinaMNIST
from medmnist.info import DEFAULT_ROOT

import matplotlib.pyplot as plt

In [9]:
batch_size = 32

train_transform = T.Compose([
                ToTensor(),
                T.RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
                T.RandFlip(spatial_axis=[-1, -2], prob=0.5),
                T.RandGridDistortion(prob=0.5),
                T.RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
                T.ScaleIntensity(),
            ])
dataset_train =  RetinaMNIST(split="train", 
        transform=train_transform, 
        download=True, 
        as_rgb=True, 
        size=128, 
        root="/DATA/NAS/datasets_source/other")

train_loader = DataLoader(dataset_train, batch_size, shuffle=False, num_workers=8, pin_memory=True)

Using downloaded and verified file: /DATA/NAS/datasets_source/other/retinamnist_128.npz


In [20]:
val_transform = T.Compose([
                ToTensor(),
                T.ScaleIntensity(),
            ])
dataset_val =  RetinaMNIST(split="test", 
        transform=val_transform, 
        download=True, 
        as_rgb=True, 
        size=128, 
        root="/DATA/NAS/datasets_source/other")

val_loader = DataLoader(dataset_val, batch_size, shuffle=False, num_workers=8, pin_memory=True)

Using downloaded and verified file: /DATA/NAS/datasets_source/other/retinamnist_128.npz


# Regression

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import DenseNet121
from sklearn.metrics import mean_absolute_error, f1_score
from torch.utils.data import DataLoader

# Assuming train_loader and val_loader are defined

# Model setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=3, out_channels=1).to(device)
loss_function = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

# Early Stopping Setup
early_stopping_patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch_data in train_loader:
        images, labels = batch_data[0].to(device), batch_data[1].float().to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Learning rate scheduler step
    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    total_preds = []
    total_labels = []
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = batch_data[0].to(device), batch_data[1].float().to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels)
            val_loss += loss.item()

            preds = outputs
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    val_mae = mean_absolute_error(total_labels, total_preds)

    print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val MAE: {val_mae:.4f}')

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stopping_patience:
            print('Early stopping!')
            break


Epoch 1/100, Train Loss: 1.2347, Val Loss: 1.2016, Val MAE: 1.1937
Epoch 2/100, Train Loss: 1.1259, Val Loss: 1.0779, Val MAE: 1.0699
Epoch 3/100, Train Loss: 1.0254, Val Loss: 0.9975, Val MAE: 0.9902
Epoch 4/100, Train Loss: 0.9326, Val Loss: 0.9375, Val MAE: 0.9303
Epoch 5/100, Train Loss: 0.8594, Val Loss: 0.8966, Val MAE: 0.8896
Epoch 6/100, Train Loss: 0.7972, Val Loss: 0.8452, Val MAE: 0.8377
Epoch 7/100, Train Loss: 0.7390, Val Loss: 0.8035, Val MAE: 0.7961
Epoch 8/100, Train Loss: 0.6845, Val Loss: 0.7871, Val MAE: 0.7788
Epoch 9/100, Train Loss: 0.6398, Val Loss: 0.7948, Val MAE: 0.7864
Epoch 10/100, Train Loss: 0.5980, Val Loss: 0.7889, Val MAE: 0.7802
Epoch 11/100, Train Loss: 0.5605, Val Loss: 0.7944, Val MAE: 0.7860
Epoch 12/100, Train Loss: 0.5228, Val Loss: 0.7924, Val MAE: 0.7839
Epoch 13/100, Train Loss: 0.4877, Val Loss: 0.7901, Val MAE: 0.7827
Epoch 14/100, Train Loss: 0.4511, Val Loss: 0.7815, Val MAE: 0.7736
Epoch 15/100, Train Loss: 0.4112, Val Loss: 0.8100, Val M

In [12]:
# Save the model state dictionary
torch.save(model.state_dict(), "/home/matan/latent_dae/diffae/resnet/retina_resnet_reg.pt")

In [22]:
from monai.networks.nets import DenseNet121
from sklearn.metrics import mean_absolute_error, f1_score

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load("/home/matan/latent_dae/diffae/resnet/retina_resnet_reg.pt"))
model.eval()

total_preds = []
total_labels = []
with torch.no_grad():
    for batch_data in val_loader:
        images, labels = batch_data[0].to(device), batch_data[1].float().to(device)
        preds = model(images)
        total_preds.extend(preds.cpu().numpy())
        total_labels.extend(labels.cpu().numpy())

val_mae = mean_absolute_error(total_labels, total_preds)
val_f1 = f1_score(total_labels, np.round(total_preds), average="macro")

print("MAE: ", val_mae)
print("F1: ", val_f1)

MAE:  0.7647887
F1:  0.3140680188814577


# Classification

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import DenseNet121
from sklearn.metrics import accuracy_score, f1_score
from torch.utils.data import DataLoader

# Model setup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=3, out_channels=1).to(device)
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

# Learning Rate Scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.1)

# Early Stopping Setup
early_stopping_patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

max_epochs = 100
for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch_data in train_loader:
        images, labels = batch_data[0].to(device), batch_data[1].to(device)
        labels = (labels > 0)
        labels = labels.float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # Learning rate scheduler step
    scheduler.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    total_preds = []
    total_labels = []
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = batch_data[0].to(device), batch_data[1].to(device)
            labels = (labels > 0)
            labels = labels.float()
            outputs = model(images)
            loss = loss_function(outputs, labels)
            val_loss += loss.item()

            preds = torch.sigmoid(outputs)
            total_preds.extend(preds.cpu().numpy())
            total_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    val_acc = accuracy_score(total_labels, np.round(total_preds))

    print(f'Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val ACC: {val_acc:.4f}')

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve == early_stopping_patience:
            print('Early stopping!')
            break


Epoch 1/100, Train Loss: 0.6357, Val Loss: 0.6162, Val ACC: 0.7333
Epoch 2/100, Train Loss: 0.5530, Val Loss: 0.5099, Val ACC: 0.8167
Epoch 3/100, Train Loss: 0.5043, Val Loss: 0.4710, Val ACC: 0.8417
Epoch 4/100, Train Loss: 0.4691, Val Loss: 0.4511, Val ACC: 0.8500
Epoch 5/100, Train Loss: 0.4408, Val Loss: 0.4395, Val ACC: 0.8500
Epoch 6/100, Train Loss: 0.4162, Val Loss: 0.4326, Val ACC: 0.8500
Epoch 7/100, Train Loss: 0.3935, Val Loss: 0.4293, Val ACC: 0.8500
Epoch 8/100, Train Loss: 0.3713, Val Loss: 0.4270, Val ACC: 0.8417
Epoch 9/100, Train Loss: 0.3489, Val Loss: 0.4266, Val ACC: 0.8417
Epoch 10/100, Train Loss: 0.3254, Val Loss: 0.4268, Val ACC: 0.8417
Epoch 11/100, Train Loss: 0.3003, Val Loss: 0.4277, Val ACC: 0.8417
Epoch 12/100, Train Loss: 0.2734, Val Loss: 0.4289, Val ACC: 0.8417
Epoch 13/100, Train Loss: 0.2444, Val Loss: 0.4316, Val ACC: 0.8417
Epoch 14/100, Train Loss: 0.2126, Val Loss: 0.4335, Val ACC: 0.8417
Epoch 15/100, Train Loss: 0.1793, Val Loss: 0.4397, Val A

In [18]:
# Save the model state dictionary
torch.save(model.state_dict(), "/home/matan/latent_dae/diffae/resnet/retina_resnet_cls.pt")

In [23]:
from sklearn.metrics import roc_auc_score, f1_score

model = DenseNet121(spatial_dims=2, in_channels=3, out_channels=1).to(device)
model.load_state_dict(torch.load("/home/matan/latent_dae/diffae/resnet/retina_resnet_cls.pt"))
model.eval()

total_preds = []
total_labels = []
with torch.no_grad():
    for batch_data in val_loader:
        images, labels = batch_data[0].to(device), batch_data[1].to(device)
        labels = (labels > 0)
        labels = labels.float()
        outputs = model(images)
        preds = torch.sigmoid(outputs).round()
        total_preds.extend(preds.cpu().numpy())
        total_labels.extend(labels.cpu().numpy())

auc = roc_auc_score(total_labels, total_preds)
f1 = f1_score(total_labels, total_preds)

print("AUC: ", auc)
print("F1: ", f1)

AUC:  0.7932560268538297
F1:  0.8268398268398268
