## Import Packages

In [1]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from skimage import io
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision.models as models
from torchvision import datasets, transforms
# For utilities
import os, shutil, time
import cv2 as cv
import subprocess
from torch.multiprocessing import Pool, set_start_method

In [2]:
# Check if GPU is available
use_gpu = torch.cuda.is_available()

# remove .ipynb_chaeckpoint files
subprocess.run('./rm_ipynbcheckpoints.sh', shell=True);

In [3]:
# code from https://amaarora.github.io/2020/09/13/unet.html#u-net
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs=(8,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 40)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)            
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(8,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.head = nn.Conv2d(dec_chs[-1], 40, 1)

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out = self.head(out)
        out = F.interpolate(out, (256, 256))
        return out

In [4]:
model = UNet().double()

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0.0)

In [5]:
class STPT_IMC_ImageFolder(datasets.ImageFolder):    
    """
    Preprocesses
    """
    def __init__(self, root, transform, bits=8, batch_size=64):
        self.root = root
        self.transform = transform
        self.imc_folder = os.path.join(self.root, 'IMC')
        self.stpt_folder = os.path.join(self.root, 'STPT')
        self.bits = bits # num bits for each pixel in image
        self.batch_size = batch_size
        
    def __len__(self):
        # length of dataset dictated by num aligned IMC images b/c len(IMC) < len(STPT)
        return len(os.listdir(self.imc_folder))
        
    def __getitem__(self, index):
        
        index += 1  # covers the case when index=0
        # ====== GET IMAGE PATHS ======
        # define folder paths for physical section defined by index
        imc_section_folder = os.path.join(self.imc_folder,
                                          'SECTION_{}'.format(str(index).zfill(2)))
        
        # get a list of all .tif images inside imc_section_folder
        imc_img_paths = [os.path.join(imc_section_folder, imc_img_path)
                         for imc_img_path in os.listdir(imc_section_folder)
                         if imc_img_path.endswith('.tif')]
        
        # get path to stpt images
        stpt_img_paths = [os.path.join(self.stpt_folder,
                                       'S{0}_Z{1}.tif'.format(str(index).zfill(3),
                                                          optical_section.zfill(2)))
                          for optical_section in ['0', '1']] 
        
        
        # ====== LOAD IMAGES ======
        
        with Pool(maxtasksperchild=100) as p:
            imc_imgs = list(p.imap(self.process_imc_image, imc_img_paths))
            stpt_imgs = list(p.imap(self.process_stpt_image, stpt_img_paths))
            
        # postprocess loaded images
        imc_imgs = [torch.unsqueeze(img, 0) for img in imc_imgs] # add an extra dimesion for channel
        imc_imgs_cat = torch.cat(imc_imgs, 0) # (40, 18720, 18720)
        
        stpt_imgs = [img.permute((2,0,1)) for img in stpt_imgs] # (C,H,W) tensor
        stpt_imgs_cat = torch.cat(stpt_imgs, 0) # concatenate two stpt images (8, 20800, 20800)
        
        
        # ====== TRANSFORMS ======
        
        stpt_imgs_cat = transforms.Resize(imc_imgs[0].shape[1])(stpt_imgs_cat)  # make STPT img same size as IMC (..., 18720, 18720)
        combine = torch.cat((imc_imgs_cat, stpt_imgs_cat), 0) # combine imc and stpt -> (48, 18720, 18720)
        
        # obtain a batch of random crops
        img_set = [self.transform(combine) for i in range(self.batch_size)]
            
        # separate imc and stpt -> (40, 18720, 18720), (8, 18720, 18720)
        imc_imgs = [torch.split(img, 40)[0] for img in img_set]
        stpt_imgs = [torch.split(img, 40)[1] for img in img_set]
        
        return stpt_imgs, imc_imgs
    
    def process_stpt_image(self, file_name):
        img = io.imread(file_name)
        return torch.from_numpy(img.astype('uint8'))
    
    def process_imc_image(self, file_name):
        # read image file
        img = cv.imread(file_name, cv.IMREAD_UNCHANGED)

        # normalize image
        norm_img = img.copy()
        cv.normalize(img, norm_img, alpha=0, beta=2**self.bits - 1, norm_type=cv.NORM_MINMAX)

        # Apply log transformation method
        c = (2**self.bits - 1) / np.log(1 + np.max(norm_img))
        
        log_image = c * (np.log(norm_img + 1))
        
        # Specify the data type so that
        # float value will be converted to int
        return torch.from_numpy(log_image.astype('uint8'))

In [6]:
def merge_cropped_images(batch):
    """
    takes in a batch of cropped images from a single physical section and forms
    a mini-batch for the DataLoader class
    """
    stpt_imgs = batch[0][0]
    imc_imgs = batch[0][1]
    return torch.stack(stpt_imgs), torch.stack(imc_imgs)

# Training
train_transforms = transforms.Compose([transforms.RandomCrop(256)])
train_imagefolder = STPT_IMC_ImageFolder(root='data/train',
                                         transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_imagefolder,
                                           batch_size=1,
                                           shuffle=True,
                                           collate_fn=merge_cropped_images)

