In [16]:
import numpy as np
from tqdm.notebook import tqdm
import glob
import sys
import os
from PIL import Image


In [17]:
import argparse
def GetParams():
  opt = argparse.Namespace()

  opt.model='rcan'#'model to use'
  opt.lr = 0.0001 # learning rate
  opt.norm = '' # if normalization should not be used
  opt.nepoch =100 # number of epochs to train for
  opt.saveinterval =1 # number of epochs between saves
  opt.modifyPretrainedModel = False
  opt.multigpu = False
  opt.undomulti = False
  opt.ntrain = 5000 # number of samples to train on
  opt.scheduler = '' # options for a scheduler, format: stepsize,gamma
  opt.log = False
  opt.noise ='' # options for noise added, format: poisson,gaussVar

  # data
  opt.dataset = 'fouriersim' # dataset to train
  opt.imageSize = 512 # the low resolution image size
  opt.weights = 'D:/User/Edward/Downloads/HPC-download/prelim32.pth' # model to retrain from
  opt.basedir = '' # path to prepend to all others paths: root, output, weights
  opt.root ='D:/Work/Test datasets/OS-SIM' # dataset to train
  opt.server = '' # whether to use server root preset
  opt.local = '' # whether to use local root preset: C:/phd-data/datasets/
  opt.out = 'D:/Work/Test datasets/OS-SIM/ML-SIM reconstructions/HPC 01-05-2021' # folder to output model training results

  # computation 
  opt.workers  = 1 # number of data loading workers
  opt.batchSize = 10 # input batch size

  # restoration options
  opt.task ='sr' # restoration task 
  opt.scale = 1 # low to high resolution scaling factor
  opt.nch_in = 3 # channels in input 
  opt.nch_out = 1 # channels in output 

  # architecture options 
  opt.narch = 0 # architecture-dependent parameter
  opt.n_resblocks  = 2 # number of residual blocks 
  opt.n_resgroups  = 3 # number of residual groups 
  opt.reduction  = 2 # number of 36eature maps
  opt.n_feats = 54 

  # test options
  opt.ntest  = 10 # number of images to test per epoch or test run 
  opt.testinterval  = 1 # number of epochs between tests during training 
  opt.test = False
  opt.cpu = False # not supported for training
  opt.batchSize_test  = 1 # input batch size for test loader 
  opt.plotinterval  = 1 # number of test samples between plotting 
    
  return opt

In [18]:

import math
import os

import torch
import time 

import torch.optim as optim
import torchvision
from torch.autograd import Variable

from skimage import io
from models import *
from datahandler import *

import matplotlib.pyplot as plt
from tqdm import tqdm
import glob

def remove_dataparallel_wrapper(state_dict):
	r"""Converts a DataParallel model to a normal one by removing the "module."
	wrapper in the module dictionary

	Args:
		state_dict: a torch.nn.DataParallel state dictionary
	"""
	from collections import OrderedDict

	new_state_dict = OrderedDict()
	for k, vl in state_dict.items():
		name = k[7:] # remove 'module.' of DataParallel
		new_state_dict[name] = vl

	return new_state_dict


def changeColour(I): # change colours (used to match WEKA output)
    Inew = np.zeros(I.shape + (3,)).astype('uint8')
    for rowidx in range(I.shape[0]):
        for colidx in range(I.shape[1]):
            if I[rowidx][colidx] == 0:
                Inew[rowidx][colidx] = [198,118,255]
            elif I[rowidx][colidx] == 127:
                Inew[rowidx][colidx] = [79,255,130]
            elif I[rowidx][colidx] == 255:
                Inew[rowidx][colidx] = [255,0,0]
    return Inew


def loadimg(imgfile):
    stack = io.imread(imgfile)
    inputimgs,wfimgs = [],[]

    for i in range(int(len(stack)/9)):
        inputimg = stack[i*9:(i+1)*9]

        if inputimg.shape[1] != 512 or inputimg.shape[2] != 512:
            print(imgfile,'not 512x512! Cropping')
            inputimg = inputimg[:,:512,:512]
                 
        widefield = np.mean(inputimg,0)
        widefield = (widefield - np.min(widefield)) / (np.max(widefield) - np.min(widefield))    


        if opt.norm == 'convert': # raw img from microscope, needs normalisation and correct frame ordering
            print('Raw input assumed - converting')


            inputimg = np.rot90(inputimg,axes=(1,2))
            inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
            for i in range(len(inputimg)):
                inputimg[i] = 100 / np.max(inputimg[i]) * inputimg[i]
            inputimg = torch.tensor(inputimg.astype('float') / 255).float()
        elif 'convert' in opt.norm:
            fac = float(opt.norm[7:])
            inputimg = np.rot90(inputimg,axes=(1,2))
            inputimg = inputimg[[6,7,8,3,4,5,0,1,2]] # could also do [8,7,6,5,4,3,2,1,0]
            for i in range(len(inputimg)):
                inputimg[i] = fac * 255 / np.max(inputimg[i]) * inputimg[i]
            inputimg = torch.tensor(inputimg.astype('float') / 255).float()
        elif opt.norm == 'minmax':
            inputimg = torch.tensor(inputimg.astype('float') / 255).float()
            for i in range(len(inputimg)):
                inputimg[i] = (inputimg[i] - torch.min(inputimg[i])) / (torch.max(inputimg[i]) - torch.min(inputimg[i]))

        else:
            inputimg = torch.tensor(inputimg.astype('float') / 255).float()


        widefield = torch.tensor(widefield).float()
        
        inputimgs.append(inputimg)
        wfimgs.append(widefield)

    return inputimgs,wfimgs


