## Import Packages

In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision.models as models
from torchvision import datasets, transforms
# For utilities
import os, shutil, time
import cv2 as cv
import subprocess
from torch.multiprocessing import Pool, set_start_method

In [2]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()

# remove .ipynb_chaeckpoint files
subprocess.run('.././rm_ipynbcheckpoints.sh', shell=True, cwd='/home/kyang/Shared/Notebooks/Kevin/stpt2imc');

In [3]:
# code from https://amaarora.github.io/2020/09/13/unet.html#u-net
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
        self.batchnorm2d = nn.BatchNorm2d(in_ch)
    
    def forward(self, x):
        x = self.batchnorm2d(x)
        x = self.conv1(x)
        x = self.relu(x)
        return self.conv2(x)
#         return self.conv2(self.relu(self.conv1(self.batchnorm2d(x))))


class Encoder(nn.Module):
    def __init__(self, chs=(8,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 40)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)            
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(8,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.head = nn.Conv2d(dec_chs[-1], 40, 1)

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out = self.head(out)
        out = F.interpolate(out, (256, 256))
        return out

In [4]:
model = UNet().double()
criterion = nn.MSELoss()

# Move model and loss function to GPU
if use_gpu: 
    criterion = criterion.cuda()
    model = model.cuda()

# optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=0.01)

# https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

checkpoint = torch.load('../checkpoints/model-epoch-7-losses-282.914.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
inc = checkpoint['epoch'] + 1 # increment depending on how many epochs we already completed

In [5]:
class STPT_IMC_ImageFolder(datasets.ImageFolder):    
    """
    Preprocesses
    """
    def __init__(self, root, transform, bits=8, batch_size=64):
        self.root = root
        self.transform = transform
        self.imc_folder = os.path.join(self.root, 'IMC')
        self.stpt_folder = os.path.join(self.root, 'STPT')
        self.bits = bits # num bits for each pixel in image
        self.batch_size = batch_size
        
        # length of dataset will be the total number of files contained in all subdirectories inside self.imc_folder
        self.num_imgs_per_phys_sec = len(os.listdir(os.path.join(self.imc_folder, '01')))
        self.num_imgs = self.num_imgs_per_phys_sec * 15  # 15 physical sections
        
        self.index_to_phys_sec = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]  # skip phys_sec 16
            
    def __len__(self):
        return self.num_imgs
        
    def __getitem__(self, index):
        
        phys_sec = self.index_to_phys_sec[int(np.floor(index / self.num_imgs_per_phys_sec))]  # mod to find physical section
                                                         
        # ====== GET LIST OF IMAGE FILES ======
        stpt_imgs = os.listdir(os.path.join(self.stpt_folder,
                                    '{}'.format(str(phys_sec).zfill(2)))) 
                                                         
        imc_imgs = os.listdir(os.path.join(self.imc_folder,
                                           '{}'.format(str(phys_sec).zfill(2))))
        
        # ====== GET IMAGE FILE PATH ======
        stpt_path = os.path.join(self.stpt_folder,
                                           '{}'.format(str(phys_sec).zfill(2)),
                                           stpt_imgs[int(index % self.num_imgs_per_phys_sec)])
        
        imc_path = os.path.join(self.imc_folder,
                                          '{}'.format(str(phys_sec).zfill(2)),
                                          imc_imgs[int(index % self.num_imgs_per_phys_sec)])

        # make sure the files line up
        try:
            assert(os.path.basename(stpt_path) == os.path.basename(imc_path))
        except:
            print('stpt path:', os.path.basename(stpt_path))
            print('imc path:', os.path.basename(imc_path))
                                       
        # ====== LOAD IMAGES ======
        stpt_img = torch.load(stpt_path)

        imc_img = torch.load(imc_path)
                                                                     
        return stpt_img, imc_img   

In [6]:
# Training

# not using transforms
transform = None

train_imagefolder = STPT_IMC_ImageFolder(root='../data/train',
                                         transform=transform)
train_loader = torch.utils.data.DataLoader(train_imagefolder,
                                           batch_size=64,
                                           shuffle=True)

# Validation 
# val_transforms = transforms.Compose([transforms.Normalize(normalize_param, normalize_param)])
val_imagefolder = STPT_IMC_ImageFolder(root='../data/val',
                                       transform=transform)
val_loader = torch.utils.data.DataLoader(val_imagefolder,
                                         batch_size=64,
                                         shuffle=False)

In [7]:
class AverageMeter(object):
  '''A handy class from the PyTorch ImageNet tutorial''' 
  def __init__(self):
    self.reset()
    self.vals = []
    self.avgs = []
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count
    self.vals.append(self.val)
    self.avgs.append(self.avg)

In [8]:
def validate(val_loader, model, criterion, epoch, plot=True):
  print('='*10, 'Starting validation epoch {}'.format(epoch), '='*10) 
  model.eval()

  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()

  end = time.time()
  already_saved_images = False
  for i, (stpt, imc) in enumerate(val_loader):
    data_time.update(time.time() - end)

    # Use GPU
    if use_gpu: 
        stpt, imc = stpt.cuda(), imc.cuda()

    # Run model and record loss
    imc_recons = model(stpt.double()).cuda() # throw away class predictions
    loss = criterion(imc_recons.double(), imc.double())
    losses.update(loss.item(), stpt.size(0))

    # Record time to do forward passes and save images
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy -- in the code below, val refers to both value and validation
    if i % 25 == 0:
      print('Validate: [{0}/{1}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
             i, len(val_loader), batch_time=batch_time, loss=losses))
    
  return losses.avg

