In [None]:
import numpy as np
from tqdm.notebook import tqdm
import glob
import sys
import os
from PIL import Image


In [None]:
import argparse
def GetParams():
  opt = argparse.Namespace()

  opt.model='rcan'#'model to use'
  opt.lr = 0.0001 # learning rate
  opt.norm = 'minmax' # if normalization should not be used
  opt.nepoch =100 # number of epochs to train for
  opt.saveinterval =1 # number of epochs between saves
  opt.modifyPretrainedModel = False
  opt.multigpu = False
  opt.undomulti = False
  opt.ntrain = 5000 # number of samples to train on
  opt.scheduler = '' # options for a scheduler, format: stepsize,gamma
  opt.log = False
  opt.noise ='' # options for noise added, format: poisson,gaussVar

  # data
  opt.dataset = 'fouriersim' # dataset to train
  opt.imageSize = 255 # the low resolution image size
  opt.weights = 'D:/ML-SIM/OS-SIM/255 models/generated 27-05-2021/results/prelim100.pth' # model to retrain from
  opt.basedir = '' # path to prepend to all others paths: root, output, weights
  opt.root ='D:/ML-SIM/OS-SIM/15-05-2021/cells/488-membrane_561-ER_647-mito12/' # dataset to train
  opt.server = '' # whether to use server root preset
  opt.local = '' # whether to use local root preset: C:/phd-data/datasets/
  opt.out = 'D:/User/Edward/Documents/GitHub/ML OS-SIM/Training code/' # folder to output model training results

  # computation 
  opt.workers  = 1 # number of data loading workers
  opt.batchSize = 10 # input batch size

  # restoration options
  opt.task ='sr' # restoration task 
  opt.scale = 1 # low to high resolution scaling factor
  opt.nch_in = 3 # channels in input 
  opt.nch_out = 1 # channels in output 

  # architecture options 
  opt.narch = 0 # architecture-dependent parameter
  opt.n_resblocks  = 3 # number of residual blocks 
  opt.n_resgroups  = 5 # number of residual groups 
  opt.reduction  = 16 # number of 36eature maps
  opt.n_feats = 96 

  # test options
  opt.ntest  = 10 # number of images to test per epoch or test run 
  opt.testinterval  = 1 # number of epochs between tests during training 
  opt.test = False
  opt.cpu = False # not supported for training
  opt.batchSize_test  = 1 # input batch size for test loader 
  opt.plotinterval  = 1 # number of test samples between plotting 
    
  return opt

In [None]:

import math
import os

import torch
import time 

import torch.optim as optim
import torchvision
from torch.autograd import Variable
import skimage
from skimage import io
from models import *
from datahandler import *

import matplotlib.pyplot as plt
from tqdm import tqdm
import glob

def remove_dataparallel_wrapper(state_dict):
	r"""Converts a DataParallel model to a normal one by removing the "module."
	wrapper in the module dictionary

	Args:
		state_dict: a torch.nn.DataParallel state dictionary
	"""
	from collections import OrderedDict

	new_state_dict = OrderedDict()
	for k, vl in state_dict.items():
		name = k[7:] # remove 'module.' of DataParallel
		new_state_dict[name] = vl

	return new_state_dict


def EvaluateModel(opt):

    try:
        os.makedirs(opt.out)
    except IOError:
        pass

    opt.fid = open(opt.out + '/log.txt','w')
    print(opt)
    print(opt,'\n',file=opt.fid)
    
    net = GetModel(opt)

    checkpoint = torch.load(opt.weights)
    if opt.cpu:
        net.cpu()
    
    print('loading checkpoint',opt.weights)
    if opt.undomulti:
        checkpoint['state_dict'] = remove_dataparallel_wrapper(checkpoint['state_dict'])
    net.load_state_dict(checkpoint['state_dict'])

    if opt.root.split('.')[-1] == 'png' or opt.root.split('.')[-1] == 'jpg':
        imgs = [opt.root]
    else:
        imgs = []
        imgs.extend(glob.glob(opt.root + '/*.jpg'))
        imgs.extend(glob.glob(opt.root + '/*.png'))
        imgs.extend(glob.glob(opt.root + '/*.tif'))
        if len(imgs) == 0: # scan everything
            imgs.extend(glob.glob(opt.root + '/**/*.jpg',recursive=True))
            imgs.extend(glob.glob(opt.root + '/**/*.png',recursive=True))
            imgs.extend(glob.glob(opt.root + '/**/*.tif',recursive=True))

    
    for i, imgfile in enumerate(imgs):
        description = 'Processing image [%d/%d]' % (i+1,len(imgs))
        handle = skimage.external.tifffile.TiffFile(imgfile)
        
        nImgs = len(handle) // opt.nch_in

        X = handle[0].shape[0]
        Y = handle[0].shape[1]

        filename = os.path.basename(imgfile)[:-4]
        SRsvPath = opt.out + '/' + filename +'_sr.tif'  
        WFsvPath = opt.out + '/' + filename +'_wf.tif'
                
        for stack_idx in tqdm(range(nImgs),desc=description):
            stackSubset = skimage.external.tifffile.imread(path,key=range(stack_idx*opt.nch_in,(stack_idx+1)*opt.nch_in))
            stackSubset = stackSubset/np.amax(stackSubset)
            wf = np.mean(stackSubset,0)

            sub_tensor = toTensor(np.moveaxis(stackSubset,0,2))
            sub_tensor = sub_tensor.unsqueeze(0)
            sub_tensor = sub_tensor.type(torch.FloatTensor)
          
            
            with torch.no_grad():
                if opt.cpu:
                    sr = net(sub_tensor)
                else:
                    sr = net(sub_tensor.cuda())
                sr = sr.cpu()

                sr = torch.clamp(sr[0],0,1)
                sr_frame = sr.numpy()
                sr_frame = np.squeeze(sr_frame)
  
            wf = (wf * 32000).astype('uint16')
            sr_frame = (sr_frame * 32000).astype('uint16')
            if stack_idx == 0:
                skimage.external.tifffile.imsave(SRsvPath,np.expand_dims(sr_frame,axis=(0,1)))
                skimage.external.tifffile.imsave(WFsvPath,np.expand_dims(wf,axis=(0,1)))
            else: 
                skimage.external.tifffile.imsave(SRsvPath,np.expand_dims(sr_frame,axis=(0,1)),append=True)
                skimage.external.tifffile.imsave(WFsvPath,np.expand_dims(wf,axis=(0,1)),append=True)
  

if __name__ == '__main__':
    opt = GetParams()

    EvaluateModel(opt)
