In [1]:
import torch
import matplotlib.pyplot as plt
import imageio
import numpy as np
import os
torch.manual_seed(1337)

data_dir = '/n/projects/kc2819/projects/ChotaLLM/data/'
assets_dir = '/n/projects/kc2819/projects/ChotaLLM/assets/'

## Self attention

In [2]:
'''

Self attention -----> away for tokens to interact with each other in a sequence. 

Lets say we have a batch (B) of sequences, with each context lentgh or block size (B) of tokens, and each token has a dimension of features, or lets call them channels (C).

B,T,C 

Lets say, ith token in T, wants to know the relationship with all other tokens in T (including itself), one naive way is to take the mean of all channel infrmation of all tokens in T, 
such a way that the information propagates from the 0 to i-1 -> ith token. This summerize the context to current token with context of all other tokens in T in the past.

This is a extreamly simple way to do it, but it is not enough, because it is not considering the relationship between the tokens in the context, but lets worry about it later.


'''

B,T,C = 4,8,2 # batch, tokens, channels
x = torch.randn(B,T,C)
print(x.shape)

torch.Size([4, 8, 2])


In [3]:
# x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(x.size(0)):
    for t in range(x.size(1)):
        prev = x[b,:t+1]
        xbow[b,t] = torch.mean(prev, dim=0)

In [4]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [12]:
xbow[0]

tensor([[ 0.1275, -0.0560],
        [ 0.4795, -0.3036],
        [ 0.6689,  0.3372],
        [ 0.6057,  0.3369],
        [ 0.3943,  0.1296],
        [ 0.4821, -0.0581],
        [ 0.5984, -0.3702],
        [ 0.5895, -0.3438]])

In [30]:
B, T, C = 10, 8, 2
x = torch.randn(B, T, C)
xbow = torch.zeros((B, T, C))


def plot_frame(x, xbow, b, t, frame_number, save_dir, vmin, vmax):
    fig, ax = plt.subplots(2, 1, figsize=(10, 10))
    axes = ax.flatten()

    
    axes[0].set_title(f'Original Data (b={b}, t={t})')
    cax0 = axes[0].imshow(x[b].detach().numpy().T, aspect='auto', cmap='rainbow', vmin=vmin, vmax=vmax)
    fig.colorbar(cax0, ax=axes[0])
    for i in range(x.size(1)):
        for j in range(x.size(2)):
            axes[0].text(i, j, f'{x[b, i, j]:.2f}', ha='center', va='center', color='white' if x[b, i, j] < 0 else 'black')

    
    axes[1].set_title(f'Attention Result (b={b}, t={t})')
    cax1 = axes[1].imshow(xbow[b].detach().numpy().T, aspect='auto', cmap='rainbow', vmin=vmin, vmax=vmax)
    fig.colorbar(cax1, ax=axes[1])
    for i in range(x.size(1)):
        for j in range(x.size(2)):
            axes[1].text(i, j, f'{xbow[b, i, j]:.2f}', ha='center', va='center', color='white' if xbow[b, i, j] < 0 else 'black')

    fig.tight_layout()
    fig.savefig(f"{save_dir}/frame_{frame_number}.png")
    plt.close(fig)

    for ax in axes:
        ax.set_yticks(range(x.size(2)))
        ax.set_yticklabels([f'{i}' for i in range(x.size(2))], fontsize=12, va='center', fontweight='bold', fontfamily='serif')
        ax.set_xticks(range(x.size(1)))
        for s in ['top', 'bottom', 'left', 'right']:
            ax.spines[s].set_visible(False)

img_dir = assets_dir + "/images"
gif_dir = assets_dir + "/gifs"
os.makedirs(img_dir, exist_ok=True)
os.makedirs(gif_dir, exist_ok=True)

vmin, vmax = -1, 1
frame_number = 0
for b in range(x.size(0)):
    for t in range(x.size(1)):
        prev = x[b, :t+1]
        xbow[b, t] = torch.mean(prev, dim=0)
        plot_frame(x, xbow, b, t, frame_number, save_dir=img_dir, vmin=vmin, vmax=vmax)
        frame_number += 1

with imageio.get_writer(f"{gif_dir}/attention.gif", mode='I', duration=0.1) as writer:
    for i in range(frame_number):
        image = imageio.imread(f"{img_dir}/frame_{i}.png")
        writer.append_data(image)
        os.remove(f"{img_dir}/frame_{i}.png")


  image = imageio.imread(f"{img_dir}/frame_{i}.png")
