In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
import math
import os

In [2]:
!nvidia-smi

Wed Mar 13 12:56:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               On  | 00000000:21:00.0 Off |                  Off |
| 72%   87C    P2             199W / 200W |  24828MiB / 49140MiB |    100%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               On  | 00000000:41:00.0 Off |  

In [37]:
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torch
import numpy as np
import math

class SamplingDataset(object):
  def __init__(self,conf):
    self.num_classes = conf.num_classes
    self.dim = conf.dim
    self.num_labels = conf.num_labels
    self.mu, self.labels = self._get_data()

  def _get_data(self):
    mu = torch.normal(mean=0, std=math.sqrt(1/self.dim), size=(self.num_classes,self.dim))
    labels = torch.randint(self.num_labels, size=(self.num_classes,1))
    return mu, labels

class MultiTaskSamplingLoader(DataLoader):

  def __init__(self,conf, dataset):
    self.dataset = dataset
    self.mu, self.labels = self.dataset.mu, self.dataset.labels
    self.data_type = conf.data_type
    self.num_seq = conf.num_seq
    self.alpha = conf.alpha
    self.num_classes = conf.num_classes
    self.num_task = conf.num_tasks
    self.num_labels = conf.num_labels
    self.task_ways = conf.task_ways
    self.item_ways = conf.item_ways
    self.p_bursty = conf.p_bursty
    self.p_icl = conf.p_icl
    self.eps = conf.eps
    self.dim = conf.dim
    if self.item_ways != 0 or self.task_ways != 0:
      assert self.num_seq % self.item_ways == 0 and self.num_seq % self.task_ways == 0
    if self.item_ways == 0 or self.task_ways == 0:
      self.p_bursty = 0
    prob = np.array([1/((k+1)**self.alpha) for k in range(self.num_classes)])
    self.prob = prob/prob.sum()

  def get_seq(self):
    while True:
      if self.data_type=="bursty":
        if self.p_bursty > np.random.rand():
          # choise few shot tasks
          num_few_shot_task = self.num_seq//self.task_ways
          few_shot_task = np.random.choice(self.num_task, num_few_shot_task, replace=False)
          tasks = np.repeat(few_shot_task, self.task_ways, axis=0).reshape(-1,1)
          # print(tasks.shape)
          
          # choise few shot items
          num_few_shot_class = self.num_seq//self.item_ways
          few_shot_class = np.random.choice(self.num_classes, num_few_shot_class, replace=False)
          mus = self.mu[few_shot_class]
          mus = np.repeat(mus, self.item_ways, axis=0) # expand ways
          
          # choice few shot labels
          labels = self.labels[few_shot_class]
          labels = np.repeat(labels, self.item_ways, axis=0) # expand ways
        
          
          # classes 
          classes = np.repeat(few_shot_class, self.item_ways)
          # add noise
          x = self.add_noise(mus)
          # permutation shuffle
          ordering = np.random.permutation(self.num_seq)
          x = x[ordering]
          labels = labels[ordering]
          classes = classes[ordering]
          task_ordering = np.random.permutation(self.num_seq)
          tasks = tasks[task_ordering]
          
          labels = (labels + tasks) % self.num_labels
          
          # select query labels
          query_class = np.random.choice(few_shot_class, 1)
          query_task = np.random.choice(few_shot_task, 1)
          query_label = (self.labels[query_class] + query_task) % self.num_labels
          query_mu = self.mu[query_class]
          query_x = self.add_noise(query_mu)
          # concat
          x = torch.cat([x, query_x])
          labels = torch.cat([labels.flatten(), query_label.flatten()])
          tasks = torch.cat([torch.tensor(tasks).flatten(), torch.tensor(query_task).flatten()])
          
          yield {
              "tasks":tasks,
              "examples":x.to(torch.float32),
              "labels":labels,
              "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
          }
          
        else:
          # rank frequency
          num_few_shot_task = self.num_seq//self.task_ways
          few_shot_task = np.random.choice(self.num_task, num_few_shot_task, replace=False)
          tasks = np.repeat(few_shot_task, self.task_ways, axis=0).reshape(-1,1)
          
          classes = np.random.choice(self.num_classes, self.num_seq+1, p=self.prob)
          mus = self.mu[classes]
          labels = self.labels[classes]
          x = self.add_noise(mus)
          # permutation shuffle
          ordering = np.random.permutation(self.num_seq+1)
          x = x[ordering]
          labels = labels[ordering]
          labels = (labels + tasks) % self.num_labels
          classes = classes[ordering]
          tasks = tasks[ordering]
          

          yield {
              "tasks":tasks,
              "examples":x.to(torch.float32),
              "labels":labels.flatten(),
              "classes" : torch.from_numpy(classes)
          }

      elif self.data_type == "no_support":
          num_few_shot_task = self.num_seq//self.task_ways
          few_shot_task = np.random.choice(self.num_task, num_few_shot_task, replace=False)
          tasks = np.repeat(few_shot_task, self.task_ways, axis=0).reshape(-1,1)
        
          # rank frequency
          classes = np.random.choice(self.num_classes, self.num_seq, p=self.prob)
          mus = self.mu[classes]
          # random label
          labels = np.random.randint(self.num_labels, size=(self.num_seq,1))
          x = self.add_noise(mus)
          # permutation shuffle
          ordering = np.random.permutation(self.num_seq)
          x = x[ordering]
          labels = labels[ordering]
          classes = classes[ordering]
          tasks = tasks[ordering]
          
          # select query labels
          query_class = np.random.choice(self.num_classes, 1)
          query_task = np.random.choice(few_shot_task, 1)
          query_label = self.labels[query_class]
          query_label = (query_label + query_task) % self.num_labels
          query_mu = self.mu[query_class]
          query_mu = self.add_noise(query_mu)
          
          # concat
          x = torch.cat([x, query_mu])
          labels = torch.cat([torch.from_numpy(labels).flatten(), query_label.flatten()])
          tasks = torch.cat([torch.tensor(tasks).flatten(), torch.tensor(query_task).flatten()])
          classes = np.concatenate([classes, query_class])

          yield {
              "tasks": tasks,
              "examples":x.to(torch.float32),
              "labels":labels.flatten(),
              "classes" : torch.from_numpy(classes)
          }
          
      elif self.data_type == "holdout":
        # choise few shot tasks
        num_few_shot_task = self.num_seq//self.task_ways
        few_shot_task = np.random.choice(self.num_task, num_few_shot_task, replace=False)
        tasks = np.repeat(few_shot_task, self.task_ways, axis=0).reshape(-1,1)
        false_tasks = np.random.choice(self.num_task, 1, replace=False)
        # print(tasks.shape)
        
        # choise few shot items
        num_few_shot_class = self.num_seq//self.item_ways
        few_shot_class = np.random.choice(self.num_classes, num_few_shot_class, replace=False)
        mus = self.mu[few_shot_class]
        mus = np.repeat(mus, self.item_ways, axis=0) # expand ways
        
        # choice few shot labels
        labels = self.labels[few_shot_class]
        labels = np.repeat(labels, self.item_ways, axis=0) # expand ways
        
        classes = np.repeat(few_shot_class, self.item_ways)
        
        # add noise
        x = self.add_noise(mus)
        # permutation shuffle
        ordering = np.random.permutation(self.num_seq)
        mus = mus[ordering]
        x = x[ordering]
        labels = labels[ordering]
        classes = classes[ordering]
        tasks = tasks[ordering]
        
        labels = (labels + tasks) % self.num_labels
        
        # select query labels
        query_class = np.random.choice(few_shot_class, 1)
        query_task = np.random.choice(few_shot_task, 1)
        query_label = (self.labels[query_class] + query_task) % self.num_labels
        query_mu = self.mu[query_class]
        query_x = self.add_noise(query_mu)
        # concat
        x = torch.cat([x, query_x])
        labels = torch.cat([labels.flatten(), query_label.flatten()])
        tasks = torch.cat([torch.tensor(tasks).flatten(), torch.tensor(false_tasks).flatten()])
          
        
        
        yield {
            "tasks":tasks,
            "examples":x.to(torch.float32),
            "labels":labels,
            "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
        }

      elif self.data_type == "flip":
        # choise few shot example
        num_few_shot_class = self.num_seq//self.ways
        few_shot_class = np.random.choice(self.num_classes, num_few_shot_class, replace=False)
        mus = self.mu[few_shot_class]
        mus = np.repeat(mus, self.ways, axis=0) # expand ways
        classes = np.repeat(few_shot_class, self.ways)
        # label flip
        labels = (self.labels[classes] + 1) % self.num_labels
        # add noise
        x = self.add_noise(mus)
        # permutation shuffle
        ordering = np.random.permutation(self.num_seq)
        x = x[ordering]
        labels = labels[ordering]
        classes = classes[ordering]
        # select query labels
        query_class = np.random.choice(few_shot_class, 1)
        query_label = (self.labels[query_class] + 1) % self.num_labels
        query_mu = self.mu[query_class]
        query_x = self.add_noise(query_mu)
        # concat
        x = torch.cat([x, query_x])
        labels = torch.cat([labels.flatten(), query_label.flatten()])
        
        yield {
            "examples":x.to(torch.float32),
            "labels":labels,
            "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
        }
    
  

  def add_noise(self, x):
    x = (x+self.eps*torch.normal(mean=0, std=math.sqrt(1/self.dim), size=(x.shape)))/(np.sqrt(1+self.eps**2))
    # x = (x+self.eps*np.random.normal(mean=0, std=np.sqrt(1/self.dim), size=(x.shape[0],1)))/(np.sqrt(1+self.eps**2))
    return x
  
  def _get_novel_class_seq(self,num_class):
    mu = torch.normal(mean=0, std=math.sqrt(1/self.dim), size=(num_class,self.dim))
    labels = torch.randint(self.num_labels, size=(num_class,1))
    return mu, labels