In [9]:
def train(train_loader, model, criterion, optimizer, epoch, plot=True):
  print('='*10, 'Starting training epoch {}'.format(epoch), '='*10)
  model.train()
  
  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()

  end = time.time()
  for i, (stpt, imc) in enumerate(train_loader):
    
    # Use GPU if available
    if use_gpu:
        stpt, imc = stpt.cuda(), imc.cuda()

    # Record time to load data (above)
    data_time.update(time.time() - end)

    # Run forward pass
    imc_recons = model(stpt.double()).cuda()
    loss = criterion(imc_recons.double(), imc.double()) 
    losses.update(loss.item(), stpt.size(0))

    # Compute gradient and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Record time to do forward and backward passes
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy -- in the code below, val refers to value, not validation
    if i % 25 == 0:
      print('Epoch: [{0}][{1}/{2}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
              epoch, i, len(train_loader), batch_time=batch_time,
             data_time=data_time, loss=losses)) 

In [None]:
if __name__ == '__main__':
    best_losses = 1e10
    epochs = 20

    # Train model
    for epoch in range(epochs):
      epoch += inc
      # Train for one epoch, then validate
      train(train_loader, model, criterion, optimizer, epoch)
      with torch.no_grad():
        losses = validate(val_loader, model, criterion, epoch)
      # Save checkpoint and replace old best model if current model is better
      if losses < best_losses:
        best_losses = losses
        torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': losses,
                    'epoch': epoch,
                    'loss': losses
                   }, '../checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))

Epoch: [7][0/1125]	Time 11.514 (11.514)	Data 7.681 (7.681)	Loss 311.0457 (311.0457)	
Epoch: [7][25/1125]	Time 9.374 (9.579)	Data 5.569 (5.772)	Loss 303.8727 (279.2831)	
Epoch: [7][50/1125]	Time 9.361 (9.615)	Data 5.554 (5.810)	Loss 286.9803 (281.5090)	
Epoch: [7][75/1125]	Time 9.894 (9.625)	Data 6.089 (5.820)	Loss 259.1132 (280.1500)	
Epoch: [7][100/1125]	Time 9.913 (9.672)	Data 6.109 (5.867)	Loss 298.8916 (281.5730)	
Epoch: [7][125/1125]	Time 9.345 (9.648)	Data 5.532 (5.842)	Loss 206.5334 (277.5395)	
Epoch: [7][150/1125]	Time 9.777 (9.644)	Data 5.973 (5.839)	Loss 334.1176 (277.9638)	
Epoch: [7][175/1125]	Time 9.794 (9.683)	Data 5.989 (5.877)	Loss 288.3357 (277.9542)	
Epoch: [7][200/1125]	Time 9.728 (9.700)	Data 5.918 (5.895)	Loss 245.4786 (277.4311)	
Epoch: [7][225/1125]	Time 10.526 (9.726)	Data 6.719 (5.921)	Loss 220.3418 (276.3437)	
