## Pipeline to compare reconstruction to target IMC images

In [None]:
# For plotting
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from skimage import io
import cv2 as cv
# For everything
import torch
import torch.nn as nn
import torch.nn.functional as F
# For our model
import torchvision.models as models
from torchvision import datasets, transforms
from torch.multiprocessing import Pool, set_start_method
# For utilities
import os, shutil, time

In [None]:
# code from https://amaarora.github.io/2020/09/13/unet.html#u-net
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
        self.batchnorm2d = nn.BatchNorm2d(in_ch)
    
    def forward(self, x):
        x = self.batchnorm2d(x)
        x = self.conv1(x)
        x = self.relu(x)
        return self.conv2(x)
#         return self.conv2(self.relu(self.conv1(self.batchnorm2d(x))))


class Encoder(nn.Module):
    def __init__(self, chs=(8,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 40)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)            
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(8,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.head = nn.Conv2d(dec_chs[-1], 40, 1)

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out = self.head(out)
        out = F.interpolate(out, (256, 256))
        return out
    
    
def process_stpt_image(file_name):
    img = io.imread(file_name)
    
    # normalize image (8 bits)
    norm_img = img.copy()
    cv.normalize(img, norm_img, alpha=0, beta=2**8 - 1, norm_type=cv.NORM_MINMAX)

    # Apply log transformation method
    c = (2**8 - 1) / np.log(1 + np.max(norm_img))

    log_image = c * (np.log(norm_img + 1))
    # Specify the data type so that
    # float value will be converted to int
    return torch.from_numpy(log_image)

In [None]:
model = UNet().double()

checkpoint = torch.load('../checkpoints/model-epoch-2-losses-294.774.pth')
model.load_state_dict(checkpoint)

In [None]:
phys_sec = '7'   # choose which physical section to reconstruct
chunk = '30_30'  # choose which chunk to extract

# load images
stpt_piece = torch.load('../data/train/STPT/{0}/{1}.pt'.format(phys_sec.zfill(2), chunk))
imc_reconst = model(stpt_piece.unsqueeze(0)).squeeze()

imc_true = torch.load('../data/train/IMC/{0}/{1}.pt'.format(phys_sec.zfill(2)), chunk)

In [None]:
# plot sample and target pairs

channels = [i for i in range(5)]
f, axarr = plt.subplots(num_channels, 2)
for channel in channels:
    axarr[channel, 0].imshow(imc_reconst[channel].detach().numpy(), cmap='gray')
    axarr[channel, 1].imshow(imc_true[channel].detach().numpy(), cmap='gray')

In [None]:
plt.imshow(imc_reconst[25].detach().numpy(), cmap='gray')

In [None]:
plt.imshow(imc_true[25].detach().numpy(), cmap='gray')