In [None]:
from dataclasses import dataclass, field
from typing import Callable

import transformer_lens as tl
from plotly import subplots
from torch import Tensor
import torch
import plotly.graph_objects as go
import plotly.express as px

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
config = tl.HookedTransformerConfig(
    n_layers=3,
    d_model=64,
    n_ctx=12,
    d_head=8,
    n_heads=8,
    d_vocab=12,
    act_fn="relu",
    attention_dir="bidirectional",
)

print(config)

# Defining the data distribution

We define two tasks, SORT and REVERSE. An input is of the form "TASK *numbers". With numbers between 0 and 10

In [None]:
import numpy as np
import src
from jaxtyping import Int, Float

def gen_task(batch_size: int, force_task=None) -> tuple[Int[Tensor, "batch token=11"], Int[Tensor, "batch token-1"]]:
    x = np.random.randint(1, 10, size=(batch_size, 10))
    # y_sort = np.sort(x, axis=1)
    # y_sort = np.cumsum(x, axis=1) % 10
    y_sort = np.concatenate([x[:, ::2], x[:, 1::2]], axis=1)
    y_reverse = np.flip(x, axis=1)
    # y_reverse = np.cumprod(x, axis=1) % 10
    
    task = np.random.randint(0, 2, size=(batch_size, 1))
    if force_task is not None:
        task[:] = force_task
    y = np.where(task == 0, y_sort, y_reverse)
    x = np.concatenate([task, x], axis=1)
    
    return torch.tensor(x).float(), torch.tensor(y).float()

@torch.inference_mode()
def get_accuracy(model, n: int = 1000):
    # 2. Compute the accuracy for each task
    xs, ys = gen_task(n)
    y_preds = model(xs)
    y_preds = torch.round(y_preds)
    task_sort = torch.where(xs[:, 0] == 0)[0]
    task_reverse = torch.where(xs[:, 0] == 1)[0]
    acc_sort = (y_preds[task_sort] == ys[task_sort]).float().mean()
    acc_reverse = (y_preds[task_reverse] == ys[task_reverse]).float().mean()
    # print(f"Accuracy for SORT:    {acc_sort:.2f}")
    # print(f"Accuracy for REVERSE: {acc_reverse:.2f}")
    
    return acc_sort.item(), acc_reverse.item()



In [None]:
@dataclass
class Stats:
    losses: list[Float] = field(default_factory=list)
    acc_sorts: list[Float] = field(default_factory=list)
    acc_reverses: list[Float] = field(default_factory=list)
    examples: list[Float] = field(default_factory=list)
    n_examples: int = 0

    def log(self, n_examples, loss, acc_sort=None, acc_reverse=None):
        if acc_sort is None:
            acc_sort = self.acc_sorts[-1]
        if acc_reverse is None:
            acc_reverse = self.acc_reverses[-1]

        self.n_examples += n_examples
        self.examples.append(self.n_examples)
        self.losses.append(loss)
        self.acc_sorts.append(acc_sort)
        self.acc_reverses.append(acc_reverse)
        
    def plot(self, **kwargs):
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=self.examples, y=smooth(self.losses), name="Loss"))
        fig.add_trace(go.Scatter(x=self.examples, y=1-smooth(self.acc_sorts), name="1 - Accuracy for SORT"))
        fig.add_trace(go.Scatter(x=self.examples, y=1-smooth(self.acc_reverses), name="1 - Accuracy for REVERSE"))
        fig.update_layout(
            title="Training metrics",
            xaxis_title="Number of examples",
            yaxis_type="log",
            **kwargs,
        )
        fig.show()

In [None]:
# Training loop
def train(model, optim, task: Callable[[int], tuple]=gen_task, epochs=1000, batch_size=1000, stats: Stats = None, extra_loss = lambda _: 0):
    if stats is None:
        stats = Stats()
    loss_fn = torch.nn.MSELoss()
    for epoch in range(epochs):
        x, y = task(batch_size)
        y_pred = model(x)
        loss = loss_fn(y_pred, y) + extra_loss(model)
        optimizer.zero_grad()
        loss.backward()
        optim.step()
        
        # Logs
        if epoch % 100 == 0:
            acc_sort, acc_reverse = get_accuracy(model)
            stats.log(batch_size, loss.item(), acc_sort, acc_reverse)
            print(f"Epoch {epoch} - Loss: {loss.item():.2f} - Accuracy for SORT: {acc_sort:.2f} - Accuracy for REVERSE: {acc_reverse:.2f}")
        else:
            stats.log(batch_size, loss.item())

In [None]:
model = src.MLP(11, 64, 64, 64, 10, activation=torch.nn.ReLU)
optimizer = torch.optim.Adam(model.parameters())
stats = Stats()

In [None]:
train(model, optimizer, stats=stats)

In [None]:
# Smooth the loss
def smooth(x, kernel_size: int = 100):
    kernel = np.ones(kernel_size) / kernel_size
    return np.convolve(x, kernel, mode="valid")

# logplot of all the metrics: loss, 1-sort accuracy, 1-reverse accuracy, logscale



In [None]:
# Evaluate the model
# 1. Show a few examples
xs, ys = gen_task(10)
y_preds = model(xs)
for x, y, y_pred in zip(xs, ys, y_preds):
    x = x.int().tolist()
    y = y.int().tolist()
    print(f"Input:  {x}")
    print(f"Target: {y}")
    print(f"Pred:   [{', '.join([f'{v:.1f}' for v in y_pred.tolist()])}]")
    print("")
    


