In [1]:
%load_ext autoreload
%autoreload 2

## CROP IMAGES
- Crops image pairs from stpt2imc/data/{IMC, STPT}/ and saves cropped images to disk

In [11]:
from img_processing import crop_imc, crop_stpt

for i in range(10, 11): # for loop indices correspond to the physical sections that will be processed (18)
    if i == 16:
        continue  # physical section 16 is defective
    crop_imc(i)
    crop_stpt(i)

IMC: Done physical section: 10
STPT: Done physical section: 10


## MODEL TRAINING

In [2]:
from models import UNet, PointSetGen
from datasets import STPT_IMC_ImageFolder
from fitting import train_model, validate_model

import os
import math
import torch
import torch.nn as nn

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
use_gpu = torch.cuda.is_available()

model = UNet().double()
criterion = nn.MSELoss()

if use_gpu: 
    model = model.cuda()
    criterion = criterion.cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=0.01)

########### SPLIT TRAIN AND VAL ###########

img_folder = STPT_IMC_ImageFolder(root='processed_data')
train_size = math.floor(len(img_folder) * .8)
val_size = len(img_folder) - train_size
train_data, val_data = torch.utils.data.random_split(img_folder, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=64,
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(val_data,
                                         batch_size=64,
                                         shuffle=False)

In [None]:
if __name__ == '__main__':
    best_losses = 1e10
    epochs = 100

    # Train model
    prev_chkpt_file = None
    for epoch in range(epochs):
        # Train for one epoch, then validate
        train_model(train_loader, model, criterion, optimizer, epoch, use_gpu=use_gpu, mod=10)
        with torch.no_grad():
            losses = validate_model(val_loader, model, criterion, epoch, use_gpu=use_gpu, mod=10)
        # Save checkpoint and replace old best model if current model is better
        if losses < best_losses:
            best_losses = losses
            chkpt_file = 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses)
            torch.save({'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': losses,
                        'epoch': epoch,
                        'loss': losses
                       }, chkpt_file)
            
            # only keep the best model
            if prev_chkpt_file:
                os.remove(prev_chkpt_file)
                prev_chkpt_file = chkpt_file
            else:
                prev_chkpt_file = chkpt_file