In [None]:
import torch
import os

import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from argparse import ArgumentParser
from torch import nn
from pl_bolts.models.autoencoders import AE
from fastai.vision.all import *
from PIL import Image, ImageFile

from autoencoder import AESigmoid
from classifier import Classifier
from tripletnet import TripletNet

from lucent.optvis import render, param, transform, objectives

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Visualise activations of trained networks

A notebook to visualise optimised activations for the most variable channels in the feature extractor layer of three trained networks and put them together with examples of images.

## Setup

Just need to set the image height used during training, and load the genera of the training examples for the classifier network.

In [None]:
input_height = 256
label_data = "../data/herbarium-2021-fgvc8-sampled/sample-metadata.csv"

genera = (
    pd.read_csv(label_data)
      .assign(image_id=lambda df: df.image_id.astype(str))
      .set_index("image_id")
      .genus
)

n_genera = genera.unique().shape[0]

## Load networks

### Autoencoder

In [None]:
ae_model = AESigmoid(input_height=input_height)
ae_model = ae_model.load_from_checkpoint("../lightning_logs/resnet18_size-256_ae/version_6/checkpoints/epoch=24-step=331849.ckpt", 
                                   strict=False, input_height=input_height)
print("set up autoencoder")

### Triplet network

In [None]:
encoder = resnet18(pretrained=True)
encoder = nn.Sequential(*list(encoder.children())[:-1], nn.Flatten())

tpl_model = TripletNet(encoder)
tpl_model = tpl_model.load_from_checkpoint("../lightning_logs/resnet18_size-256_tpl/version_6/checkpoints/epoch=24-step=331849.ckpt", 
                                   strict=False, encoder=encoder)
print("set up triplet network")

### Classifier

In [None]:
encoder = resnet18(pretrained=True)
encoder = nn.Sequential(*list(encoder.children())[:-1], nn.Flatten())

clf_model = Classifier(encoder, 512, genera.unique().shape[0])
clf_model = clf_model.load_from_checkpoint("../lightning_logs/resnet18_size-256_clf/version_17/checkpoints/epoch=24-step=331849.ckpt", 
                                   strict=False, encoder=encoder, input_dim=512, num_classes=genera.unique().shape[0])
print("set up classifier")

## Load features

We're using features extracted from the training data here. 

We're also finding the channel with the greatest variation. There are 512 channels in the feature extraction layer of each network, so looking at all of them would be hard. So we'll just look at the ones with the greatest variation.

In [None]:
feat_values_ae = np.load("../output/resnet18_size-256_ae/version_6/features_train.npy")
max_channel_ae = feat_values_ae.std(axis=0).argmax()

feat_values_tpl = np.load("../output/resnet18_size-256_tpl/version_6/features_train.npy")
max_channel_tpl = feat_values_tpl.std(axis=0).argmax()

feat_values_clf = np.load("../output/resnet18_size-256_clf/version_17/features_train.npy")
max_channel_clf = feat_values_clf.std(axis=0).argmax()

print(f"Autoencoder channel with greatest variation: {max_channel_ae}")
print(f"Triplet channel with greatest variation: {max_channel_tpl}")
print(f"Classifier channel with greatest variation: {max_channel_clf}")

## Visualise optimised activations

We'll use [lucent](https://github.com/greentfrapp/lucent) to visualise the optimised activations. We're visualising the most positive optimised activation and the most negative, so we can look at images that give a spectrum of activations for each channel.

First we need to set our models to evaluation mode and transfer them to the device we want to use. **NB: this may take a long time if you don't have access to a gpu**

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ae = ae_model.encoder.to(device).eval()
tpl = tpl_model.encoder.to(device).eval()
clf = clf_model.encoder.to(device).eval()

### Autoencoder

In [None]:
param_f = lambda: param.image(256 * 2, batch=2)
obj = objectives.channel("avgpool", max_channel_ae, batch=1) - objectives.channel("avgpool", max_channel_ae, batch=0)
ae_activations = render.render_vis(ae, obj, param_f, fixed_image_size=256, thresholds=(1024,))

### Triplet network

In [None]:
param_f = lambda: param.image(256*2, batch=2)
obj = objectives.channel("8", max_channel_tpl, batch=1) - objectives.channel("8", max_channel_tpl, batch=0)
tpl_activations = render.render_vis(tpl, obj, param_f, fixed_image_size=256, thresholds=(1024,))

### Classifier

In [None]:
param_f = lambda: param.image(256*2, batch=2)
obj = objectives.channel("8", max_channel_clf, batch=1) - objectives.channel("8", max_channel_clf, batch=0)
clf_activations = render.render_vis(clf, obj, param_f, fixed_image_size=256, thresholds=(1024,))

## Compare with example images

We'll now create a figure comparing images that produce a spectrum of activation values to the optimised activations.

### Load paths to images

In [None]:
ae_labels = pd.read_csv("../output/resnet18_size-256_ae/version_6/feature-labels_train.csv")
tpl_labels = pd.read_csv("../output/resnet18_size-256_tpl/version_6/feature-labels_train.csv")
clf_labels = pd.read_csv("../output/resnet18_size-256_clf/version_17/feature-labels_train.csv")

### Utility functions

