## Import Packages

In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision.models as models
from torchvision import datasets, transforms
# For utilities
import os, shutil, time
import cv2 as cv
import subprocess
from torch.multiprocessing import Pool, set_start_method

In [2]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True

# remove .ipynb_chaeckpoint files
subprocess.run('.././rm_ipynbcheckpoints.sh', shell=True, cwd='/home/kyang/Shared/Notebooks/Kevin/stpt2imc');

In [3]:
class Block3(nn.Module):
    '''
    Module consisting of 3 convolutional layers
    '''
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=2),  # first stride is always 2
            nn.ReLU(),
            nn.BatchNorm2d(out_ch),
            
            nn.Conv2d(out_ch, out_ch, kernel_size=3),  # constant kernel size from here
            nn.ReLU(),
            nn.BatchNorm2d(out_ch),
            
            nn.Conv2d(out_ch, out_ch, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(out_ch)
        )
    
    def forward(self, x):
        return self.layers(x)

    
class PointSetGen(nn.Module):
    def __init__(self, in_ch=8, batch_size=8):
        super().__init__()
        self.batch_size = batch_size
        self.relu = nn.ReLU()
        
        # ====== ENCODER 1 ======
        
        self.beginning = nn.Sequential(
            nn.BatchNorm2d(8),
            
            nn.Conv2d(8, 16, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            
            nn.Conv2d(16, 16, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(16)
        )
        
        self.block3_1 = Block3(16, 32)
        self.block3_2 = Block3(32, 64)
        self.block3_3 = Block3(64, 128)
        self.block3_4 = Block3(128, 256, kernel_size=5)
        self.upblock = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1))
        
        # ====== DECODER 1 ======
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        self.fully_connected1 = nn.Sequential(
            nn.Flatten(-2, -1),
            nn.Linear(4, 2048),
            nn.ReLU()
        )
        self.deconv1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=5),
            nn.Upsample(scale_factor=2),
            nn.BatchNorm2d(256)
        )
        
        self.skip1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3),
            nn.Upsample((12, 12))
        )
        self.comb1 = nn.Conv2d(256, 256, kernel_size=3)
        self.blue1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=5),
            nn.Upsample(scale_factor=2)
        )
        
        self.skip2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3),
            nn.Upsample((28, 28))
        )
        self.comb2 = nn.Conv2d(128, 128, kernel_size=3)
        self.blue2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=5),
            nn.Upsample(scale_factor=2)
        )
        
        self.skip3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3),
            nn.Upsample((60, 60))
        )
        self.comb3 = nn.Conv2d(64, 64, kernel_size=3)
        self.blue3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=5),
            nn.Upsample(scale_factor=2)
        )   
        
        self.skip4 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3),
            nn.Upsample((124, 124))
        )
        self.comb4 = nn.Conv2d(32, 32, kernel_size=3)
        self.blue4 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=5),
            nn.Upsample(scale_factor=2)
        )  
        
        self.skip5 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3),
            nn.Upsample((252, 252))
        )
        self.comb5 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=3, stride=2)
        )
        
        # ====== ENCODER 2 ======
        
        self.enc_skip1 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3),
            nn.Upsample((124, 124))
        )
        self.enc_comb1 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=2)
        )
        
        self.enc_skip2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.Upsample((60, 60))
        )
        self.enc_comb2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )
        self.enc_last2 = nn.Conv2d(64, 128, kernel_size=5, stride=2)
        
        self.enc_skip3 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3),
            nn.Upsample((27, 27))
        )
        self.enc_comb3 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )
        self.enc_last3 = nn.Conv2d(128, 256, kernel_size=5, stride=2)  
        
        self.enc_skip4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3),
            nn.Upsample((11, 11))
        )
        self.enc_comb4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(256)
        )
        self.enc_last4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        
        # ====== PREDICTOR ======

        self.fully_connected2 = nn.Linear(2048, 2048)
        self.fully_connected3 = nn.Sequential(
            nn.Flatten(-2, -1),
            nn.Linear(9, 2048)
        )

        self.dec_blue1 = nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2)
        self.dec_skip1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3),
            nn.Upsample((9, 9))
        )
        self.convdeconv1 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2),
            nn.Upsample(scale_factor=2)
        )

        self.dec_skip2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3),
            nn.Upsample((34, 34))
        )
        self.convdeconv2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(128),

            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.Upsample(scale_factor=2)
        )
        
        self.dec_skip3 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.Upsample((134, 134))
        )
        self.convdeconv3 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 64, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )
        
        self.fully_connected4 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            
            nn.Linear(1024, 300)