class IterDataset(IterableDataset):
    def __init__(self, generator):
        self.generator = generator


    def __iter__(self):
        return self.generator()


In [38]:
from dataclasses import dataclass, asdict

@dataclass
class TransformerConfig:
  num_layers: int = 2
  d_vocab: int = 32
  d_model: int = 128
  d_mlp: int = 128
  d_head: int = 128
  num_heads: int = 1
  # n_ctx: int = int((9*2+1)*3-1)
  n_ctx: int = int(9*2)
  act_type: str = "ReLU"
  use_cache: bool = False
  use_ln: bool = True
  p_dim: int = 65
  d_emb: int = 128
  num_tasks:int = 3
  num_seq_per_task:int = 8

@dataclass
class TrainDataConfig:
  num_classes: int = 10
  dim: int = 63
  num_labels: int = 6
  eps: float = 0.1
  alpha: float = 0
  item_ways: int = 1
  num_seq: int = 8
  p_bursty: float = 1
  data_type: str = "bursty" # bursty, holdout, no_support, flip
  num_holdout_classes: int = 2
  num_tasks: int = 3
  task_ways: int = 8
  p_icl=0

@dataclass
class IWLDataConfig(TrainDataConfig):
  data_type: str = "no_support" # bursty, holdout, no_support, flip

