In [None]:
from train import get_model

In [None]:
import yaml
import torch
import os
import train
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
directory = "output/gaussian_cnn_4l_16h_alpha_pos_init"

In [None]:
config = yaml.load(open(os.path.join(directory, "config.yaml"), "r"))
train.config = config
model = get_model('cuda:0')


In [None]:
checkpoint_data = torch.load(os.path.join(directory, "best.checkpoint"))

In [None]:
model.load_state_dict(checkpoint_data['model_state_dict'])

In [None]:
alpha_per_layer = [layer.attention.self.attention_alpha.detach().cpu() for layer in model.encoder.layer]
centers_per_layer = [layer.attention.self.attention_centers.detach().cpu() for layer in model.encoder.layer]

In [None]:
def softmax(X):
    # softmax over all dim but last one
    max_per_head = X.view(-1, X.shape[-1]).max(dim=0)[0]
    X -= max_per_head
    exp_X = X.exp()
    normalizer = exp_X.view(-1, X.shape[-1]).sum(dim=0)
    Y = exp_X / normalizer.view(1,1,-1)
    return Y

MAX_WIDTH_HEIGHT = 50
range_ = torch.arange(MAX_WIDTH_HEIGHT)
grid = torch.cat([t.unsqueeze(-1) for t in torch.meshgrid([range_, range_])], dim=-1)
relative_indices = grid.unsqueeze(0).unsqueeze(0) - grid.unsqueeze(-2).unsqueeze(-2)
R = torch.cat([relative_indices, relative_indices ** 2, torch.ones_like(relative_indices)], dim=-1)
R = R.float()

def plot_attention_positions(relative_positions, alphas, width=20, ax=None):
    relative_encoding_from_center = R[width // 2, width // 2, :width, :width,]
    targets = torch.cat([-2 * relative_positions, torch.ones_like(relative_positions), relative_positions ** 2], dim=-1)
    
    attention_scores = torch.einsum('ijd,hd->ijh', [relative_encoding_from_center, targets])
    attention_scores /= - alphas.view(1,1,-1) # rescaling
    attention_probs = softmax(attention_scores)
    
    if ax is None:
        fig, ax = plt.subplots()
    
    n_head = len(alphas)
    for i in range(n_head):
        cs = ax.contour(attention_probs[:,:,i], levels=[0.1, 0.4], colors=f"C{i}")
        
    ax.set_xticks([], [])
    ax.set_yticks([], [])
    ax.set_aspect(aspect=1)
    ax.scatter([width//2], [width//2], c='r', zorder=2)
    
# plot_attention_positions(gaussian_shifts, gaussian_alpha, width=20)

In [None]:
fig, axes = plt.subplots(1, len(layers), figsize=(24,6))

for i in range(len(layers)):
    plot_attention_positions(centers_per_layer[i], alpha_per_layer[i].exp(), width=32, ax=axes[i])
    axes[i].set_title(f"Layer {i}")