In [1]:
import pandas as pd
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torch import Tensor, nn, optim
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm.auto import tqdm
from collections import OrderedDict

from pretrained.models.medicalnet import resnet10

from utils import *

In [2]:
class ResidualBlock(nn.Module):
    def __conv_subblock(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=(1, 1, 1),
                stride=1,
                padding="same",
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(3, 1, 1),
                stride=1,
                padding="same",
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(1, 3, 1),
                stride=1,
                padding="same",
            ),
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(1, 1, 3),
                stride=1,
                padding="same",
            ),
            nn.BatchNorm3d(out_channels),
        )

    def __init__(
        self, pooling: bool = True, io_channels: int = 16, latent_channels: int = 8
    ):
        super().__init__()

        self.subblock1 = self.__conv_subblock(io_channels, latent_channels)
        self.lr1 = nn.LeakyReLU()
        self.subblock2 = self.__conv_subblock(latent_channels, io_channels)
        self.lr2 = nn.LeakyReLU()

        if pooling:
            self.pooling = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2)
        else:
            self.pooling = nn.Identity()

    def forward(self, x: Tensor):
        identity = x

        out = self.subblock1(x)
        out = self.lr1(out)
        out = self.subblock2(out)

        out += identity
        out = self.lr2(out)

        out = self.pooling(out)

        return out


class TransMedicalNet(nn.Module):
    def __init__(self):
        super().__init__()

        state_dict = torch.load(
            "pretrained/weights/medicalnet_resnet_10_23dataset.pth"
        )["state_dict"]
        state_dict = OrderedDict({k[7:]: v for k, v in state_dict.items()})

        self.medicalnet = resnet10(
            sample_input_D=192,
            sample_input_H=256,
            sample_input_W=256,
            num_seg_classes=2,
        )
        self.medicalnet.conv_seg = nn.Identity()
        self.medicalnet.load_state_dict(state_dict)

        self.bottleneck = nn.Sequential(
            nn.Conv3d(512, 8, kernel_size=(1, 1, 1), stride=1, padding="same"),
            nn.ReLU(),
        )

        # self.init_block = nn.Sequential(
        #     nn.Conv3d(4, 16, kernel_size=(3, 3, 3), stride=2),
        #     nn.BatchNorm3d(16),
        #     nn.ReLU(),
        # )

        # self.conv1 = ResidualBlock()
        # self.conv2 = ResidualBlock(pooling=False)

        self.global_pooling = nn.AdaptiveAvgPool3d((1, 1, 1))

        self.readout_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(8, 32, bias=True),
            nn.Dropout(0.5),
            nn.Linear(32, 2, bias=True),
            nn.Softmax(dim=1),
        )

        for p in self.medicalnet.parameters():
            p.requires_grad = False

        self.bottleneck.apply(init_weights)
        # self.conv1.apply(init_weights)
        # self.conv2.apply(init_weights)
        self.global_pooling.apply(init_weights)
        self.readout_block.apply(init_weights)

    def forward(self, x: Tensor):
        out = self.medicalnet(x)
        out = self.bottleneck(out)
        # out = self.init_block(out)
        # out = self.conv1(out)
        # out = self.conv2(out)
        out = self.global_pooling(out)
        out = self.readout_block(out)

        return out

In [3]:
learning_rate = 0.0005
epochs = 100
batch_size = 2
patience = 3

random_state = 42

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TransMedicalNet().to(device)
summary(model, input_size=(batch_size, 1, 192, 256, 256))

  m.weight = nn.init.kaiming_normal(m.weight, mode="fan_out")


Layer (type:depth-idx)                        Output Shape              Param #
TransMedicalNet                               [2, 2]                    --
├─ResNet: 1-1                                 [2, 512, 24, 32, 32]      --
│    └─Conv3d: 2-1                            [2, 64, 96, 128, 128]     (21,952)
│    └─BatchNorm3d: 2-2                       [2, 64, 96, 128, 128]     (128)
│    └─ReLU: 2-3                              [2, 64, 96, 128, 128]     --
│    └─MaxPool3d: 2-4                         [2, 64, 48, 64, 64]       --
│    └─Sequential: 2-5                        [2, 64, 48, 64, 64]       --
│    │    └─BasicBlock: 3-1                   [2, 64, 48, 64, 64]       (221,440)
│    └─Sequential: 2-6                        [2, 128, 24, 32, 32]      --
│    │    └─BasicBlock: 3-2                   [2, 128, 24, 32, 32]      (672,512)
│    └─Sequential: 2-7                        [2, 256, 24, 32, 32]      --
│    │    └─BasicBlock: 3-3                   [2, 256, 24, 32, 32]      

In [5]:
df = pd.read_csv("images.csv")
df_train, df_test = train_test_split(
    df, test_size=2 / 10, stratify=df["score"], random_state=random_state
)
df_train, df_val = train_test_split(
    df_train, test_size=1 / 8, stratify=df_train["score"], random_state=random_state
)

loader_train = DataLoader(MRIDataset(df_train), batch_size=batch_size)
loader_val = DataLoader(MRIDataset(df_val), batch_size=batch_size)
loader_test = DataLoader(MRIDataset(df_test), batch_size=batch_size)

In [6]:
loss_fn = nn.CrossEntropyLoss().to(device)
opt = optim.Adam(model.parameters(), lr=learning_rate)
es = EarlyStopping(patience=patience)

In [7]:
for epoch in tqdm(range(epochs)):
    model.train()
    for X, y in loader_train:
        X = X.to(device)
        y = y.to(device)

        opt.zero_grad()
        pred = model(X)
        cost = loss_fn(pred, y)
        cost.backward()
        opt.step()

    model.eval()
    with torch.no_grad():
        val_loss = sum(
            loss_fn(model(X.to(device)), y.to(device)) for X, y in loader_val
        )

    es.evaluate(model, val_loss)

    print(f"Epoch {epoch}: val_loss={val_loss / len(loader_val)}, early_stopping_count={es.counter}")

    if es.should_stop():
        model = es.load_best(model)
        break

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0: val_loss=0.77482670545578, early_stopping_count=0
Epoch 1: val_loss=0.7038728594779968, early_stopping_count=0
Epoch 2: val_loss=0.710848331451416, early_stopping_count=1
Epoch 3: val_loss=0.6841521263122559, early_stopping_count=0
Epoch 4: val_loss=0.6836827993392944, early_stopping_count=0
Epoch 5: val_loss=0.6807718276977539, early_stopping_count=0
Epoch 6: val_loss=0.6788797974586487, early_stopping_count=0
Epoch 7: val_loss=0.6805113554000854, early_stopping_count=1
Epoch 8: val_loss=0.680249810218811, early_stopping_count=2
Epoch 9: val_loss=0.6839407086372375, early_stopping_count=3


In [8]:
y_true_raw = []
y_pred_raw = []

model.eval()
with torch.no_grad():
    for X, y in loader_test:
        X = X.to(device)
        y = y.to(device)

        pred = model(X)

        y_true_raw.append(y)
        y_pred_raw.append(pred)

y_true_raw = torch.concatenate(y_true_raw).cpu().numpy()
y_pred_raw = torch.concatenate(y_pred_raw).cpu().numpy()

In [9]:
y_true = y_true_raw.argmax(axis=1)
y_pred = y_pred_raw.argmax(axis=1)
y_true, y_pred

(array([1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0,
        0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [10]:
confusion_matrix(y_true, y_pred)

array([[48,  0],
       [35,  5]])