@dataclass
class ICLDataConfig(TrainDataConfig):
  data_type: str = "holdout" # bursty, holdout, no_support, flip


@dataclass
class ICL2DataConfig(TrainDataConfig):
  data_type: str = "flip" # bursty, holdout, no_support, flip
  
@dataclass
class TrainConfig:
  batch_size: int = 3
  optimize_step: int = int(5e4)
  lr: float = 0.01
  optimizer: str = "sgd" # adam, sgd, adamw

@dataclass
class MainConfig:
  traindataconfig : TrainDataConfig = TrainDataConfig()
  icldataconfig: ICLDataConfig = ICLDataConfig()
  iwldataconfig: IWLDataConfig = IWLDataConfig()
  icl2dataconfig: ICL2DataConfig = ICL2DataConfig()
  modelconfig: TransformerConfig = TransformerConfig()
  trainconfig: TrainConfig = TrainConfig()
  device: str = "cuda:0"
# define config

In [39]:
traindataconfig = MainConfig.traindataconfig
icldataconfig = MainConfig.icldataconfig
iwldataconfig = MainConfig.iwldataconfig
icl2dataconfig = MainConfig.icl2dataconfig
trainconfig = MainConfig.trainconfig

Dataset = SamplingDataset(traindataconfig)

trainloader = MultiTaskSamplingLoader(traindataconfig, dataset=Dataset)
train_seq_generator = trainloader.get_seq
train_dataset = IterDataset(train_seq_generator)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iclloader = MultiTaskSamplingLoader(icldataconfig, dataset=Dataset)
icl_seq_generator = iclloader.get_seq
icl_dataset = IterDataset(icl_seq_generator)
icl_dataloader = torch.utils.data.DataLoader(icl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iwlloader = MultiTaskSamplingLoader(iwldataconfig, dataset=Dataset)
iwl_seq_generator = iwlloader.get_seq
iwl_dataset = IterDataset(iwl_seq_generator)
iwl_dataloader = torch.utils.data.DataLoader(iwl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

icl2loader = MultiTaskSamplingLoader(icl2dataconfig, dataset=Dataset)
icl2_seq_generator = icl2loader.get_seq
icl2_dataset = IterDataset(icl2_seq_generator)
icl2_dataloader = torch.utils.data.DataLoader(icl2_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

print("Train")
cnt = 0
for data in train_dataloader:
    # print(data.shape)
    task = data["tasks"]
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    # print(examples)
    print("task\n", task)
    print("train_class\n", classes)
    print("train_label\n", labels)
    cnt += 1
    if cnt > 0:
        break
print("ICL")
cnt = 0
for data in icl_dataloader:
    task = data["tasks"]
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    # print(examples)
    print("task\n", task)
    print("train_class\n", classes)
    print("train_label\n", labels)
    cnt += 1
    if cnt > 0:
        break
print("IWL")
cnt = 0
for data in iwl_dataloader:
    task = data["tasks"]
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    print("example",examples.shape)
    print("task\n", task)
    print("train_class\n", classes)
    print("train_label\n", labels)
    cnt += 1
    if cnt > 0:
        break
# print("ICL2")
# cnt = 0
# for data in icl2_dataloader:
#     task = data["task"]
#     examples = data["examples"]
#     labels = data["labels"]
#     classes = data["classes"]
#     # print(examples)
#     print("task", task)
#     print("train_class", classes)
#     print("train_label", labels)
#     cnt += 1
#     if cnt > 0:
#         break


Train
task
 tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 2, 2, 2],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]])