# Validation 
# val_transforms = transforms.Compose([transforms.Resize(256)])
# val_imagefolder = STPT_IMC_ImageFolder(root='data/val', transform=val_transforms)
# val_loader = torch.utils.data.DataLoader(val_imagefolder, batch_size=1, shuffle=False)

In [7]:
class AverageMeter(object):
  '''A handy class from the PyTorch ImageNet tutorial''' 
  def __init__(self):
    self.reset()
  def reset(self):
    self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
  def update(self, val, n=1):
    self.val = val
    self.sum += val * n
    self.count += n
    self.avg = self.sum / self.count

In [8]:
def validate(val_loader, model, criterion, epoch):
  print('Starting validation epoch {}'.format(epoch)) 
  model.eval()

  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()

  end = time.time()
  already_saved_images = False
  for i, (stpt, imc) in enumerate(val_loader):
    data_time.update(time.time() - end)

    # Use GPU
    if use_gpu: 
        stpt, imc = stpt.cuda(), imc.cuda()

    # Run model and record loss
    imc_recons = model(stpt) # throw away class predictions
    loss = criterion(imc_recons, imc)
    losses.update(loss.item(), stpt.size(0))

    # Record time to do forward passes and save images
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy -- in the code below, val refers to both value and validation
    if i % 25 == 0:
      print('Validate: [{0}/{1}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
             i, len(val_loader), batch_time=batch_time, loss=losses))

  print('Finished validation.')
  return losses.avg

In [9]:
def train(train_loader, model, criterion, optimizer, epoch):
  print('='*10, 'Starting training epoch {}'.format(epoch), '='*10)
  model.train()
  
  # Prepare value counters and timers
  batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()

  end = time.time()
  for i, (stpt, imc) in enumerate(train_loader):
    print('**Training iteration {}**'.format(i))
    
    # Use GPU if available
    if use_gpu:
        stpt, imc = stpt.cuda(), imc.cuda()

    # Record time to load data (above)
    data_time.update(time.time() - end)

    # Run forward pass
    imc_recons = model(stpt.double()).cuda()
    loss = criterion(imc_recons.double(), imc.double()) 
    losses.update(loss.item(), stpt.size(0))

    # Compute gradient and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Record time to do forward and backward passes
    batch_time.update(time.time() - end)
    end = time.time()

    # Print model accuracy -- in the code below, val refers to value, not validation
    if i % 25 == 0:
      print('Epoch: [{0}][{1}/{2}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
              epoch, i, len(train_loader), batch_time=batch_time,
             data_time=data_time, loss=losses)) 

  print('='*10, 'Finished training epoch {}'.format(epoch), '='*10)

In [10]:
# Move model and loss function to GPU
if use_gpu: 
    criterion = criterion.cuda()
    model = model.cuda()

model.share_memory();

In [None]:
if __name__ == '__main__':
    best_losses = 1e10
    epochs = 20

    # Train model
    for epoch in range(epochs):
      # Train for one epoch, then validate
      train(train_loader, model, criterion, optimizer, epoch)
      with torch.no_grad():
        losses = validate(val_loader, model, criterion, epoch)
      # Save checkpoint and replace old best model if current model is better
      if losses < best_losses:
        best_losses = losses
        torch.save(model.state_dict(), 'checkpoints/model-epoch-{}-losses-{:.3f}.pth'.format(epoch+1,losses))

**Training iteration 0**
Epoch: [0][0/17]	Time 71.433 (71.433)	Data 67.422 (67.422)	Loss 569.1596 (569.1596)	
**Training iteration 1**
**Training iteration 2**
**Training iteration 3**
**Training iteration 4**
**Training iteration 5**
**Training iteration 6**
**Training iteration 7**
**Training iteration 8**
**Training iteration 9**
**Training iteration 10**
**Training iteration 11**
**Training iteration 12**


Process ForkPoolWorker-1710:
Process ForkPoolWorker-1793:
Process ForkPoolWorker-1749:
Process ForkPoolWorker-1794:
Process ForkPoolWorker-1731:
Process ForkPoolWorker-1740:
Traceback (most recent call last):
Process ForkPoolWorker-1727:
Process ForkPoolWorker-1713:
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
Process ForkPoolWorker-1699:
Process ForkPoolWorker-1672:
Process ForkPoolWorker-1760:
Process ForkPoolWorker-1737:
Process ForkPoolWorker-1680:
Process ForkPoolWorker-1698:
Process ForkPoolWorker-1741:
Process ForkPoolWorker-1722:
Process ForkPoolWorker-1709:
Process ForkPoolWorker-1730:
Process ForkPoolWorker-1744:
Process ForkPoolWorker-1695:
Process ForkPoolWorker-1747:
Process ForkPoolWorker-1745:
Process ForkPoolWorker-1748:
  File "/opt/conda/lib/python3.7/multiprocessing/pool.py", line 110, in w