In [1]:
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.auto import tqdm

from utils import *

In [2]:
random_state = 42

df = pd.read_csv("images.csv")
df_train, df_test = train_test_split(
    df, test_size=2 / 10, stratify=df["score"], random_state=random_state
)
df_train, df_val = train_test_split(
    df_train, test_size=1 / 8, stratify=df_train["score"], random_state=random_state
)

ds_train = MRIDataset(df_train)
ds_val = MRIDataset(df_val)
ds_test = MRIDataset(df_test)

In [3]:
class BasicBlock(nn.Module):
    def __conv_subblock(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=(3, 1, 1),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(1, 3, 1),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(1, 1, 3),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm3d(out_channels),
        )

    def __init__(self, channels: int):
        super().__init__()

        self.subblock1 = self.__conv_subblock(channels, channels)
        self.rl1 = nn.ReLU()
        self.subblock2 = self.__conv_subblock(channels, channels)
        self.rl2 = nn.ReLU()

    def forward(self, x: Tensor):
        identity = x

        out = self.subblock1(x)
        out = self.rl1(out)
        out = self.subblock2(out)

        out += identity
        out = self.rl2(out)

        return out


class BottleNeck(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv3d(
                in_channels,
                in_channels,
                kernel_size=(1, 1, 1),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm3d(in_channels),
        )
        self.rl1 = nn.ReLU()
        self.conv2 = nn.Sequential(
            nn.Conv3d(
                in_channels,
                in_channels,
                kernel_size=(3, 1, 1),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.Conv3d(
                in_channels,
                in_channels,
                kernel_size=(1, 3, 1),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.Conv3d(
                in_channels,
                in_channels,
                kernel_size=(1, 1, 3),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm3d(in_channels),
        )
        self.rl2 = nn.ReLU()
        self.conv3 = nn.Sequential(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=(1, 1, 1),
                stride=1,
                padding="same",
                bias=False,
            ),
            nn.BatchNorm3d(out_channels),
        )
        self.rl3 = nn.ReLU()

        if in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv3d(
                    in_channels,
                    out_channels,
                    kernel_size=(1, 1, 1),
                    stride=1,
                    padding="same",
                    bias=False,
                ),
                nn.BatchNorm3d(out_channels),
            )
        else:
            self.downsample = nn.Identity()

    def forward(self, x: Tensor):
        identity = self.downsample(x)

        out = self.conv1(x)
        out = self.rl1(out)
        out = self.conv2(x)
        out = self.rl2(out)
        out = self.conv3(x)

        out += identity
        out = self.rl3(out)

        return out


class Residual(nn.Module):
    def __init__(self):
        super().__init__()

        self.init_block = nn.Sequential(
            nn.Conv3d(1, 8, kernel_size=(5, 5, 5), stride=2),
            nn.BatchNorm3d(8),
            nn.ReLU(),
        )

        self.res = nn.Sequential(
            BasicBlock(8),
            BasicBlock(8),
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2),
            BottleNeck(8, 8),
            BasicBlock(8),
            BasicBlock(8),
            BasicBlock(8),
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2),
            BottleNeck(8, 16),
            BasicBlock(16),
            BasicBlock(16),
            BasicBlock(16),
        )

        self.global_pooling = nn.AdaptiveAvgPool3d((1, 1, 1))

        self.readout_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16, 32, bias=True),
            nn.Dropout(0.5),
            nn.Linear(32, 2, bias=True),
            nn.Softmax(dim=1),
        )

    def forward(self, x: Tensor):
        out = self.init_block(x)
        out = self.res(out)
        out = self.global_pooling(out)
        out = self.readout_block(out)

        return out

In [5]:
learning_rate = 0.0005
epochs = 400
batch_size = 4
patience = 5
delta = 0.01

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Residual().apply(init_weights).to(device)
summary(model, input_size=(batch_size, 1, 192, 256, 256))