train_class
 tensor([[3, 5, 4, 9, 0, 6, 7, 2, 3],
        [4, 6, 5, 3, 1, 8, 2, 7, 4],
        [6, 4, 5, 9, 8, 2, 3, 1, 3]])
train_label
 tensor([[2, 1, 5, 4, 2, 4, 0, 5, 2],
        [5, 4, 1, 2, 2, 5, 5, 0, 5],
        [3, 4, 0, 3, 4, 4, 1, 1, 1]])
ICL
task
 tensor([[2, 2, 2, 2, 2, 2, 2, 2, 0],
        [2, 2, 2, 2, 2, 2, 2, 2, 2],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]])
train_class
 tensor([[1, 4, 7, 9, 5, 3, 2, 8, 3],
        [9, 7, 6, 2, 3, 4, 8, 5, 3],
        [4, 3, 9, 5, 7, 0, 6, 1, 5]])
train_label
 tensor([[2, 5, 0, 4, 1, 2, 5, 5, 2],
        [4, 0, 4, 5, 2, 5, 5, 1, 2],
        [4, 1, 3, 0, 5, 1, 3, 1, 0]])
IWL
example torch.Size([3, 9, 63])
task
 tensor([[2, 2, 2, 2, 2, 2, 2, 2, 2],
        [1, 1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2, 2]])
train_class
 tensor([[2, 4, 7, 1, 3, 9, 5, 5, 7],
        [9, 6, 7, 6, 2, 8, 0, 2, 2],
        [1, 7,

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops


class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W = nn.Parameter(torch.randn(d_vocab, d_model) / np.sqrt(d_model))

    def forward(self, x):
        return torch.einsum("pe,bse->bsp", self.W, x)


class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []

    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name

    def add_hook(self, hook, dir="fwd"):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output,
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)

        if dir == "fwd":
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir == "bwd":
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")

    def remove_hooks(self, dir="fwd"):
        if (dir == "fwd") or (dir == "both"):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir == "bwd") or (dir == "both"):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ["fwd", "bwd", "both"]:
            raise ValueError(f"Invalid direction {dir}")

    def forward(self, x):
        return x


class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model, weight_scale=1):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model) * weight_scale)

    def forward(self, x):
        return x + self.W_pos[: x.shape[-2]]


class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon=1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon

    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x