#             nn.Linear(1024, 49152)
        )
        
        self.finalconv_full = nn.Conv2d(512, 64, kernel_size=1)
        self.finalconv_deconv = nn.Sequential(
#             nn.Conv2d(64, 512, kernel_size=3),
#             nn.Upsample((45, 45))
            nn.Upsample((30,30))
        )
        
        self.mlp = nn.Conv2d(64, 40, kernel_size=1)

        
    def forward(self, x):
        
        # ====== ENCODER 1 ======
        x = self.beginning(x)
        x = self.block3_1(x)
        x1 = x    # can do this because torch returns new tensors for operations like nn.Conv2d
        
        # sequence of blocks of 3 convolutional layers
        x = self.block3_2(x) 
        x2 = x
        x = self.block3_3(x) 
        x3 = x
        x = self.block3_4(x) 
        x4 = x

        # substitute for block of 4 conv. layers b/c convolutions make images too small
        x = self.upblock(x)
        x5 = x
        
        # ====== DECODER 1 ======
        
        x = self.conv1(x)
        x_additional = self.fully_connected1(x)  # save for fully connected layer
        x = self.deconv1(x)
        
        x5 = self.skip1(x5)
        x = self.relu(torch.add(x, x5))  # torch.Size([1, 256, 12, 12])
        x = self.relu(self.comb1(x))
        x5 = x
        x = self.blue1(x)
        
        x4 = self.skip2(x4)
        x = self.relu(torch.add(x, x4))
        x = self.relu(self.comb2(x))
        x4 = x
        x = self.blue2(x)
        
        x3 = self.skip3(x3)
        x = self.relu(torch.add(x, x3))
        x = self.relu(self.comb3(x))
        x3 = x
        x = self.blue3(x)

        x2 = self.skip4(x2)
        x = self.relu(torch.add(x, x2))
        x = self.relu(self.comb4(x))
        x2 = x
        x = self.blue4(x)   
        
        x1 = self.skip5(x1)
        x = self.relu(torch.add(x, x1))
        x = self.comb5(x)
        
        # ====== ENCODER 2 ======
        # the function name and variable names should be off by 1
        x2 = self.enc_skip1(x2)
        x = self.relu(torch.add(x, x2))
        x = self.enc_comb1(x)
        
        x3 = self.enc_skip2(x3)
        x = self.relu(torch.add(x, x3))
        x = self.enc_comb2(x)
        x3 = x
        x = self.enc_last2(x)
        
        x4 = self.enc_skip3(x4)
        x = self.relu(torch.add(x, x4))
        x = self.enc_comb3(x)
        x4 = x
        x = self.enc_last3(x)
        
        x5 = self.enc_skip4(x5)
        x = self.relu(torch.add(x, x5))
        x = self.enc_comb4(x)
        x5 = x
        x = self.enc_last4(x)
        
        # ====== PREDICTOR ======
        
        x_additional = self.fully_connected2(x_additional)
        x_additional = self.relu(torch.add(x_additional, self.fully_connected3(x)))
        
        x = self.dec_blue1(x)
        x5 = self.dec_skip1(x5)
        x = self.relu(torch.add(x, x5))
        x = self.convdeconv1(x)
        
        x4 = self.dec_skip2(x4)
        x = self.relu(torch.add(x, x4))
        x = self.convdeconv2(x)
        
        x3 = self.dec_skip3(x3)
        x = self.relu(torch.add(x, x3))
        x = self.convdeconv3(x)
        
        x_additional = self.fully_connected4(x_additional) # torch.Size([1, 512, 600])
        x_additional = torch.reshape(x_additional, (self.batch_size, 512, 100, 3))
        x_additional = self.finalconv_full(x_additional)
        x = self.finalconv_deconv(x)
        x = torch.reshape(x, (self.batch_size, 64, 300, 3))
        x = torch.cat((x_additional, x), 2)
    
        uv = torch.meshgrid(torch.arange(0, 256), torch.arange(0, 256))
        uv = torch.stack(uv).permute(1,2,0).type(torch.uint8).cuda()  # [256, 256, 2]
        xy = torch.sum(x.type(torch.float32), dim=1) # [self.batch_size, 40, 875, 3]
#         xy = torch.sum(x.squeeze().type(torch.float32), dim=0) # [40, 875, 3]

