# Resunet Explorer

## Import libraries

In [7]:
from re import split

import pyprog

# Initial imports and device setting
from pathlib import Path
from functools import partial


import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import torch
import torch.nn as nn
from torch.nn.functional import interpolate 
from torch.functional import F
import imgaug.augmenters as iaa

# Libraries for graph
import networkx.drawing as draw
import networkx as nx

import math 
import numpy as np
import pandas as pd
import seaborn as sns
import plotly.express as px

from plotly.subplots import make_subplots
import plotly.graph_objects as go


from torchtrainer.imagedataset import ImageSegmentationDataset
from torchtrainer import img_util
from torchtrainer import transforms
from torchtrainer.models.resunet import ResUNet
from torchtrainer.module_util import ActivationSampler

# Import Resunet Explorer library 
from resunetexplorer.layer_extractor import ExtractResUNetLayers
from resunetexplorer.maps_extractor import ExtractResUNetMaps

import scipy
from scipy import ndimage, misc

# PCA module
from sklearn.decomposition import PCA


use_cuda = False

if use_cuda and torch.cuda.is_available():
    device = torch.device('cuda')
    dev_info = torch.cuda.get_device_properties(device)
    print(dev_info)
else:
    device = torch.device('cpu')

## Functions

In [2]:
def dataset_creation(root_dir_, img_dir_, label_dir_):
  """
  Load the dataset given the complet path.
  """
  # Dataset creation
  def img_name_to_label(filename):
      return filename.split('.')[0] + '.png'

  root_dir = Path(root_dir_)
  img_dir = root_dir/img_dir_
  label_dir = root_dir/label_dir_

  # Data transformations
  imgaug_seq = iaa.Sequential([
      iaa.CLAHE(clip_limit=6, tile_grid_size_px=12)
  ])    
  imgaug_seq = transforms.translate_imagaug_seq(imgaug_seq)
  valid_transforms = [transforms.TransfToImgaug(), imgaug_seq, transforms.TransfToTensor(), 
                      transforms.TransfWhitten(67.576, 37.556)]

  img_opener = partial(img_util.pil_img_opener, channel=None)
  label_opener = partial(img_util.pil_img_opener, is_label=True)
  dataset = ImageSegmentationDataset(img_dir, label_dir, name_to_label_map=img_name_to_label, img_opener=img_opener, 
                                    label_opener=label_opener, transforms=valid_transforms)

  return dataset

In [3]:
def load_model_checkpoint(path, device):
  """
  Load the model from a checkpoint given a path to file and the device which 
  will process the model
  """
  checkpoint = torch.load(path, map_location=torch.device(device))
  model = ResUNet(num_channels=1, num_classes=2) 
  model.load_state_dict(checkpoint['model_state'])
  model.eval()
  model.to(device);

  return model

## Feature maps visulization example


In [6]:
# Load dataset
dataset = dataset_creation('data', 'CD31(vessels)', 'labels')
# Model path
model_path = 'learner_vessel.tar'
# Load model
model = load_model_checkpoint(model_path, device)

## ExtractResUNetLayers class

In [17]:
# ExtractResUNetLayers test
layers_paths = ['encoder.resblock1.conv1', 
                '_l4.conv1']
erl = ExtractResUNetLayers(model, layers_paths)
layers = erl.get_layers()


In [18]:
layers

{'network_part': ['encoder.resblock1.conv1', '_l4.conv1'],
 'n_maps': [64, 512],
 'layer': [Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)]}

## ExtractResUNetMaps Class

In [None]:

layers_fm_list = get_multiple_feature_maps(img_idx, layers['layer'])

maps_idx = [2,6,7,52]

show_feature_maps(img_idx, layers_name, layers_fm_list, maps_idx)