# Attention
class Attention(nn.Module):
    """
    b : batch size
    d : embedding size of token
    p : vocabraly size (113 or 3)
    i : number of heads
    h : embedding size of each heads
    n_ctx : token size
    """

    def __init__(self, d_model, num_heads, d_head, n_ctx):
        super().__init__()
        self.W_K = nn.Parameter(
            torch.randn(num_heads, d_head, d_model) / np.sqrt(d_model)
        )
        self.W_Q = nn.Parameter(
            torch.randn(num_heads, d_head, d_model) / np.sqrt(d_model)
        )
        self.W_V = nn.Parameter(
            torch.randn(num_heads, d_head, d_model) / np.sqrt(d_model)
        )
        self.W_O = nn.Parameter(
            torch.randn(d_model, d_head * num_heads) / np.sqrt(d_model)
        )
        self.register_buffer("mask", torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head

    def forward(self, x):
        print(self.W_K.shape, x.shape)
        print(self.W_K.device, x.device)
        # torch.Size([1, 128, 128]) torch.Size([3, 18, 128])
        # cuda:1 cuda:1
        k = torch.einsum("ihd,bpd->biph", self.W_K, x)
        q = torch.einsum("ihd,bpd->biph", self.W_Q, x)
        v = torch.einsum("ihd,bpd->biph", self.W_V, x)
        attn_scores_pre = torch.einsum("biph,biqh->biqp", k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (
            1 - self.mask[: x.shape[-2], : x.shape[-2]]
        )
        attn_matrix = F.softmax(attn_scores_masked / np.sqrt(self.d_head), dim=-1)
        z = torch.einsum("biph,biqp->biqh", v, attn_matrix)
        z_flat = einops.rearrange(z, "b i q h -> b q (i h)")
        out = torch.einsum("df,bqf->bqd", self.W_O, z_flat)
        return out


class Dense(nn.Module):
    def __init__(self, d_in, d_out, act_type, weight_scale=1):
        super().__init__()
        self.W = nn.Parameter(torch.randn(d_out, d_in))
        torch.nn.init.normal_(self.W, mean=0, std=weight_scale / np.sqrt(d_in))

    def set_weight_ratio(self, weight_ratio):
        self.W = nn.Parameter(self.W * weight_ratio)

    def set_weight_ratio_l2(self, weight_ratio):
        self.W = nn.Parameter(self.W * torch.sqrt(weight_ratio))

    def forward(self, x):
        return x @ self.W.T


# for Transformer
class MLPBlock(nn.Module):
    """
    b : batch size
    d : embedding size of token
    p : vocabraly size (114 or 3)
    i : number of heads
    h : embedding size of each heads
    """

    def __init__(self, d_model, d_mlp, act_type):
        super().__init__()
        # bias & layer norm are removed.
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model) / np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp) / np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        assert act_type in ["ReLU", "GeLU"]

    def forward(self, x):
        x = torch.einsum("md,bpd->bpm", self.W_in, x) + self.b_in
        if self.act_type == "ReLU":
            x = F.relu(x)
        elif self.act_type == "GeLU":
            x = F.gelu(x)
        x = torch.einsum("dm,bpm->bpd", self.W_out, x) + self.b_out
        return x

    def set_weight_ratio(self, weight_ratio):
        self.W_in = nn.Parameter(self.W_in * weight_ratio)
        self.W_out = nn.Parameter(self.W_out * weight_ratio)


class TransformerBlock(nn.Module):
    """
    b : batch size
    d : embedding size of token
    p : vocabraly size
    i : number of heads
    h : embedding size of each heads
    """

    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.model = model
        self.attn = Attention(d_model, num_heads, d_head, n_ctx)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLPBlock(d_model, d_mlp, act_type)
        self.layer_norm = LayerNorm(d_model, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def forward(self, x):
        x = self.hook_resid_mid(
            x + self.hook_attn_out(self.attn((self.hook_resid_pre(x))))
        )
        x = self.layer_norm(x)
        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp((x))))
        return x

    def set_weight_ratio(self, weight_ratio):
        self.attn.set_weight_ratio(weight_ratio)
        self.mlp.set_weight_ratio(weight_ratio)