#         img = torch.exp(((uv[None,:,:,0]-xy[:,None,None,0])**2 + (uv[None,:,:,1]-xy[:,None,None,1])**2) / (xy[:,None,None,2]**2 + 1))  # [875,256,256]
        img = torch.exp(((uv[None,None,:,:,0]-xy[:,:,None,None,0])**2 + (uv[None,None,:,:,1]-xy[:,:,None,None,1])**2) / (xy[:,:,None,None,2]**2 + 1))  # [875,256,256]
#         x = self.mlp(x).squeeze()
        x = self.mlp(x)
        x = torch.sum(x, dim=-1)
        x = x[:,:,:,None,None] * img[:,None,:,:,:]
        x = torch.sum(x, dim=2)
        
        return x

In [4]:
batch_size = 1
# model = PointSetGen(batch_size=batch_size).double()
model = PointSetGen(batch_size=batch_size).double()
        
criterion = nn.MSELoss()

# Move model and loss function to GPU
if use_gpu: 
    criterion = criterion.cuda()
    model = model.cuda()

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

# https://pytorch.org/tutorials/beginner/saving_loading_models.html#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training

# checkpoint = torch.load('../checkpoints/model-epoch-5-losses-285.453.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# inc = checkpoint['epoch'] + 1 # increment depending on how many epochs we already completed
inc = 0

In [5]:
class STPT_IMC_ImageFolder(datasets.ImageFolder):    
    """
    Preprocesses
    """
    def __init__(self, root, transform, bits=8, batch_size=64):
        self.root = root
        self.transform = transform
        self.imc_folder = os.path.join(self.root, 'IMC')
        self.stpt_folder = os.path.join(self.root, 'STPT')
        self.bits = bits # num bits for each pixel in image
        
        # length of dataset will be the total number of files contained in all subdirectories inside self.imc_folder
        self.num_imgs_per_phys_sec = len(os.listdir(os.path.join(self.imc_folder, '01')))
        self.num_imgs = self.num_imgs_per_phys_sec * 15  # 15 physical sections
        
        self.index_to_phys_sec = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]  # skip phys_sec 16
            
    def __len__(self):
        return self.num_imgs
        
    def __getitem__(self, index):
        
        phys_sec = self.index_to_phys_sec[int(np.floor(index / self.num_imgs_per_phys_sec))]  # mod to find physical section
                                                         
        # ====== GET LIST OF IMAGE FILES ======
        stpt_imgs = os.listdir(os.path.join(self.stpt_folder,
                                    '{}'.format(str(phys_sec).zfill(2)))) 
                                                         
        imc_imgs = os.listdir(os.path.join(self.imc_folder,
                                           '{}'.format(str(phys_sec).zfill(2))))
        
        # ====== GET IMAGE FILE PATH ======
        stpt_path = os.path.join(self.stpt_folder,
                                           '{}'.format(str(phys_sec).zfill(2)),
                                           stpt_imgs[int(index % self.num_imgs_per_phys_sec)])
        
        imc_path = os.path.join(self.imc_folder,
                                          '{}'.format(str(phys_sec).zfill(2)),
                                          imc_imgs[int(index % self.num_imgs_per_phys_sec)])

        # make sure the files line up
        try:
            assert(os.path.basename(stpt_path) == os.path.basename(imc_path))
        except:
            print('stpt path:', os.path.basename(stpt_path))
            print('imc path:', os.path.basename(imc_path))
                                       
        # ====== LOAD IMAGES ======
#         stpt_img = self.transform[0](torch.load(stpt_path))  
        stpt_img = torch.load(stpt_path)

#         imc_img = self.transform[1](torch.load(imc_path))
        imc_img = torch.load(imc_path)     
                                                                     
        return stpt_img.double(), imc_img.double()   

In [6]:
# Training
stpt_normalize_param = [0.5 for i in range(8)]
imc_normalize_param = [0.5 for i in range(40)]
transform = [transforms.Normalize(stpt_normalize_param, stpt_normalize_param),
              transforms.Normalize(imc_normalize_param, imc_normalize_param)]

train_imagefolder = STPT_IMC_ImageFolder(root='../data/train',
                                         transform=transform)
train_loader = torch.utils.data.DataLoader(train_imagefolder,
                                           batch_size=batch_size,
                                           shuffle=True)

# Validation 
# val_transforms = transforms.Compose([transforms.Normalize(normalize_param, normalize_param)])
val_imagefolder = STPT_IMC_ImageFolder(root='../data/val',
                                       transform=transform)
