# reproduce figure 

objective: THE EFFECTS OF PRETRAINING TASK DIVERSITY ON IN-CONTEXT LEARNING OF RIDGE REGRESSION

1. dataset construction

In [31]:
# import libs
import random
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [60]:
from omegaconf import OmegaConf
Args = {
    'dataset' : {
        'sigma'     : 0.1,
        'num_tasks' : 32,
        'dims'      : 1,
        'k'         : 16, # incontext sample
        'batch_size': 16,
    },
    'model' : {
        'emb_dims'  : 64,
    },
    'optim' : {
        'epochs' : int(1e4),
        'lr'     : 3e-5,
    }
}
args = OmegaConf.structured(Args)

In [61]:
# dataset
from einops import einsum
class Dataset:
    def __init__(self, args):
        assert args.dataset.dims == 1
        self.task_vectors = args.dataset.sigma * torch.randn(
            args.dataset.num_tasks, args.dataset.dims
        )
        self.dims = args.dataset.dims
        self.k = args.dataset.k

        self.eval_task_vector = 5 * args.dataset.sigma * torch.randn(
            args.dataset.dims
        )

    def __getitem__(self, index):
        # generate toy dataset
        x = torch.randn(self.k, self.dims)
        y = x @ self.task_vectors[index]
        label = y[-1]
        y[-1] = 0
        
        # organize 
        input = torch.empty(self.k, 2*self.dims)
        input[:, :self.dims] = x
        input[:, self.dims:] = y[:, None]
        return input, label

    def __len__(self):
        return len(self.task_vectors)

ds = Dataset(args)
dl = torch.utils.data.DataLoader(
    ds, batch_size=args.dataset.batch_size, pin_memory=True, num_workers=16
)

In [62]:
# model
class Model(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.embs = args.model.emb_dims
        self.emb = nn.Linear(2*args.dataset.dims, self.embs, bias=False)
        self.attn = Attention(self.embs)
        self.out = nn.Linear(self.embs, args.dataset.dims, bias=False)

    def forward(self, input):
        h = self.emb(input)
        h = self.attn(h)
        h = self.out(h)
        return h

class Attention(nn.Module):
    def __init__(self, embs):
        super().__init__()
        self.to_q = nn.Linear(embs, embs, bias=False)
        self.to_k = nn.Linear(embs, embs, bias=False)
        self.to_v = nn.Linear(embs, embs, bias=False)
        # self.to_o = nn.Linear(embs, embs, bias=False)

    def forward(self, x, is_causal=True):
        # x : |batch_size, seq_len, hid_dim|
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # attention mask
        L, S = q.size(-2), k.size(-2)
        attn_bias = torch.zeros(L, S, dtype=q.dtype)
        if is_causal:
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias = attn_bias.to(q.device, q.dtype)

        # attention calc.
        attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
        attn_weight += attn_bias
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_output = attn_weight @ v

        return attn_output

model = Model(args)

In [64]:
from tqdm import tqdm
device='cuda'

def eval(ds, num_eval_samples=512):
    x = torch.randn(num_eval_samples, ds.k, ds.dims)
    y = x @ ds.eval_task_vector
    label = y[:, -1, :]
    
    # organize 
    input = torch.empty(num_eval_samples, ds.k, 2*ds.dims)
    input[num_eval_samples, :, :ds.dims] = x
    input[num_eval_samples, :, ds.dims:] = y[num_eval_samples, :, None]
    return input, label

# train 
optim = torch.optim.Adam(model.parameters(), lr=args.optim.lr)
model = model.to(device)
loss_traj = []
val_loss_traj = []
for epoch in tqdm(range(args.optim.epochs)):
    model.train()
    for input, label in dl:
        output = model(input.to(device))[:, -1].squeeze()
        loss = F.mse_loss(output, label.to(device))
        loss.backward()
        loss_traj.append(loss.item())

    # validation 
    if epoch % 100 and epoch != 0:
        input, label = eval(ds)
        input, label = input.to(device), label.to(device)
        output = model(input)[:, -1].squeeze()
        loss = F.mse_loss(output, label.to(device))
        val_loss_traj.append(loss.item())
        assert False

  0%|          | 1/10000 [00:11<31:41:48, 11.41s/it]


IndexError: too many indices for tensor of dimension 2