class MultiTaskInputEmbedderV1(nn.Module):
    """Input embedder."""

    def __init__(self, conf):

        """Initialize the input embedder.

    Args:
      num_classes: Total number of output classes.
      emb_dim: Dimensionality of example and label embeddings.
      example_encoding: How to encode example inputs.
        'resnet': simple resnet encoding
        'linear': flatten and pass through a linear layer
        'embedding': pass through an embedding layer
      flatten_superpixels: Whether to flatten the output of the resnet (instead
        of taking a mean over superpixels).
      example_dropout_prob: Dropout probability on example embeddings. Note that
        these are applied at both train and test.
      concatenate_labels: Whether to concatenate example and label embeddings
        into one token for each (example, label) pair, rather than being fed to
        the transformer as two separate tokens.
      use_positional_encodings: Whether to use positional encoding.
      positional_dropout_prob: Positional dropout probability.
      name: Optional name for the module.
    """
        super(MultiTaskInputEmbedderV1, self).__init__()
        self._num_classes = conf.d_vocab
        self._emb_dim = conf.d_emb
        self.p_dim = conf.p_dim
        self.num_tasks = conf.num_tasks
        self.num_seq_per_task = conf.num_seq_per_task

        self.Emb = nn.Linear(self._emb_dim, self._emb_dim)

        self.label_embs = nn.Parameter(
            torch.randn(self._num_classes, self._emb_dim) / np.sqrt(self._emb_dim)
        )
        
        self.task_embs = nn.Parameter(
            torch.randn(self.num_tasks, self._emb_dim) / np.sqrt(self._emb_dim)
        )
    
    def forward(self, examples, labels, tasks):
        """_summary_

        Args:
            examples (_type_): _description_
            labels (_type_): _description_
            tasks (_type_): _description_
            is_training (bool): _description_

        Returns:
            _type_: _description_
        """
        # Encode the example inputs into shape (B, T, SS, E)
        B, T, SS, D = examples.shape
        examples = examples.view(B, T*SS, D)
        # pos encoding
        pos_enc = F.one_hot(torch.arange(T*SS), num_classes=self.p_dim).repeat(B,1,1).to(examples.device)
        h_example = torch.cat([examples, pos_enc], dim=2) # (B, T*SS, E)
        
        # Embed the labels. (B, T, SS, 1) -> (B, T*SS, E)
        h_label = self.label_embs[labels]  # (B, T, SS, E)
        h_label = h_label.view(B, T*SS, self._emb_dim) #(B, T*SS, E)
        
        # task embedding (B, T) -> (B, T, 1, E)
        task_embs = self.task_embs[tasks] # (B, T, E)
        
        hh = torch.empty(
            (B, (SS * 2 +1) * T ,  h_example.shape[2]), # (B, S, E),  S = T*(SS*2 + task) 
            dtype=h_example.dtype, 
        ).to(h_example.device)
        hh[:, 0::(SS*2+1)] = task_embs
        for t in range(T):
            hh[:, t*(SS*2+1)+1: t*(SS*2+1)+1 + SS*2:2] = h_example[:, t*SS:(t+1)*SS]
            hh[:, t*(SS*2+1)+2:t*(SS*2+1)+1 + SS*2:2] = h_label[:, t*SS:(t+1)*SS]

        # last label remove
        hh = hh[:, :-1]
        

        return hh

class MultiTaskInputEmbedderV2(nn.Module):
    """Input embedder."""

    def __init__(self, conf):

        """Initialize the input embedder.

    Args:
      num_classes: Total number of output classes.
      emb_dim: Dimensionality of example and label embeddings.
      example_encoding: How to encode example inputs.
        'resnet': simple resnet encoding
        'linear': flatten and pass through a linear layer
        'embedding': pass through an embedding layer
      flatten_superpixels: Whether to flatten the output of the resnet (instead
        of taking a mean over superpixels).
      example_dropout_prob: Dropout probability on example embeddings. Note that
        these are applied at both train and test.
      concatenate_labels: Whether to concatenate example and label embeddings
        into one token for each (example, label) pair, rather than being fed to
        the transformer as two separate tokens.
      use_positional_encodings: Whether to use positional encoding.
      positional_dropout_prob: Positional dropout probability.
      name: Optional name for the module.
    """
        super(MultiTaskInputEmbedderV2, self).__init__()
        self._num_classes = conf.d_vocab
        self._emb_dim = conf.d_emb
        self.p_dim = conf.p_dim
        self.num_tasks = conf.num_tasks
        self.num_seq_per_task = conf.num_seq_per_task

        self.Emb = nn.Linear(self._emb_dim, self._emb_dim)

        self.label_embs = nn.Parameter(
            torch.randn(self._num_classes, self._emb_dim) / np.sqrt(self._emb_dim)
        )
        
        self.task_embs = nn.Parameter(
            torch.randn(self.num_tasks, self._emb_dim) / np.sqrt(self._emb_dim)
        )
    
    def forward(self, examples, labels, tasks):
        """_summary_

        Args:
            examples (_type_): _description_
            labels (_type_): _description_
            tasks (_type_): _description_
            is_training (bool): _description_

        Returns:
            _type_: _description_
        """
        # Encode the example inputs into shape (B, SS, E)
        B, SS, D = examples.shape
        examples = examples.view(B, SS, D)
        # pos encoding
        pos_enc = F.one_hot(torch.arange(SS), num_classes=self.p_dim).repeat(B,1,1).to(examples.device)
        h_example = torch.cat([examples, pos_enc], dim=2) # (B, SS, E)
        
        # Embed the labels. (B, SS, 1) -> (B, SS, E)
        h_label = self.label_embs[labels]  # (B, SS, E)
        h_label = h_label.view(B, SS, self._emb_dim) #(B, SS, E)
        
        # task embedding (B, SS) -> (B, 1, E)　一つだけ取ってくる
        tmp_task = tasks[:, 0]
        task_embs = self.task_embs[tmp_task] # (B, 1, E)
        hh = torch.empty((B, SS * 2 ,  self._emb_dim), dtype=h_example.dtype, device=h_example.device)
        # hh = torch.zeros((B, (SS * 2 ),  h_example.shape[2]), dtype=h_example.dtype, device=h_example.device )
        hh[:, 0, :] = task_embs
        hh[:, 1::2] = h_example
        hh[:, 2::2] = h_label[:, :-1]

        # last label remove
        # hh = hh[:, :-1]
        

        return hh