In [None]:
# Show all the weights of the model
linears = [layer for layer in model if isinstance(layer, torch.nn.Linear)]
n_layers = len(linears)
fig = subplots.make_subplots(rows=n_layers, cols=2, subplot_titles=[f"Layer {i}" for i in range(n_layers)])

# Ensure they use the same color scale
max_weight = max([layer.weight.max().item() for layer in linears])
min_weight = min([layer.weight.min().item() for layer in linears])

for i, layer in enumerate(linears):
    if isinstance(layer, torch.nn.Linear):
        fig.add_trace(go.Heatmap(
            z=layer.weight.detach().numpy(),
            colorscale="RdBu",
            zmin=min_weight,
            zmax=max_weight,
        ), row=i+1, col=1)
        fig.add_trace(go.Heatmap(
            z=layer.bias.detach().numpy().reshape(-1, 1),
            colorscale="RdBu",
            zmin=min_weight,
            zmax=max_weight,
        ), row=i+1, col=2)
        

fig.update_layout(height=2000, width=1000)
fig.show()


In [None]:
model

In [None]:
from copy import deepcopy


class MaskedLinear(torch.nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.mask = torch.nn.Parameter(torch.rand(self.weight.shape))
    
    def forward(self, x):
        return torch.nn.functional.linear(x, self.mask.clip(0, 1) * self.weight, self.bias)
    
    @classmethod
    def from_linear(cls, linear):
        self = cls(linear.in_features, linear.out_features, linear.bias is not None)
        self.weight.data = linear.weight.data
        if linear.bias is not None:
            self.bias.data = linear.bias.data
        return self
    
def penalise_masks(model, coef=0.01):
    pen = 0
    for name, x in model.named_parameters():
        if "mask" in name:
            pen = pen + (x * (1 - x)).abs().sum()
    return pen * coef


masked_model = deepcopy(model)
for i, layer in enumerate(masked_model):
    if isinstance(layer, torch.nn.Linear):
        masked_model[i] = MaskedLinear.from_linear(layer)
# masked_model = src.L1WeightDecay(masked_model, 0.0001, name_filter="mask")

masked_model

In [None]:
# 1. Set the mask to trainable and the rest to not trainable
for name, param in masked_model.named_parameters():
    if "mask" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
    print(name, param.requires_grad)
        
# 2. train the model only on task 0
optimizer = torch.optim.AdamW([param for name, param in masked_model.named_parameters() if "mask" in name],
                             weight_decay=0)
stats = Stats()


In [None]:
from functools import partial

# Training loopq
train(masked_model, optimizer, 
      epochs=3_000,
      task=lambda bs: gen_task(bs, force_task=1), 
      extra_loss=partial(penalise_masks, coef=0.001),
      stats=stats)

In [None]:
stats.plot()
plot_masks(masked_model)

In [None]:
masked_model_task_0 = deepcopy(masked_model)

In [None]:
masked_model_task_1 = deepcopy(masked_model)

In [None]:
# plot the masks
def plot_masks(*masked_models):
    all_masks = [
        [(name, param.detach().numpy()) for name, param in model.named_parameters() if "mask" in name]
        for model in masked_models
    ]

    fig = subplots.make_subplots(rows=len(all_masks[0]), cols=len(all_masks),
                                    subplot_titles=[name for name, _ in all_masks[0] for _ in masked_models])

    for i, masks in enumerate(all_masks):
        for j, (name, mask) in enumerate(masks):
            fig.add_trace(go.Heatmap(
                z=mask,
                colorscale="RdBu",
                zmin=-1,
                zmax=1,
            ), row=j+1, col=i+1)
            
    fig.update_layout(height=2000, width=1000)
    fig.show()
    
plot_masks(masked_model_task_0, masked_model_task_1)


In [None]:
def compare_masks(masked_model_1, masked_model_2):
    # Plot the intersection, i.e. the minimum of the two masks
    # Then plot the difference, i.e. (A - min) - (B - min) = A - B
    all_masks = [
        [(name, param.detach().numpy()) for name, param in model.named_parameters() if "mask" in name]
        for model in [masked_model_1, masked_model_2]
    ]
    
    fig = subplots.make_subplots(rows=len(all_masks[0]), cols=2,
                                    subplot_titles=[name + suffix for name, _ in all_masks[0] for suffix in [" - min(A, B)", " - A-B"]])
    
    for i, ((_, mask_a), (_, mask_b)) in enumerate(zip(*all_masks)):
        imgs = np.stack([mask_a, np.zeros_like(mask_a), mask_b], axis=-1)
        mask_min = np.minimum(mask_a, mask_b)
        mask_diff = mask_a - mask_b
        fig.add_trace(go.Heatmap(
            z=mask_min,
            colorscale="RdBu",
            zmin=-1,
            zmax=1,
        ), row=i+1, col=1)
        # Show imgs
        fig.add_trace(go.Image(z=imgs * 255), row=i+1, col=2)
        
    fig.update_layout(height=2000, width=1000)
    fig.show()
    
# compare_masks(masked_model_task_0, masked_model_task_1)
compare_masks(masked_model, binary_masked)
    

In [None]:
# Binarize the masks
binary_masked = deepcopy(masked_model)

for name, param in binary_masked.named_parameters():
    if "mask" in name:
        param.data = param.data.clip(0, 1)

        # Sample according to the mask
        param.data = torch.bernoulli(param.data)
        
        # Set to 1 if the mask is > 0.5
        # param.data = (param.data > 0.5).float()
        
plot_masks(binary_masked)
# Check accuracy
get_accuracy(binary_masked)