In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# from IPython.display import clear_output
# !pip install pandas tqdm torch scikit-learn
# clear_output()

In [4]:
import os
import argparse
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import models
from data.dataset import PatientDataset

In [5]:
images_path = "dataset/reduced/train"
labels_path = "dataset/reduced/train.csv"
batch_size = 8
model = "cnn"
cuda = True
lr = 1e-3
epochs = 100
depth = 100 # number of CT slices

In [6]:
entries = os.listdir(images_path)
patient_ids = [int(entry) for entry in entries if os.path.isdir(os.path.join(images_path, entry))]

df = pd.read_csv(labels_path)
df = df[df['patient_id'].isin(patient_ids)]
print(df.shape)

(2886, 6)


In [7]:
# Split the dataset into train, validation, and test sets
train_data, val_test_data = train_test_split(df, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=42)

print(train_data.shape)
print(val_data.shape)
print(test_data.shape)

(2308, 6)
(289, 6)
(289, 6)


In [8]:
train_data.head()

Unnamed: 0,patient_id,bowel,extravastion,kidney,liver,spleen
1426,38343,0,1,0,0,1
964,29412,0,0,0,0,0
2608,63193,0,0,0,0,0
662,23709,0,0,0,0,0
2408,58547,0,0,0,0,0


In [9]:
train_dataset = PatientDataset(images_path, train_data)
val_dataset = PatientDataset(images_path, val_data)
test_dataset = PatientDataset(images_path, test_data)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# get one sample
inputs, labels = train_dataset[0]

print(inputs.shape)
print(labels.shape)

torch.Size([1, 100, 128, 128])
torch.Size([5])


In [10]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(model, path):
    model.load_state_dict(torch.load(path))
    model.eval()
    print(f"Model loaded from {path}")

In [None]:
if model == "cnn":
    model = models.ConvNet3D(
        in_channels=inputs.shape[0],
        out_channels=labels.shape[0],
        depth=inputs.shape[1],
        height=inputs.shape[2],
        width=inputs.shape[3],
    )
elif model == "unet":
    model = ...
else:
    raise ValueError("Invalid model selected for training.")

if cuda:
    model = model.cuda()

# test save and loader functions
save_model(model, "checkpoints/model.pth")
load_model(model, "checkpoints/model.pth")

criterion = nn.BCEWithLogitsLoss()  # multi-label classification
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    running_loss = 0.0

    for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        inputs, labels = batch
        if cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        # print(inputs.shape)
        # zero gradients for every batch
        optimizer.zero_grad()
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    tprint(f"Epoch {epoch+1}, Loss: {running_loss / len(train_dataloader)}")
    if epoch % 2 == 0:
        save_model(model, f"checkpoints/model-e{epoch}.pth")


Model saved to checkpoints/model.pth
Model loaded from checkpoints/model.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 1, Loss: 0.3495239856114005
Model saved to checkpoints/model-e0.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 2, Loss: 0.26117522250423386


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 3, Loss: 0.2645952412505934
Model saved to checkpoints/model-e2.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 4, Loss: 0.26345168899306287


100%|██████████| 289/289 [04:30<00:00,  1.07it/s]


Epoch 5, Loss: 0.2581877589274207
Model saved to checkpoints/model-e4.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 6, Loss: 0.25751749083664016


100%|██████████| 289/289 [04:30<00:00,  1.07it/s]


Epoch 7, Loss: 0.25571891520765316
Model saved to checkpoints/model-e6.pth


100%|██████████| 289/289 [04:28<00:00,  1.08it/s]


Epoch 8, Loss: 0.2554730119696049


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 9, Loss: 0.25489510044868025
Model saved to checkpoints/model-e8.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 10, Loss: 0.2547092673417091


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 11, Loss: 0.2554032043866478
Model saved to checkpoints/model-e10.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 12, Loss: 0.25522790607838375


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 13, Loss: 0.25441717516927354
Model saved to checkpoints/model-e12.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 14, Loss: 0.25405139478961175


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 15, Loss: 0.254874172123555
Model saved to checkpoints/model-e14.pth


100%|██████████| 289/289 [04:28<00:00,  1.08it/s]


Epoch 16, Loss: 0.2544681899036866


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 17, Loss: 0.25442275511919443
Model saved to checkpoints/model-e16.pth


100%|██████████| 289/289 [04:28<00:00,  1.08it/s]


Epoch 18, Loss: 0.25430688460044926


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 19, Loss: 0.2546381137981545
Model saved to checkpoints/model-e18.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 20, Loss: 0.25423266874516726


100%|██████████| 289/289 [04:28<00:00,  1.08it/s]


Epoch 21, Loss: 0.254106835273955
Model saved to checkpoints/model-e20.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 22, Loss: 0.25433831633581


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 23, Loss: 0.25425578541602556
Model saved to checkpoints/model-e22.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 24, Loss: 0.2544335936779434


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 25, Loss: 0.2538652981538002
Model saved to checkpoints/model-e24.pth


100%|██████████| 289/289 [04:27<00:00,  1.08it/s]


Epoch 26, Loss: 0.25406735682453957


100%|██████████| 289/289 [04:26<00:00,  1.08it/s]


Epoch 27, Loss: 0.25434522645670454
Model saved to checkpoints/model-e26.pth


100%|██████████| 289/289 [04:29<00:00,  1.07it/s]


Epoch 28, Loss: 0.2543351226624838


100%|██████████| 289/289 [04:26<00:00,  1.08it/s]


Epoch 29, Loss: 0.25552565136595695
Model saved to checkpoints/model-e28.pth


 78%|███████▊  | 224/289 [03:27<00:56,  1.14it/s]