In [1]:
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.auto import tqdm

from utils import *

In [2]:
random_state = 42

df = pd.read_csv("images.csv")
df_train, df_test = train_test_split(
    df, test_size=2 / 10, stratify=df["score"], random_state=random_state
)
df_train, df_val = train_test_split(
    df_train, test_size=1 / 8, stratify=df_train["score"], random_state=random_state
)

ds_train = MRIDataset(df_train)
ds_val = MRIDataset(df_val)
ds_test = MRIDataset(df_test)

KeyboardInterrupt: 

In [3]:
class Baseline(nn.Module):
    def __conv_subblock(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=(5, 1, 1),
                stride=1,
                padding="same",
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(1, 5, 1),
                stride=1,
                padding="same",
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(1, 1, 5),
                stride=1,
                padding="same",
            ),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(),
        )

    def __init__(self):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv3d(1, 8, kernel_size=(5, 5, 5), stride=2),
            nn.BatchNorm3d(8),
            nn.ReLU(),
            self.__conv_subblock(8, 8),
            self.__conv_subblock(8, 16),
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2),
            self.__conv_subblock(16, 8),
            self.__conv_subblock(8, 16),
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2),
            self.__conv_subblock(16, 8),
            self.__conv_subblock(8, 16),
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2),
            self.__conv_subblock(16, 8),
            self.__conv_subblock(8, 16),
            nn.AvgPool3d((10, 14, 14)),
        )
        self.readout_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16, 32, bias=True),
            nn.Dropout(0.5),
            nn.Linear(32, 2, bias=True),
            nn.Softmax(dim=1),
        )

    def forward(self, x: Tensor):
        out = self.conv_layers(x)
        out = self.readout_layers(out)

        return out

In [12]:
learning_rate = 0.0005
epochs = 400
batch_size = 4
patience = 5
delta = 0.01

In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Baseline().apply(init_weights).to(device)
summary(model, input_size=(batch_size, 1, 192, 256, 256))

Layer (type:depth-idx)                   Output Shape              Param #
Baseline                                 [4, 2]                    --
├─Sequential: 1-1                        [4, 16, 1, 1, 1]          --
│    └─Conv3d: 2-1                       [4, 8, 94, 126, 126]      1,008
│    └─BatchNorm3d: 2-2                  [4, 8, 94, 126, 126]      16
│    └─ReLU: 2-3                         [4, 8, 94, 126, 126]      --
│    └─Sequential: 2-4                   [4, 8, 94, 126, 126]      --
│    │    └─Conv3d: 3-1                  [4, 8, 94, 126, 126]      328
│    │    └─Conv3d: 3-2                  [4, 8, 94, 126, 126]      328
│    │    └─Conv3d: 3-3                  [4, 8, 94, 126, 126]      328
│    │    └─BatchNorm3d: 3-4             [4, 8, 94, 126, 126]      16
│    │    └─ReLU: 3-5                    [4, 8, 94, 126, 126]      --
│    └─Sequential: 2-5                   [4, 16, 94, 126, 126]     --
│    │    └─Conv3d: 3-6                  [4, 16, 94, 126, 126]     656
│    │  

In [14]:
loader_train = DataLoader(ds_train, batch_size=batch_size)
loader_val = DataLoader(ds_val, batch_size=batch_size)
loader_test = DataLoader(ds_test, batch_size=batch_size)

In [15]:
loss_fn = nn.CrossEntropyLoss().to(device)
opt = optim.Adam(model.parameters(), lr=learning_rate)
es = EarlyStopping(patience=patience, delta=delta)

In [16]:
for epoch in tqdm(range(epochs)):
    model.train()
    for X, y in loader_train:
        X = X.to(device)
        y = y.to(device)

        opt.zero_grad()
        pred = model(X)
        cost = loss_fn(pred, y)
        cost.backward()
        opt.step()

    model.eval()
    with torch.no_grad():
        val_loss = sum(
            loss_fn(model(X.to(device)), y.to(device)) for X, y in loader_val
        )

    es.evaluate(model, val_loss)

    print(f"Epoch {epoch}: val_loss={val_loss / len(loader_val)}, early_stopping_count={es.counter}")

    if es.should_stop():
        model = es.load_best(model)
        break

  0%|          | 0/400 [00:00<?, ?it/s]

Epoch 0: val_loss=0.5932536721229553, early_stopping_count=0
Epoch 1: val_loss=0.5151541233062744, early_stopping_count=0
Epoch 2: val_loss=0.46176382899284363, early_stopping_count=0
Epoch 3: val_loss=0.5268815755844116, early_stopping_count=1
Epoch 4: val_loss=0.44984298944473267, early_stopping_count=0
Epoch 5: val_loss=0.5222452282905579, early_stopping_count=1
Epoch 6: val_loss=0.42839497327804565, early_stopping_count=0
Epoch 7: val_loss=0.4613119959831238, early_stopping_count=1
Epoch 8: val_loss=0.43004995584487915, early_stopping_count=2
Epoch 9: val_loss=0.42143091559410095, early_stopping_count=0
Epoch 10: val_loss=0.5452969670295715, early_stopping_count=1
Epoch 11: val_loss=0.5484550595283508, early_stopping_count=2
Epoch 12: val_loss=0.4542331099510193, early_stopping_count=3
Epoch 13: val_loss=0.504102349281311, early_stopping_count=4
Epoch 14: val_loss=0.43712279200553894, early_stopping_count=5


In [17]:
y_true_raw = []
y_pred_raw = []

model.eval()
with torch.no_grad():
    for X, y in loader_test:
        X = X.to(device)
        y = y.to(device)

        pred = model(X)

        y_true_raw.append(y)
        y_pred_raw.append(pred)

y_true_raw = torch.concatenate(y_true_raw).cpu().numpy()
y_pred_raw = torch.concatenate(y_pred_raw).cpu().numpy()

In [18]:
y_true = y_true_raw.argmax(axis=1)
y_pred = y_pred_raw.argmax(axis=1)
y_true, y_pred

(array([1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1]),
 array([1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0,
        1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1,
        1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1]))

In [19]:
confusion_matrix(y_true, y_pred)

array([[40,  8],
       [ 0, 40]])