class Transformer(nn.Module):
    def __init__(self, embedder, config):
        super().__init__()
        num_layers = config.num_layers
        d_model = config.d_emb
        d_mlp = config.d_emb * 4
        d_head = config.d_emb // config.num_heads
        num_heads = config.num_heads
        n_ctx = config.n_ctx
        act_type = config.act_type
        use_cache = config.use_cache
        use_ln = config.use_ln
        self.cache = {}
        self.use_cache = use_cache
        d_vocab = config.d_vocab

        self.embedder = embedder
        # self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]
                )
                for i in range(num_layers)
            ]
        )
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module) == HookPoint:
                module.give_name(name)

    def forward(self, x, labels):
        x = self.embedder(x, labels,)
        # x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if "hook" in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks("fwd")
            hp.remove_hooks("bwd")

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()

        def save_hook_back(tensor, name):
            cache[name + "_grad"] = tensor[0].detach()

        for hp in self.hook_points():
            hp.add_hook(save_hook, "fwd")
            if incl_bwd:
                hp.add_hook(save_hook_back, "bwd")

class TransformerICL(nn.Module):
    def __init__(self, embedder, config):
        super().__init__()
        num_layers = config.num_layers
        d_model = config.d_emb
        d_mlp = config.d_emb * 4
        d_head = config.d_emb // config.num_heads
        num_heads = config.num_heads
        n_ctx = config.n_ctx
        act_type = config.act_type
        use_cache = config.use_cache
        use_ln = config.use_ln
        self.cache = {}
        self.use_cache = use_cache
        d_vocab = config.d_vocab

        self.embedder = embedder
        # self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList(
            [
                Attention(d_model, num_heads, d_head, n_ctx),
                Attention(d_model, num_heads, d_head, n_ctx),
                Dense(d_model, d_model, act_type),
                Dense(d_model, d_model, act_type),
                Dense(d_model, d_model, act_type),
            ]
        )
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module) == HookPoint:
                module.give_name(name)

    def forward(self, examples, labels, tasks):
        x = self.embedder(examples, labels, tasks)
        for block in self.blocks:
            x = block(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if "hook" in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks("fwd")
            hp.remove_hooks("bwd")

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()

        def save_hook_back(tensor, name):
            cache[name + "_grad"] = tensor[0].detach()

        for hp in self.hook_points():
            hp.add_hook(save_hook, "fwd")
            if incl_bwd:
                hp.add_hook(save_hook_back, "bwd")

In [None]:
def cal_acc(t,p):
    p_arg = torch.argmax(p,dim=1)
    return torch.sum(t == p_arg) / p.shape[0]
def to_gpu_dict(dic):
    dic = {k:v.to("cuda:0") for k,v in dic.items()}
    return dic
def to_gpu_dict_list(dic_list):
    return np.array([to_gpu_dict(dic) for dic in dic_list])

In [None]:
# train
import wandb
from tqdm import tqdm
config = MainConfig()
wandb.init(project="icl-minima-multitask", config=asdict(config))
# data
traindataconfig = MainConfig.traindataconfig
icldataconfig = MainConfig.icldataconfig
iwldataconfig = MainConfig.iwldataconfig
icl2dataconfig = MainConfig.icl2dataconfig
trainconfig = MainConfig.trainconfig

Dataset = SamplingDataset(traindataconfig)

trainloader = MultiTaskSamplingLoader(traindataconfig, dataset=Dataset)
train_seq_generator = trainloader.get_seq
train_dataset = IterDataset(train_seq_generator)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iclloader = MultiTaskSamplingLoader(icldataconfig, dataset=Dataset)
icl_seq_generator = iclloader.get_seq
icl_dataset = IterDataset(icl_seq_generator)
icl_dataloader = torch.utils.data.DataLoader(icl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iwlloader = MultiTaskSamplingLoader(iwldataconfig, dataset=Dataset)
iwl_seq_generator = iwlloader.get_seq
iwl_dataset = IterDataset(iwl_seq_generator)
iwl_dataloader = torch.utils.data.DataLoader(iwl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

icl2loader = MultiTaskSamplingLoader(icl2dataconfig, dataset=Dataset)
icl2_seq_generator = icl2loader.get_seq
icl2_dataset = IterDataset(icl2_seq_generator)
icl2_dataloader = torch.utils.data.DataLoader(icl2_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

# model
embedder = MultiTaskInputEmbedderV2(config.modelconfig)
model = TransformerICL(embedder, config.modelconfig)
model.to(config.device)

# optimizer
if config.trainconfig.optimizer == "adam":
  optimizer =  torch.optim.Adam(model.parameters(), lr=config.trainconfig.lr)
elif config.trainconfig.optimizer == "sgd":
  optimizer =  torch.optim.SGD(model.parameters(), lr=config.trainconfig.lr)
elif config.trainconfig.optimizer == "adamw":
  optimizer =  torch.optim.AdamW(model.parameters(), lr=config.trainconfig.lr)

# loss
criterion = nn.CrossEntropyLoss()
step = 0
# for (data_dict_list, icl_data_dict_list, iwl_data_dict_list, icl2_data_dict_list) in zip(train_dataloader, icl_dataloader, iwl_dataloader, icl2_dataloader):
for data_dict_list in train_dataloader:
  model.train()   
  data_dict = to_gpu_dict(data_dict_list)
  # icl_data_dict = to_gpu_dict(icl_data_dict_list)
  # iwl_data_dict = to_gpu_dict(iwl_data_dict_list)
  # icl2_data_dict = to_gpu_dict(icl2_data_dict_list)
  
  # print("data_dict", data_dict)
  
  logits = model(data_dict["examples"], data_dict["labels"], data_dict["tasks"])
  query_logit = logits[:,-1,:]

  optimizer.zero_grad()
  # print(data_dict["labels"][:,-1])
  loss = criterion(query_logit, data_dict["labels"][:,-1],)
  loss.backward()
  optimizer.step()
  train_acc = cal_acc(data_dict["labels"][:, -1], query_logit)
  # print("train_sample", data_dict["classes"], data_dict["labels"])
  wandb.log({"train/acc":train_acc,"train/loss":loss}, step=step)
  # with torch.no_grad():
  #         model.eval()
  #         logits = model(icl_data_dict["examples"], icl_data_dict["labels"], icl_data_dict["task"])
  #         query_logit = logits[:,-1,:]
  #         icl_acc = cal_acc(icl_data_dict["labels"][:,-1, -1], query_logit)
  #         # wandb.log({"valid/icl_acc":icl_acc}, step=step)
  #         # print("\r",step, icl_acc, end="")
  #         # print("icl_sample", icl_data_dict)

  #         logits = model(iwl_data_dict["examples"], iwl_data_dict["labels"], iwl_data_dict["task"])
  #         query_logit = logits[:,-1,:]
  #         iwl_acc = cal_acc(iwl_data_dict["labels"][:,-1, -1], query_logit)
  #         # wandb.log({"valid/iwl_acc":iwl_acc}, step=step)
  #         # print("\r",step, iwl_acc, end="")
  #         # print("iwl_sample", iwl_data_dict["classes"], iwl_data_dict["labels"])

  #         logits = model(icl2_data_dict["examples"], icl2_data_dict["labels"], icl2_data_dict["task"])
  #         query_logit = logits[:,-1,:]
  #         icl2_acc = cal_acc(icl2_data_dict["labels"][:,-1, -1], query_logit)
          # wandb.log({"valid/icl2_acc":icl2_acc}, step=step)
          # print("\r",step, icl2_acc, end="")
          # print("icl2_sample", icl2_data_dict["classes"], icl2_data_dict["labels"])
          
  print("\r",step, train_acc.item())#, iwl_acc.item(), icl_acc.item(), icl2_acc.item(), end="")
  step+=1
  if step > config.trainconfig.optimize_step:
    break