val_loader = torch.utils.data.DataLoader(val_imagefolder,
                                         batch_size=batch_size,
                                         shuffle=False)

In [7]:
class AverageMeter(object):
  '''A handy class from the PyTorch ImageNet tutorial''' 
  def __init__(self):
    self.reset()
    self.vals = []
    self.avgs = []
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count
    self.vals.append(self.val)
    self.avgs.append(self.avg)

In [8]:
def validate(val_loader, model, criterion, epoch, plot=True):
  print('='*10, 'Starting validation epoch {}'.format(epoch), '='*10) 
  model.eval()

  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()

  end = time.time()
  already_saved_images = False
  for i, (stpt, imc) in enumerate(val_loader):
    data_time.update(time.time() - end)

    # Use GPU
    if use_gpu: 
        stpt, imc = stpt.cuda(), imc.cuda()

    # Run model and record loss
    imc_recons = model(stpt.double()).cuda() # throw away class predictions
    loss = criterion(imc_recons.double(), imc.double())
    losses.update(loss.item(), stpt.size(0))

    # Record time to do forward passes and save images
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy -- in the code below, val refers to both value and validation
    if i % 25 == 0:
      print('Validate: [{0}/{1}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
             i, len(val_loader), batch_time=batch_time, loss=losses))
    
  return losses.avg

In [9]:
def train(train_loader, model, criterion, optimizer, epoch, plot=True):
  print('='*10, 'Starting training epoch {}'.format(epoch), '='*10)
  model.train()
  
  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()

  end = time.time()
  for i, (stpt, imc) in enumerate(train_loader):
    
    # Use GPU if available
    if use_gpu:
        stpt, imc = stpt.cuda(), imc.cuda()

    # Record time to load data (above)
    data_time.update(time.time() - end)

    imc_recons = model(stpt).cuda()
    loss = criterion(imc_recons, imc)

    losses.update(loss.item(), stpt.size(0))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Record time to do forward and backward passes
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy -- in the code below, val refers to value, not validation
    if i % 25 == 0:
      print('Epoch: [{0}][{1}/{2}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
              epoch, i, len(train_loader), batch_time=batch_time,
             data_time=data_time, loss=losses)) 

In [None]:
if __name__ == '__main__':
    best_losses = 1e10
    epochs = 20

    # Train model
    for epoch in range(epochs):
      epoch += inc
      # Train for one epoch, then validate
      train(train_loader, model, criterion, optimizer, epoch)
      with torch.no_grad():
        losses = validate(val_loader, model, criterion, epoch)
      # Save checkpoint and replace old best model if current model is better
      if losses < best_losses:
        best_losses = losses
        torch.save({'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': losses,
                    'epoch': epoch,
                    'loss': losses
                   }, '../checkpoints/pointsetgen/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))

Epoch: [0][0/71940]	Time 1.285 (1.285)	Data 0.123 (0.123)	Loss nan (nan)	
Epoch: [0][25/71940]	Time 0.348 (0.390)	Data 0.096 (0.101)	Loss nan (nan)	
Epoch: [0][50/71940]	Time 0.377 (0.377)	Data 0.126 (0.106)	Loss nan (nan)	
Epoch: [0][75/71940]	Time 0.331 (0.366)	Data 0.079 (0.101)	Loss nan (nan)	
Epoch: [0][100/71940]	Time 0.337 (0.360)	Data 0.086 (0.099)	Loss nan (nan)	
Epoch: [0][125/71940]	Time 0.344 (0.357)	Data 0.092 (0.097)	Loss nan (nan)	
Epoch: [0][150/71940]	Time 0.346 (0.356)	Data 0.094 (0.097)	Loss nan (nan)	
Epoch: [0][175/71940]	Time 0.330 (0.354)	Data 0.079 (0.096)	Loss nan (nan)	
Epoch: [0][200/71940]	Time 0.335 (0.354)	Data 0.083 (0.098)	Loss nan (nan)	
Epoch: [0][225/71940]	Time 0.352 (0.354)	Data 0.101 (0.098)	Loss nan (nan)	
Epoch: [0][250/71940]	Time 0.344 (0.353)	Data 0.093 (0.097)	Loss nan (nan)	
Epoch: [0][275/71940]	Time 0.441 (0.353)	Data 0.189 (0.098)	Loss nan (nan)	
Epoch: [0][300/71940]	Time 0.366 (0.354)	Data 0.114 (0.099)	Loss nan (nan)	
Epoch: [0][325/71