# load dataset

In [1]:
import os
import torch
import torchio as tio

dataset = torch.load('../data/preproc.pt', weights_only=False)

# split into train and test
BATCH_SIZE = 10

train_data, test_data = torch.utils.data.random_split(dataset, [0.7, 0.3])

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# check loaders
for i, batch in enumerate(train_loader):
    inputs = batch['mri'][tio.DATA] # shape: (B, 1, D, H, W)
    labels = batch['label']
    print(i, inputs.shape, labels.shape)
    break

train_N = len(train_loader.dataset)
test_N = len(test_loader.dataset)
len(dataset), train_N, test_N

  from .autonotebook import tqdm as notebook_tqdm

0 torch.Size([10, 1, 512, 512, 60]) torch.Size([10])


(290, 203, 87)

# model

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class MRIClassifier(nn.Module):
    def __init__(self, in_channels=1, num_classes=2):
        super(MRIClassifier, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(32)
        self.pool1 = nn.MaxPool3d(kernel_size=2)

        self.conv2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(64)
        self.pool2 = nn.MaxPool3d(kernel_size=2)

        self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.global_pool(x) # shape: (B, 64, 1, 1, 1)
        x = x.view(x.size(0), -1) # shape: (B, 64)
        return self.fc(x)

# train

In [None]:
# compile model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MRIClassifier()
model = nn.DataParallel(model)
model.to(DEVICE)

# # weighted loss
# import numpy as np
# from sklearn.utils.class_weight import compute_class_weight
# train_labels = np.array([dataset[i][1] for i in train_data.indices])
# classes = np.arange(train_loader.dataset.dataset.num_classes)
# class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_labels)
# loss_fn = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float32)).to(DEVICE) # expects logits

loss_fn = torch.nn.CrossEntropyLoss() # expects logits
optimizer = torch.optim.Adam(model.parameters())

# train
EPOCHS = 100
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_loss = torch.inf
for epoch in range(EPOCHS):
    # train
    model.train()
    running_loss = 0.0
    running_correct = 0
    for i, batch in enumerate(train_loader):
        inputs = batch['mri'][tio.DATA] # shape: (B, 1, D, H, W)
        labels = batch['label']
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_correct += (torch.softmax(outputs, dim=1).argmax(dim=1) == labels).sum().item()

    print(f"Epoch {epoch+1}/{EPOCHS}\tTrain Loss: {running_loss/train_N:.4f} | Train Acc: {running_correct/train_N:.4f}")
    train_losses.append(running_loss/train_N)
    train_accs.append(running_correct/train_N)

    # save model
    torch.save(model, f'../results/model.pt')



Epoch 1/100	Train Loss: 0.0699 | Train Acc: 0.6108




Epoch 2/100	Train Loss: 0.0542 | Train Acc: 0.7980




Epoch 3/100	Train Loss: 0.0515 | Train Acc: 0.7980




Epoch 4/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 5/100	Train Loss: 0.0508 | Train Acc: 0.7980




Epoch 6/100	Train Loss: 0.0534 | Train Acc: 0.7980




Epoch 7/100	Train Loss: 0.0542 | Train Acc: 0.7980




Epoch 8/100	Train Loss: 0.0525 | Train Acc: 0.7980




Epoch 9/100	Train Loss: 0.0516 | Train Acc: 0.7980




Epoch 10/100	Train Loss: 0.0540 | Train Acc: 0.7980




Epoch 11/100	Train Loss: 0.0530 | Train Acc: 0.7980




Epoch 12/100	Train Loss: 0.0546 | Train Acc: 0.7980




Epoch 13/100	Train Loss: 0.0524 | Train Acc: 0.7980




Epoch 14/100	Train Loss: 0.0542 | Train Acc: 0.7980




Epoch 15/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 16/100	Train Loss: 0.0511 | Train Acc: 0.7980




Epoch 17/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 18/100	Train Loss: 0.0526 | Train Acc: 0.7980




Epoch 19/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 20/100	Train Loss: 0.0523 | Train Acc: 0.7980




Epoch 21/100	Train Loss: 0.0510 | Train Acc: 0.7980




Epoch 22/100	Train Loss: 0.0531 | Train Acc: 0.7980




Epoch 23/100	Train Loss: 0.0526 | Train Acc: 0.7980




Epoch 24/100	Train Loss: 0.0511 | Train Acc: 0.7980




Epoch 25/100	Train Loss: 0.0514 | Train Acc: 0.7980




Epoch 26/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 27/100	Train Loss: 0.0526 | Train Acc: 0.7980




Epoch 28/100	Train Loss: 0.0531 | Train Acc: 0.7980




Epoch 29/100	Train Loss: 0.0510 | Train Acc: 0.7980




Epoch 30/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 31/100	Train Loss: 0.0511 | Train Acc: 0.7980




Epoch 32/100	Train Loss: 0.0523 | Train Acc: 0.7980




Epoch 33/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 34/100	Train Loss: 0.0533 | Train Acc: 0.7980




Epoch 35/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 36/100	Train Loss: 0.0514 | Train Acc: 0.7980




Epoch 37/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 38/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 39/100	Train Loss: 0.0558 | Train Acc: 0.7980




Epoch 40/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 41/100	Train Loss: 0.0525 | Train Acc: 0.7980




Epoch 42/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 43/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 44/100	Train Loss: 0.0530 | Train Acc: 0.7980




Epoch 45/100	Train Loss: 0.0531 | Train Acc: 0.7980




Epoch 46/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 47/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 48/100	Train Loss: 0.0525 | Train Acc: 0.7980




Epoch 49/100	Train Loss: 0.0511 | Train Acc: 0.7980




Epoch 50/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 51/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 52/100	Train Loss: 0.0532 | Train Acc: 0.7980




Epoch 53/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 54/100	Train Loss: 0.0523 | Train Acc: 0.7980




Epoch 55/100	Train Loss: 0.0507 | Train Acc: 0.7980




Epoch 56/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 57/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 58/100	Train Loss: 0.0510 | Train Acc: 0.7980




Epoch 59/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 60/100	Train Loss: 0.0531 | Train Acc: 0.7980




Epoch 61/100	Train Loss: 0.0508 | Train Acc: 0.7980




Epoch 62/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 63/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 64/100	Train Loss: 0.0514 | Train Acc: 0.7980




Epoch 65/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 66/100	Train Loss: 0.0517 | Train Acc: 0.7980




Epoch 67/100	Train Loss: 0.0510 | Train Acc: 0.7980




Epoch 68/100	Train Loss: 0.0536 | Train Acc: 0.7980




Epoch 69/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 70/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 71/100	Train Loss: 0.0526 | Train Acc: 0.7980




Epoch 72/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 73/100	Train Loss: 0.0515 | Train Acc: 0.7980




Epoch 74/100	Train Loss: 0.0547 | Train Acc: 0.7980




Epoch 75/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 76/100	Train Loss: 0.0531 | Train Acc: 0.7980




Epoch 77/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 78/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 79/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 80/100	Train Loss: 0.0515 | Train Acc: 0.7980




Epoch 81/100	Train Loss: 0.0523 | Train Acc: 0.7980




Epoch 82/100	Train Loss: 0.0511 | Train Acc: 0.7980




Epoch 83/100	Train Loss: 0.0514 | Train Acc: 0.7980




Epoch 84/100	Train Loss: 0.0549 | Train Acc: 0.7980




Epoch 85/100	Train Loss: 0.0528 | Train Acc: 0.7980




Epoch 86/100	Train Loss: 0.0530 | Train Acc: 0.7980




Epoch 87/100	Train Loss: 0.0530 | Train Acc: 0.7980




Epoch 88/100	Train Loss: 0.0533 | Train Acc: 0.7980




Epoch 89/100	Train Loss: 0.0543 | Train Acc: 0.7980




Epoch 90/100	Train Loss: 0.0510 | Train Acc: 0.7980




Epoch 91/100	Train Loss: 0.0510 | Train Acc: 0.7980




Epoch 92/100	Train Loss: 0.0512 | Train Acc: 0.7980




Epoch 93/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 94/100	Train Loss: 0.0520 | Train Acc: 0.7980




Epoch 95/100	Train Loss: 0.0524 | Train Acc: 0.7980




Epoch 96/100	Train Loss: 0.0527 | Train Acc: 0.7980




Epoch 97/100	Train Loss: 0.0546 | Train Acc: 0.7980




Epoch 98/100	Train Loss: 0.0513 | Train Acc: 0.7980




Epoch 99/100	Train Loss: 0.0525 | Train Acc: 0.7980




Epoch 100/100	Train Loss: 0.0525 | Train Acc: 0.7980




AttributeError: 'str' object has no attribute 'to'

# test

In [None]:
import numpy as np

# load best model
model = torch.load(f'../results/model.pt', weights_only=False).to(DEVICE)

model.eval()

y_true = []
y_logits = []
y_scores = []
y_pred = []

correct = 0

with torch.no_grad():
    for i, batch in enumerate(test_loader):
        inputs = batch['mri'][tio.DATA] # shape: (B, 1, D, H, W)
        labels = batch['label']
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)

        y_true.extend(labels.cpu().numpy())
        y_logits.extend(outputs.cpu().numpy())
        y_scores.extend(torch.softmax(outputs, dim=1).cpu().numpy())
        y_pred.extend(torch.softmax(outputs, dim=1).argmax(dim=1).cpu().numpy())

        correct += (torch.softmax(outputs, dim=1).argmax(dim=1) == labels).sum().item()

y_true = np.array(y_true)
y_logits = np.array(y_logits)
y_scores = np.array(y_scores)
y_pred = np.array(y_pred)

print()
print(f'Test Accuracy: {correct/test_N:.4f}')

# save

In [None]:
import pandas as pd

# save datasets, optimizer state, and metrics to metadata file
torch.save({'train_loader': train_loader,#train_data,
            'test_loader': test_loader,#test_data,
            'optimizer_state': optimizer.state_dict(),
            'y_true': y_true,
            'y_logits': y_logits,
            'y_scores': y_scores,
            'y_pred': y_pred,
            'metrics': pd.DataFrame({'train_loss': train_losses,
                                     'train_acc': train_accs,})},
            
            f'../results/metadata.pt')