## In-context learning with Transformers

---

Paper: https://arxiv.org/pdf/2208.01066.pdf

Official Repository: https://github.com/dtsip/in-context-learning

In [None]:
! pip install transformers

In [None]:
! pip install wandb

In [None]:
import torch
import torch.nn as nn
import random
import numpy as np
import wandb

from tqdm.notebook import tqdm
from transformers import GPT2Model, GPT2Config
from google.colab import drive

# fix seeds for reproducbility
def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_random_seed()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
drive.mount('/content/drive')

Mounted at /content/drive


### Model

In [None]:
class InContextModel(nn.Module):
    def __init__(self, n_dims, n_emb=256, n_layer=12, n_head=4):
        super().__init__()

        config = GPT2Config(
            n_embd=n_emb,
            n_layer=n_layer, 
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
        )

        self.in_fc = nn.Linear(n_dims, n_emb)
        self.gpt = GPT2Model(config)
        self.out_fc = nn.Linear(n_emb, 1)

    @staticmethod
    def _combine(xs, ys):
        batch, points, dim = xs.shape
        ys_wide = torch.cat(
            (
                ys.view(batch, points, 1),
                torch.zeros(batch, points, dim - 1, device=ys.device),
            ),
            axis=2,
        )
        zs = torch.stack((xs, ys_wide), dim=2).view(batch, -1, dim)
        return zs

    def forward(self, xs, ys, inds=None):
        if inds is None:
            inds = torch.arange(ys.shape[1])
        else:
            inds = torch.tensor(inds)

        zs = self._combine(xs, ys)
        output = self.gpt(inputs_embeds=self.in_fc(zs)).last_hidden_state
        predictions = self.out_fc(output)
        return predictions[:, ::2, 0][:, inds]

### Sample generation

In [None]:
class LinearSampler():
    def __init__(self, batch, dim):
        self.w = torch.randn(batch, dim, 1)

    def eval(self, xs):
        ws = self.w.to(xs.device)
        return xs @ ws # batch x points x 1

class LinearSparseSampler():
    def __init__(self, batch, dim, k=5):
        self.w = torch.randn(batch, dim, 1)
        for i in range(batch):
            idx = torch.randperm(dim)[:dim - k]
            self.w[i, idx, :] = 0

    def eval(self, xs):
        ws = self.w.to(xs.device)
        return xs @ ws # batch x points x 1

class ReLUSamlpler():
    def __init__(self, batch, dim, r=100):
        self.w = torch.randn(batch, dim, r)
        self.a = torch.randn(batch, r, 1) * np.sqrt(2 / r)
        self.activation = nn.ReLU()

    def eval(self, xs):
        ws = self.w.to(xs.device)
        return (self.activation(xs @ ws)) @ self.a.to(xs.device)

### Training

In [None]:
def train_step(model, xs, ys, optimizer, loss_func):
    optimizer.zero_grad()
    output = model(xs, ys)
    loss = loss_func(output, ys.squeeze())
    loss.backward()
    optimizer.step()
    return loss.detach().item()

def train_epoch(model, sampler_class, optimizer, loss_func, steps, batch, dim, points, device, epoch, cum_steps=5):
    cum_loss = 0
    for step in range(steps):
        xs = torch.randn(batch, points, dim).to(device)
        sampler = sampler_class(batch, dim)
        ys = sampler.eval(xs)

        loss = train_step(model, xs, ys, optimizer, loss_func)
        cum_loss += loss

        if step % cum_steps == cum_steps - 1:
            wandb.log({
                "step" : epoch + step / steps,
                "train_loss" : cum_loss / cum_steps
            })
            cum_loss = 0

def train(model, sampler_class, optimizer, loss_func, steps, batch, dim, points, device, epochs, cum_steps=5, max_point=41, delta=10, name="base"):
    model.to(device)
    model.train()
    for epoch in tqdm(list(range(epochs))):
        train_epoch(model, sampler_class, optimizer, loss_func, steps, batch, dim, points, device, epoch, cum_steps)
        torch.save(model.state_dict(), f'/content/drive/MyDrive/InContext/model_{epoch}.pt')
        torch.save(optimizer.state_dict(), f'/content/drive/MyDrive/InContext/optimizer_{name}_{epoch}.pt')
        if points < max_point:
            points += delta
            points = min(points, max_point)

In [None]:
set_random_seed()

wandb.init(
    project="In context transformer",
    name="Sparse Sampler"
)

model = InContextModel(20, 128, 6, 4)
sampler_class = LinearSparseSampler
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train(model, sampler_class, optimizer, loss_func, 10000, 64, 20, 10, device, 10, 100, delta=5)

wandb.finish()

  0%|          | 0/10 [00:00<?, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.083912…

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,███████▇▇▆▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
step,9.9999
train_loss,1.28555


In [None]:
set_random_seed()

wandb.init(
    project="In context transformer",
    name="ReLU Sampler"
)

model = InContextModel(20, 128, 6, 4)
sampler_class = ReLUSamlpler
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train(model, sampler_class, optimizer, loss_func, 10000, 64, 20, 25, device, 10, 100, max_point=101, name="relu", delta=10)

wandb.finish()

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,███████▇▇██▇▇▇▇▆▅▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁

0,1
step,3.0699
train_loss,10.13741


  0%|          | 0/10 [00:00<?, ?it/s]

Графики лосса:

https://wandb.ai/jakokorina/In%20context%20transformer?workspace=user-jakokorina

Итоговый лосс можно считать равным последней отметке на графике поскольку нет обучающей выборки и каждый раз функции генерятся новые.