In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
# from IPython.display import clear_output
# !pip install pandas tqdm torch scikit-learn
# clear_output()

In [28]:
import os
import argparse
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import models
from data.dataset import PatientDataset

In [29]:
images_path = "dataset/reduced/train"
labels_path = "dataset/reduced/train.csv"
batch_size = 1
model = "cnn"
cuda = False
lr = 1e-3
epochs = 100
depth = 100 # number of CT slices

In [30]:
entries = os.listdir(images_path)
patient_ids = [int(entry) for entry in entries if os.path.isdir(os.path.join(images_path, entry))]

df = pd.read_csv(labels_path)
df = df[df['patient_id'].isin(patient_ids)]
print(df.shape)

(2886, 6)


In [31]:
# Split the dataset into train, validation, and test sets
train_data, val_test_data = train_test_split(df, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=42)

print(train_data.shape)
print(val_data.shape)
print(test_data.shape)

(2308, 6)
(289, 6)
(289, 6)


In [32]:
train_data.head()

Unnamed: 0,patient_id,bowel,extravastion,kidney,liver,spleen
1426,38343,0,1,0,0,1
964,29412,0,0,0,0,0
2608,63193,0,0,0,0,0
662,23709,0,0,0,0,0
2408,58547,0,0,0,0,0


In [33]:
train_dataset = PatientDataset(images_path, train_data)
val_dataset = PatientDataset(images_path, val_data)
test_dataset = PatientDataset(images_path, test_data)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# get one sample
inputs, labels = train_dataset[0]

print(inputs.shape)
print(labels.shape)

torch.Size([1, 100, 128, 128])
torch.Size([5])


In [34]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(model, path):
    model.load_state_dict(torch.load(path))
    model.eval()
    print(f"Model loaded from {path}")

In [35]:
if model == "cnn":
    model = models.ConvNet3D(
        in_channels=inputs.shape[0],
        out_channels=labels.shape[0],
        depth=inputs.shape[1],
        height=inputs.shape[2],
        width=inputs.shape[3],
    )
elif model == "unet":
    model = ...
else:
    raise ValueError("Invalid model selected for training.")

if cuda:
    model = model.cuda()

# test save and loader functions
os.path.exists("checkpoints")
save_model(model, "checkpoints/test.pth")
load_model(model, "checkpoints/test.pth")

criterion = nn.BCEWithLogitsLoss()  # multi-label classification
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    running_loss = 0.0

    for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        inputs, labels = batch
        if cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        # print(inputs.shape)
        # zero gradients for every batch
        optimizer.zero_grad()
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        print(outputs)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_dataloader)}")
    if epoch % 2 == 0:
        save_model(model, f"checkpoints/model-e{epoch}.pth")


Model saved to checkpoints/model.pth
Model loaded from checkpoints/model.pth


  0%|          | 0/2308 [00:00<?, ?it/s]

tensor([[-0.0007,  0.0009,  0.0050,  0.0098,  0.0115]],
       grad_fn=<AddmmBackward0>)


  0%|          | 1/2308 [00:06<3:55:39,  6.13s/it]

tensor([[-864.6986, -759.8212, -860.5419, -829.6760, -825.0506]],
       grad_fn=<AddmmBackward0>)


  0%|          | 2/2308 [00:12<4:04:58,  6.37s/it]

tensor([[-0.0759,  0.1034, -0.0656, -0.0632, -0.0635]],
       grad_fn=<AddmmBackward0>)


  0%|          | 3/2308 [00:19<4:15:06,  6.64s/it]

tensor([[-6.7401,  8.7230, -6.1609, -6.1953, -6.0039]],
       grad_fn=<AddmmBackward0>)


  0%|          | 4/2308 [00:26<4:18:41,  6.74s/it]

tensor([[-2.2931, -0.4756, -2.1336, -2.1198, -2.0381]],
       grad_fn=<AddmmBackward0>)


  0%|          | 5/2308 [00:40<5:13:10,  8.16s/it]


KeyboardInterrupt: 

# Inference