In [1]:
from utils.image_utils import plot_XY, gen_index_file
from unet.dataset import SegThorImagesDataset
import torch
from torch.utils.data import DataLoader, random_split
from unet.unet_model import UNet
from unet.simplified_unet_model import SimplifiedUNet
from torch import optim
from torch import nn
import time


In [2]:
# if index file doesn't exist, generate and save 
index_file_train = gen_index_file()

Filename: data/train_patient_idx.csv already exists, skipping gen


In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. ")
else:
    device = torch.device("cpu")
    print("GPU is not available, using CPU instead.")

GPU is not available, using CPU instead.


In [4]:
# define model
# define training loop 

def train_model(data_dir: str = '/home/jupyter/ecs271_data/data/train',
                epochs=10,
                dropout=0.2,
                lr=0.0001):

    input_dataset = SegThorImagesDataset(
        root_dir=data_dir,
        img_crop_size=312, 
        mask_output_size=220
        ) 
    train_dataset, valid_dataset = random_split(input_dataset, [0.8, 0.2], generator=torch.Generator().manual_seed(42))
    train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True)
    valid_dl = DataLoader(valid_dataset, batch_size=32, shuffle=True)
    
    model = SimplifiedUNet(n_channels=1, n_classes=5, dropout=dropout) # 0: no-classification 1: organ, 2: organ, 3: organ, 4: organ

    # TODO: loss function that is SegThor paper
    criterion = torch.nn.CrossEntropyLoss()  
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    epoch_train_losses = []
    epoch_val_losses = []

    # TODO: checkpointing of model
    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        running_loss = 0.0
        for idx, sample in enumerate(train_dl):
            start_time = time.time()
            inputs, targets = sample
            inputs = inputs.to(device)
            outputs = model(inputs)
            outputs = outputs.cpu()
            loss = criterion(outputs, targets.long())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() 

            print(f'Train: {idx}/{len(train_dl)}: {time.time() - start_time}')
            break
        train_loss = running_loss / len(train_dl)
        epoch_train_losses.append(train_loss)

        model.eval()
        running_val_loss = 0.0
        for idx, sample in enumerate(valid_dl):
            start_time = time.time()
            inputs, targets = sample
            inputs=inputs.to(device)
            outputs=model(inputs)
            val_loss = criterion(outputs, targets.long())
            running_val_loss += val_loss.item()
            print(f'Validation: {idx}/{len(train_dl)}: {time.time() - start_time}')
            break
        validation_loss = running_val_loss / len(valid_dl)
        epoch_val_losses.append(validation_loss)
        
        print(f'Epoch {epoch + 1} | Duration: {time.time()- epoch_start_time} | Train Loss: {train_loss} | Validation Loss: {validation_loss}')

    # log the epoch_val_losses and epoch_train_losses
    return model, epoch_val_losses, epoch_train_losses

Train: 0/186: 37.88753628730774
Epoch 1 | Duration: -74.55666208267212 | Train Loss: 0.008972191682425879 | Validation Loss: 0.033516620067839925
Train: 0/186: 37.97469902038574
Epoch 2 | Duration: -75.41443395614624 | Train Loss: 0.008851199380813106 | Validation Loss: 0.03359932087837381
Train: 0/186: 37.78674077987671
Epoch 3 | Duration: -75.73859000205994 | Train Loss: 0.008726408404688682 | Validation Loss: 0.033706550902508674
Train: 0/186: 37.801799297332764
Epoch 4 | Duration: -79.12963676452637 | Train Loss: 0.008628254936587426 | Validation Loss: 0.0338451938426241
Train: 0/186: 38.45128917694092
Epoch 5 | Duration: -78.24546790122986 | Train Loss: 0.008507203030329879 | Validation Loss: 0.03380670446030637