These will make loading, resizing, and plotting the images easier.

In [None]:
def cropresize(img, output_size, centered=True):
    """ Crop an image to the desired aspect ratio then resize.
    """
    if centered:
        crop_pct = (0.5, 0.5)
    else:
        crop_pct = np.random.rand(size=2)
        
    if isinstance(output_size, int):
        output_size = (output_size, output_size)
        
    h, w = img.shape[-2:]
    hs, ws = output_size
    h_pct, w_pct = crop_pct

    scale = w / ws if w/ws < h/hs else h / hs
    crop_size = (int(scale * hs), int(scale * ws))
    
    top, left = int(h_pct * (h - crop_size[0])), int(w_pct * (w - crop_size[1]))
    bottom, right = crop_size[0], crop_size[1]
    
    return img.crop((left, top, right, bottom)).resize((hs, ws))

def make_grid(paths, ids, size=256, pad=8, nrows=3, ncols=3):
    """Load images and make them into a grid that takes up a single numpy array.
    """
    grid_height = size * nrows
    grid_width = size * ncols
    grid = np.ones((grid_height, grid_width, 3)).astype(int) * 255
    
    padded_size = size - pad
    for i, idx in enumerate(ids):
        img = Image.open(paths[idx])
        img = cropresize(img, padded_size)
        vi = i // nrows
        hi = i % ncols
        vi_start = vi*size + pad // 2
        hi_start = hi*size + pad // 2
        vi_end = size*(vi+1) - pad // 2
        hi_end = size*(hi+1) - pad // 2
        grid[vi_start:vi_end, hi_start:hi_end, :] = np.array(img)
        
    return grid

def make_activations_plot(feat_values, activations, paths, n=9, d=0.2, img_size=256, inner_pad=4, outer_pad=9):
    """Select images to give minimum, slightly minimum, slightly maximum, and maximum values for the chosen channel,
    and plot them alongside the optimised activations.
    """
    sorted_ids = np.argsort(feat_values)
    max_ids, min_ids = sorted_ids[-n:], sorted_ids[:n]
    
    delta = (feat_values.max() - feat_values.min()) * d
    slightly_min_ids = np.argsort(abs(feat_values - delta - feat_values.min()))[:n]
    slightly_max_ids = np.argsort(abs(feat_values + delta - feat_values.max()))[:n]
    
    n_side = int(np.sqrt(n))
    
    grid_size = img_size * n_side
    
    plot_size = grid_size * 6 + outer_pad*5
    
    plot = np.ones((grid_size, plot_size, 3)).astype(int) * 255
    
    start = 0
    end = grid_size
    plot[:, start:end, :] = (activations[0] * 255).astype(int)
    
    start += grid_size + outer_pad
    end += grid_size + outer_pad
    plot[:, start:end, :] = make_grid(paths, min_ids, pad=inner_pad, size=img_size, nrows=n_side, ncols=n_side)
    
    start += grid_size + outer_pad
    end += grid_size + outer_pad
    plot[:, start:end, :] = make_grid(paths, slightly_min_ids, pad=inner_pad, size=img_size, nrows=n_side, ncols=n_side)
    
    start += grid_size + outer_pad
    end += grid_size + outer_pad
    plot[:, start:end, :] = make_grid(paths, slightly_max_ids, pad=inner_pad, size=img_size, nrows=n_side, ncols=n_side)
    
    start += grid_size + outer_pad
    end += grid_size + outer_pad
    plot[:, start:end, :] = make_grid(paths, max_ids, pad=inner_pad, size=img_size, nrows=n_side, ncols=n_side)
    
    start += grid_size + outer_pad
    end += grid_size + outer_pad
    plot[:, start:end, :] = (activations[1] * 255).astype(int)
    
    return plot

### Make activation plot

This shows examples and activations for all 3 networks.

In [None]:
fig, axes = plt.subplots(figsize=(8, 4), nrows=3)

plots = [
    make_activations_plot(feat_values_ae[:, max_channel_ae], ae_activations[0], ae_labels.label.values, n=4, d=0.3, outer_pad=128),
    make_activations_plot(feat_values_tpl[:, max_channel_tpl], tpl_activations[0], tpl_labels.label.values, n=4, outer_pad=128),
    make_activations_plot(feat_values_clf[:, max_channel_clf], clf_activations[0], clf_labels.label.values, n=4, d=0.25, outer_pad=128),
]

labels = ["Minimum\noptimised", "Minimum\nexamples", "Negative\nexamples", 
          "Positive\nexamples", "Maximum\nexamples", "Maximum\noptimised"]
letters = ["A", "B", "C"]

vpos = 619

for ax, plot, letter in zip(axes, plots, letters):
    ax.imshow(plot)
    ax.axvline(plot.shape[1]/2, lw=2, zorder=1000, color="#a9a9a9")
    ax.text(-0.01, 0.5, letter, transform=ax.transAxes, 
                 color="black", va="center", ha="right")
    ax.set_axis_off()

for i, label in enumerate(labels):
    hpos = i * (512 + 128) + 256
    axes[-1].text(hpos, vpos, label, color="grey", va="top", ha="center")
    

fig.savefig("../activations.png", dpi=600, bbox_inches="tight")