def EvaluateModel(opt):

    try:
        os.makedirs(opt.out)
    except IOError:
        pass

    opt.fid = open(opt.out + '/log.txt','w')
    print(opt)
    print(opt,'\n',file=opt.fid)
    
    net = GetModel(opt)

    checkpoint = torch.load(opt.weights)
    if opt.cpu:
        net.cpu()
    
    print('loading checkpoint',opt.weights)
    if opt.undomulti:
        checkpoint['state_dict'] = remove_dataparallel_wrapper(checkpoint['state_dict'])
    net.load_state_dict(checkpoint['state_dict'])

    if opt.root.split('.')[-1] == 'png' or opt.root.split('.')[-1] == 'jpg':
        imgs = [opt.root]
    else:
        imgs = []
        imgs.extend(glob.glob(opt.root + '/*.jpg'))
        imgs.extend(glob.glob(opt.root + '/*.png'))
        imgs.extend(glob.glob(opt.root + '/*.tif'))
        if len(imgs) == 0: # scan everything
            imgs.extend(glob.glob(opt.root + '/**/*.jpg',recursive=True))
            imgs.extend(glob.glob(opt.root + '/**/*.png',recursive=True))
            imgs.extend(glob.glob(opt.root + '/**/*.tif',recursive=True))

    imageSize = opt.imageSize

    for i, imgfile in enumerate(imgs):
        print('\rProcessing image [%d/%d]' % (i+1,len(imgs)),end='')
        img = io.imread(imgfile)
        img = np.array(img)
        img = img/np.amax(img)

        nImgs = img.shape[0] // opt.nch_in
        srs = np.zeros([512,512,nImgs]) 
        wfs = np.zeros([512,512,nImgs])

        if img.shape[1] != 512:
            print('\rimage', imgfile,' is not 512x512! Cropping')
            img = img[:,:512,:512]

        frames = np.zeros([512,512,nImgs])
        for stack_idx in tqdm(range(nImgs // opt.nch_in)):
            stackSubset = img[stack_idx*opt.nch_in:(stack_idx+1)*opt.nch_in] 
            wfs[:,:,stack_idx] = np.mean(stackSubset,0)

            sub_tensor = toTensor(np.moveaxis(stackSubset,0,2))
            sub_tensor = sub_tensor.unsqueeze(0)
            sub_tensor = sub_tensor.type(torch.FloatTensor)
            
            
            with torch.no_grad():
                if opt.cpu:
                    sr = net(sub_tensor)
                else:
                    sr = net(sub_tensor.cuda())
                sr = sr.cpu()

                sr = torch.clamp(sr[0],0,1)
                sr_frame = sr.numpy()
                sr_frame = np.squeeze(sr_frame)
                srs[:,:,stack_idx] = sr_frame
                               



        frames = (frames * 32000).astype('uint16')
        if nImgs > 1:
            frames = np.moveaxis(frames,2,0)
        
        filename = os.path.basename(imgfile)[:-4]
        svPath = opt.out + '/' + filename +'_sr.tif'
        io.imsave(svPath,frames)



if __name__ == '__main__':
    opt = GetParams()

    EvaluateModel(opt)


Namespace(basedir='', batchSize=10, batchSize_test=1, cpu=False, dataset='fouriersim', fid=<_io.TextIOWrapper name='D:/Work/Test datasets/OS-SIM/ML-SIM reconstructions/HPC 01-05-2021/log.txt' mode='w' encoding='cp1252'>, imageSize=512, local='', log=False, lr=0.0001, model='rcan', modifyPretrainedModel=False, multigpu=False, n_feats=54, n_resblocks=2, n_resgroups=3, narch=0, nch_in=3, nch_out=1, nepoch=100, noise='', norm='', ntest=10, ntrain=5000, out='D:/Work/Test datasets/OS-SIM/ML-SIM reconstructions/HPC 01-05-2021', plotinterval=1, reduction=2, root='D:/Work/Test datasets/OS-SIM', saveinterval=1, scale=1, scheduler='', server='', task='sr', test=False, testinterval=1, undomulti=False, weights='D:/User/Edward/Downloads/HPC-download/prelim32.pth', workers=1)
not using normalization
  0%|          | 0/3 [00:00<?, ?it/s]loading checkpoint D:/User/Edward/Downloads/HPC-download/prelim32.pth
100%|██████████| 3/3 [00:00<00:00,  5.93it/s]
100%|██████████| 20/20 [00:03<00:00,  6.08it/s]
100

In [21]:
 opt.out

'D:/Work/Test datasets/OS-SIM/ML-SIM reconstructions/HPC 01-05-2021'