In [1]:
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchinfo import summary
from torchvision import transforms, datasets
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import utils
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

utils.set_seeds()

#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

print(torch.__version__)
print(device)

2.1.0
cuda


In [2]:
# training
#LEARNING_RATE = 0.0001
LEARNING_RATE = 0.001
BATCH_SIZE = 1
NUM_EPOCHS = 30
schedule = True

In [3]:
import data_setup
from torchvision import transforms

simple_transform = transforms.Compose([
    transforms.ToPILImage(mode='F'),
    transforms.CenterCrop((320, 320)),
    transforms.ToTensor()
])

aug_transform = transforms.Compose([
    transforms.ToPILImage(mode='F'),
    # transforms.CenterCrop((360, 360)),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(size=(320, 320), scale=(0.3,1), ratio=(1, 1)),
    transforms.ToTensor()
])

train_dir = '../picai/train/'
test_dir = '../picai/test/'

train_dataloader, test_dataloader = data_setup.create_dataloaders(train_dir=train_dir,
                                                                  test_dir=test_dir,
                                                                  train_transform=aug_transform, 
                                                                  test_transform=simple_transform, 
                                                                  batch_size=BATCH_SIZE
                                                                  )

In [4]:
import vnetrec_model

model_name="vnet_rec_v0"
model = vnetrec_model.VNetRec(elu=True, se=True, input_ch=1, split_ch=4).to(device)
#model = torch.load("./models/vnet_rec_v0_last_trained_epoch.pth").to(device)
model.name = model_name

In [5]:
# Freeze parameters
# i = 0
# for param in model.parameters(): #120
#     i += 1
#     if i<30:
#         param.requires_grad = False

In [6]:
# import simple_vit_seg

# model = simple_vit_seg.SimpleSEGViT(
#     image_size = 320,          # image size
#     frames = 16,               # number of frames
#     image_patch_size = 16,     # image patch size
#     frame_patch_size = 1,      # frame patch size
#     dim = 512,
#     depth = 6,
#     heads = 8,
#     mlp_dim = 2048,
#     channels = 1,
#     mask_hidden_dim = 32
# ).to(device)

# model_name="vit_seg_organ_v0"

In [7]:
# # Create random input sizes
# random_input_image = (4, 1, 320, 320, 16)
# #random_input_image = (4, 1, 16, 320, 320)

# # Get a summary of the input and outputs of PatchEmbedding (uncomment for full output)
# summary(model,
#         input_size=random_input_image, # try swapping this for "random_input_image_error"
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"])

In [8]:
# random_input_image = torch.randn(BATCH_SIZE, 1, 320, 320, 16).to(device)

# test_output = model(random_input_image)

# test_output.shape

In [9]:
criterion = torch.nn.MSELoss(reduction='sum')

optimizer = torch.optim.Adam(model.parameters(),
                             betas = (0.9, 0.999),
                             lr=LEARNING_RATE,
                            #weight_decay=0.0001,
                            )

In [10]:
from torch.optim.lr_scheduler import CyclicLR

# step_size_up = 8*len(train_dataloader) # number of iterations per half-cycle, default 2000
# lr_scheduler = CyclicLR(optimizer, base_lr=LEARNING_RATE/20, max_lr=LEARNING_RATE, step_size_up=step_size_up, mode='exp_range', cycle_momentum=False)

if schedule == True:
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=LEARNING_RATE, epochs=NUM_EPOCHS, steps_per_epoch=len(train_dataloader))
else:
    lr_scheduler = None

In [11]:
import engine

results = engine.train(model=model,
                         train_dataloader=train_dataloader,
                         test_dataloader=test_dataloader,
                         optimizer=optimizer,
                         lr_scheduler=lr_scheduler,
                         criterion=criterion,
                         epochs=NUM_EPOCHS,
                         results_path="./results/",
                         models_path="./models/",
                         device=device)

 ## BEGIN TRAINING ## 
    Model:                 	 vnet_rec_v0
    Number of train batches:	 2517
    Number of test batches:	 769


  0%|          | 0/30 [00:00<?, ?it/s]

	 503 	 Loss: 2650371.5265 MAE: 1777053.1461
	 1006 	 Loss: 2604713.539 MAE: 1760952.1194
	 1509 	 Loss: 2560921.3012 MAE: 1743735.2201
	 2012 	 Loss: 2515869.217 MAE: 1725244.3125
Epoch: 1 | train_loss: 2468481.8759 | train_bcloss: 1706674.4621 | test_loss: 2236841.8787 |  test_bcloss: 1598171.1746


KeyError: 'train_bcloss'

In [12]:
utils.save_model(model=model,
               target_dir="./models/",
               model_name=model_name+"_weights.pth")

[INFO] Saving model to: models/vnet_rec_v0_weights.pth


###### final test evaluation
import engine

test_results = engine.test_step(model=model,
                         dataloader=test_dataloader,
                         loss_fn=loss_fn,
                         device=device)

print(
    f"test_loss: {test_results['loss']:.4f} |  "
    f"test_mae: {test_results['mae']:.4f}"
)