# ML-SIM Inference script

### Imports

In [18]:
import torch
import matplotlib.pyplot as plt
import torchvision
import skimage
from skimage.measure import compare_ssim
import numpy as np
import time
from PIL import Image
import scipy.ndimage as ndimage
import torch.nn as nn
import os
from skimage import io,exposure,img_as_ubyte
import glob
import argparse
from models import GetModel

## Evaluate functions

In [20]:
def LoadModel(opt):
    print('Loading model')
    print(opt)

    net = GetModel(opt)
    print('loading checkpoint',opt.weights)
    checkpoint = torch.load(opt.weights,map_location=opt.device)

    if type(checkpoint) is dict:
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint

    net.load_state_dict(state_dict)

    return net


def SIM_reconstruct_singleStack(model, opt, stack):
    
    def prepimg(stack,self):

        inputimg = stack[:9]

        if self.nch_in == 6:
            inputimg = inputimg[[0,1,3,4,6,7]]
        elif self.nch_in == 3:
            inputimg = inputimg[[0,4,8]]

        if inputimg.shape[1] > 512 or inputimg.shape[2] > 512:
            print('Over 512x512! Cropping')
            inputimg = inputimg[:,:512,:512]

        inputimg = inputimg.astype('float') / np.max(inputimg) # used to be /255
        widefield = np.mean(inputimg,0) 

        if self.norm == 'adapthist':
            for i in range(len(inputimg)):
                inputimg[i] = exposure.equalize_adapthist(inputimg[i],clip_limit=0.001)
            widefield = exposure.equalize_adapthist(widefield,clip_limit=0.001)
            inputimg = torch.from_numpy(inputimg).float()
            widefield = torch.from_numpy(widefield).float()
        else:
            # normalise 
            inputimg = torch.from_numpy(inputimg).float()
            widefield = torch.from_numpy(widefield).float()
            widefield = (widefield - torch.min(widefield)) / (torch.max(widefield) - torch.min(widefield))

            if self.norm == 'minmax':
                for i in range(len(inputimg)):
                    inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))

        return inputimg,widefield    
    
    inputimg, wf = prepimg(stack,opt)
    wf = (255*wf.numpy()).astype('uint8')

    with torch.no_grad():
        sr = model(inputimg.unsqueeze(0).to(opt.device))
        sr = sr.cpu()
        sr = torch.clamp(sr,min=0,max=1) 

    sr = sr.squeeze().numpy()
    sr = (255*sr).astype('uint8')
    if opt.norm == 'adapthist':
        sr = exposure.equalize_adapthist(sr,clip_limit=0.01)    
    return inputimg, wf, sr
        