Layer (type:depth-idx)                   Output Shape              Param #
Residual                                 [4, 2]                    --
├─Sequential: 1-1                        [4, 8, 94, 126, 126]      --
│    └─Conv3d: 2-1                       [4, 8, 94, 126, 126]      1,008
│    └─BatchNorm3d: 2-2                  [4, 8, 94, 126, 126]      16
│    └─ReLU: 2-3                         [4, 8, 94, 126, 126]      --
├─Sequential: 1-2                        [4, 16, 22, 30, 30]       --
│    └─BasicBlock: 2-4                   [4, 8, 94, 126, 126]      --
│    │    └─Sequential: 3-1              [4, 8, 94, 126, 126]      592
│    │    └─ReLU: 3-2                    [4, 8, 94, 126, 126]      --
│    │    └─Sequential: 3-3              [4, 8, 94, 126, 126]      592
│    │    └─ReLU: 3-4                    [4, 8, 94, 126, 126]      --
│    └─BasicBlock: 2-5                   [4, 8, 94, 126, 126]      --
│    │    └─Sequential: 3-5              [4, 8, 94, 126, 126]      592
│    │   

In [9]:
loader_train = DataLoader(ds_train, batch_size=batch_size)
loader_val = DataLoader(ds_val, batch_size=batch_size)
loader_test = DataLoader(ds_test, batch_size=batch_size)

In [28]:
loss_fn = nn.CrossEntropyLoss().to(device)
opt = optim.Adam(model.parameters(), lr=learning_rate)
es = EarlyStopping(patience=patience, delta=delta)

In [29]:
for epoch in tqdm(range(epochs)):
    model.train()
    for X, y in loader_train:
        X = X.to(device)
        y = y.to(device)

        opt.zero_grad()
        pred = model(X)
        cost = loss_fn(pred, y)
        cost.backward()
        opt.step()

    model.eval()
    with torch.no_grad():
        val_loss = sum(
            loss_fn(model(X.to(device)), y.to(device)) for X, y in loader_val
        )

    es.evaluate(model, val_loss)

    print(f"Epoch {epoch}: val_loss={val_loss / len(loader_val)}, early_stopping_count={es.counter}")

    if es.should_stop():
        model = es.load_best(model)
        break

  0%|          | 0/400 [00:00<?, ?it/s]

Epoch 0: val_loss=0.5775119662284851, early_stopping_count=0
Epoch 1: val_loss=0.5402666330337524, early_stopping_count=0
Epoch 2: val_loss=0.4507451057434082, early_stopping_count=0
Epoch 3: val_loss=0.49698540568351746, early_stopping_count=1
Epoch 4: val_loss=0.44499099254608154, early_stopping_count=0
Epoch 5: val_loss=0.49840179085731506, early_stopping_count=1
Epoch 6: val_loss=0.6692925691604614, early_stopping_count=2
Epoch 7: val_loss=0.416386216878891, early_stopping_count=0
Epoch 8: val_loss=0.4204461872577667, early_stopping_count=1
Epoch 9: val_loss=0.4149344265460968, early_stopping_count=0
Epoch 10: val_loss=0.40338993072509766, early_stopping_count=0
Epoch 11: val_loss=0.6030913591384888, early_stopping_count=1
Epoch 12: val_loss=0.48103559017181396, early_stopping_count=2
Epoch 13: val_loss=0.4471224248409271, early_stopping_count=3
Epoch 14: val_loss=0.4542314112186432, early_stopping_count=4
Epoch 15: val_loss=0.42859435081481934, early_stopping_count=5


In [7]:
model.load_state_dict(torch.load('residual_2_240403_01.pth'))

<All keys matched successfully>

In [10]:
y_true_raw = []
y_pred_raw = []

model.eval()
with torch.no_grad():
    for X, y in loader_test:
        X = X.to(device)
        y = y.to(device)

        pred = model(X)

        y_true_raw.append(y)
        y_pred_raw.append(pred)

y_true_raw = torch.concatenate(y_true_raw).cpu().numpy()
y_pred_raw = torch.concatenate(y_pred_raw).cpu().numpy()

In [11]:
y_true = y_true_raw.argmax(axis=1)
y_pred = y_pred_raw.argmax(axis=1)
y_true, y_pred

(array([1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1]),
 array([1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0,
        1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1]))

In [12]:
confusion_matrix(y_true, y_pred)

array([[41,  7],
       [ 1, 39]])

In [13]:
np.where((y_true == 1) & (y_pred == 0))

(array([47]),)

In [14]:
np.where((y_true == 0) & (y_pred == 1))

(array([ 8, 19, 22, 24, 32, 36, 50]),)