# ERnet Transformer

# Table of Contents

- [1. Install packages](#install)
- [2. Download files](#download)
- [3. Functions](#functions)
- [4. Inference](#inference)
    - [a. Single images](#single)
    - [b. Batch processing](#batch)

<a id='install'></a>
## 1. Install packages

In [None]:
!pip install sknw timm einops

<a id='download'></a>
## 2. Download files

In [None]:
# architectures
!mkdir -p archs
!wget https://raw.githubusercontent.com/charlesnchr/ERnet-v2/main/Training/archs/swinir_rcab_arch.py -P Training/archs
!wget https://raw.githubusercontent.com/charlesnchr/ERnet-v2/main/Training/archs/rcan_arch.py -P Training/archs

# inference code
!wget https://raw.githubusercontent.com/charlesnchr/ERnet-v2/main/Inference/model_evaluation.py -P Inference
!wget https://raw.githubusercontent.com/charlesnchr/ERnet-v2/main/Inference/graph_processing.py -P Inference

# models
!mkdir -p models
!wget https://github.com/charlesnchr/ERnet-v2/releases/download/v2.0/20220306_ER_4class_swinir_nch1.pth -P models

# image files
!mkdir -p images
!wget https://github.com/charlesnchr/ERnet-v2/releases/download/v2.0/TestImage1.png -P images
!wget https://github.com/charlesnchr/ERnet-v2/releases/download/v2.0/TestImage2.png -P images
!wget https://github.com/charlesnchr/ERnet-v2/releases/download/v2.0/TestImage3-stack.tif -P images
!wget https://github.com/charlesnchr/ERnet-v2/releases/download/v2.0/TestImage4-stack.tif -P images

<a id='functions'></a>
# 3. Functions

In [18]:
%reload_ext autoreload
%autoreload 2

# Imports
import os
import datetime
import math
import time
from skimage import io
import os
from argparse import Namespace
import sys

# Load code from repository (ERnet architecture + inference code)
dirs = ["Training/archs", "Inference"]
[sys.path.append(os.path.abspath(f)) for f in dirs]
from Inference import model_evaluation


def segment(
    exportdir,
    filepaths,
    weka_colours,
    stats_tubule_sheet,
    graph_metrics,
    save_in_original_folders,
    save_input=True,
):
    opt = Namespace()
    opt.root = filepaths
    opt.ext = ["jpg", "png", "tif"]
    opt.stats_tubule_sheet = stats_tubule_sheet
    opt.graph_metrics = graph_metrics
    opt.weka_colours = weka_colours
    opt.save_input = save_input

    opt.exportdir = exportdir
    os.makedirs(exportdir, exist_ok=True)
    opt.jobname = datetime.datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:-3]

    if stats_tubule_sheet:
        csvfid_path = "%s/%s_stats_tubule_sheet.csv" % (opt.exportdir, opt.jobname)
        opt.csvfid = open(csvfid_path, "w")

    if opt.graph_metrics:
        graphfid_path = "%s/%s_graph_metrics.csv" % (opt.exportdir, opt.jobname)
        opt.graphfid = open(graphfid_path, "w")

    ## model specific
    opt.imageSize = 600
    opt.n_resblocks = 10
    opt.n_resgroups = 3
    opt.n_feats = 64
    opt.reduction = 16
    opt.narch = 0
    opt.norm = None
    opt.nch_in = 1
    opt.nch_out = 4
    opt.cpu = False
    opt.weights = model
    opt.scale = 1

    if save_in_original_folders:
        opt.out = "root"

    print(vars(opt))

    return model_evaluation.EvaluateModel(opt)

<a id='inference'></a>
# 4. Inference

<a id='single'></a>
## a. Example of using ERnet Transformer on single images

In [None]:
exportdir = 'output'
filepaths = ['images/TestImage1.png','images/TestImage2.png']
model = 'models/20220306_ER_4class_swinir_nch1.pth'
weka_colours = False
stats_tubule_sheet = True
graph_metrics = True
save_in_original_folders = True
outpaths = segment(exportdir,filepaths,weka_colours,stats_tubule_sheet,graph_metrics,save_in_original_folders,model)


### Visualise result

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.clf()

# output files per input file
n = len(outpaths)//len(filepaths)

for idx, inpath in enumerate(filepaths):

    outpath = [outpaths[i] for i in range(idx*n, (idx+1)*n )]

    plt.figure(figsize=(20,10))
    plt.subplot(221)
    plt.imshow(io.imread(inpath))
    plt.title('Input %d: %s' % (idx+1,inpath))
    plt.subplot(222)
    plt.imshow(io.imread(outpath[0]))
    plt.title('Output %d' % (idx+1))

    plt.subplot(223)
    plt.imshow(io.imread(outpath[1]))
    plt.title('Graph representation %d' % (idx+1))
    plt.subplot(224)
    plt.imshow(io.imread(outpath[2]))
    plt.title('Degree histogram %d' % (idx+1))
plt.show()

<a id='batch'></a>
## b. Example of batch processing

`TestImage3-stack.tif` consists of 5 frames

`TestImage4-stack.tif` consists of 30 frames

In [None]:
exportdir = 'output'
filepaths = ['images/TestImage3-stack.tif','images/TestImage4-stack.tif']
model = 'models/20220306_ER_4class_swinir_nch1.pth'
weka_colours = False
stats_tubule_sheet = True
graph_metrics = True
save_in_original_folders = True
outpaths = segment(exportdir,filepaths,weka_colours,stats_tubule_sheet,graph_metrics,save_in_original_folders,model)