def SIM_reconstruct(model, opt):

    os.makedirs('%s' % opt.out,exist_ok=True)
    files = glob.glob('%s/*.tif' % opt.root)
    count = 0
    
    for iidx,imgfile in enumerate(files):
        
        print('[%d/%d] Reconstructing %s' % (iidx+1,len(files),imgfile))
        stack = io.imread(imgfile)
        
        if stack.shape[0] >= 2*opt.nch_in:
            for stack_idx in range(stack.shape[0] // opt.nch_in):
                stackSubset = stack[stack_idx*opt.nch_in:(stack_idx+1)*opt.nch_in]
                inputimg, wf, sr = SIM_reconstruct_singleStack(model, opt, stackSubset)
                skimage.io.imsave('%s/test_wf_%d.jpg' % (opt.out,count), wf)
                skimage.io.imsave('%s/test_sr_%d.jpg' % (opt.out,count), sr) 
                count += 1
        else:
            inputimg, wf, sr = SIM_reconstruct_singleStack(model, opt, stack)              
            skimage.io.imsave('%s/test_wf_%d.jpg' % (opt.out,count), wf)
            skimage.io.imsave('%s/test_sr_%d.jpg' % (opt.out,count), sr) 
            count += 1

### Model 1

In [23]:
opt = argparse.Namespace()

opt.root = 'AtheiSIM_batch_data'
opt.out = 'test-output_model-1'
opt.task = 'simin_gtout'
opt.norm = 'minmax'
opt.dataset = 'fouriersim'

opt.model = 'rcan'

# data
opt.imageSize = 512
opt.weights = 'DIV2K_randomised_3x3_20200317.pth'

# input/output layer options
opt.scale = 1
opt.nch_in = 9
opt.nch_out = 1

# architecture options 
opt.narch = 0
opt.n_resblocks = 10
opt.n_resgroups = 3
opt.reduction = 16
opt.n_feats = 96

# test options
opt.test = False
opt.cpu = True
opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')

In [24]:
net = LoadModel(opt)
SIM_reconstruct(net,opt)

Loading model
Namespace(cpu=True, dataset='fouriersim', device=device(type='cpu'), imageSize=512, model='rcan', n_feats=96, n_resblocks=10, n_resgroups=3, narch=0, nch_in=9, nch_out=1, norm='minmax', out='test-output_model-1', reduction=16, root='AtheiSIM_batch_data', scale=1, task='simin_gtout', test=False, weights='DIV2K_randomised_3x3_20200317.pth')
loading checkpoint DIV2K_randomised_3x3_20200317.pth
[1/1] Reconstructing AtheiSIM_batch_data\AtheiSIM-concat.tif


### Model 2

In [34]:
opt = argparse.Namespace()

opt.root = 'AtheiSIM_batch_data'
opt.out = 'test-output_model-2'
opt.task = 'simin_gtout'
opt.dataset = 'fouriersim'
opt.norm = 'minmax' ## trained on 'adapthist', but 'minmax' may be better for tests 
opt.model = 'rcan'

# data
opt.imageSize = 512
opt.weights = "0216_SIMRec_0214_rndAll_rcan_continued.pth"

# input/output layer options
opt.scale = 1
opt.nch_in = 9
opt.nch_out = 1

# architecture options 
opt.narch = 0
opt.n_resgroups = 3
opt.n_resblocks = 10
opt.n_feats = 48
opt.reduction = 16
opt.narch = 0

    
# test options
opt.test = False
opt.cpu = True
opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')

In [35]:
net = LoadModel(opt)
SIM_reconstruct(net,opt)

Loading model
Namespace(cpu=True, dataset='fouriersim', device=device(type='cpu'), imageSize=512, model='rcan', n_feats=48, n_resblocks=10, n_resgroups=3, narch=0, nch_in=9, nch_out=1, norm='minmax', out='test-output_model-2', reduction=16, root='AtheiSIM_batch_data', scale=1, task='simin_gtout', test=False, weights='0216_SIMRec_0214_rndAll_rcan_continued.pth')
loading checkpoint 0216_SIMRec_0214_rndAll_rcan_continued.pth
[1/1] Reconstructing AtheiSIM_batch_data\AtheiSIM-concat.tif


## Model 2 -- adapthist norm

In [36]:
opt = argparse.Namespace()

opt.root = 'AtheiSIM_batch_data'
opt.out = 'test-output_model-2-adapthist'
opt.task = 'simin_gtout'
opt.dataset = 'fouriersim'
opt.norm = 'adapthist' ## trained on 'adapthist', but 'minmax' may be better for tests 
opt.model = 'rcan'

# data
opt.imageSize = 512
opt.weights = "0216_SIMRec_0214_rndAll_rcan_continued.pth"

# input/output layer options
opt.scale = 1
opt.nch_in = 9
opt.nch_out = 1

# architecture options 
opt.narch = 0
opt.n_resgroups = 3
opt.n_resblocks = 10
opt.n_feats = 48
opt.reduction = 16
opt.narch = 0

    
# test options
opt.test = False
opt.cpu = True
opt.device = torch.device('cuda' if torch.cuda.is_available() and not opt.cpu else 'cpu')

In [37]:
net = LoadModel(opt)
SIM_reconstruct(net,opt)

Loading model
Namespace(cpu=True, dataset='fouriersim', device=device(type='cpu'), imageSize=512, model='rcan', n_feats=48, n_resblocks=10, n_resgroups=3, narch=0, nch_in=9, nch_out=1, norm='adapthist', out='test-output_model-2-adapthist', reduction=16, root='AtheiSIM_batch_data', scale=1, task='simin_gtout', test=False, weights='0216_SIMRec_0214_rndAll_rcan_continued.pth')
loading checkpoint 0216_SIMRec_0214_rndAll_rcan_continued.pth
[1/1] Reconstructing AtheiSIM_batch_data\AtheiSIM-concat.tif


