In [4]:
import torch
import torch.nn as nn
import os
import numpy as np
import math

In [5]:
!nvidia-smi

Mon Mar 11 06:37:53 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA RTX A6000               On  | 00000000:21:00.0 Off |                  Off |
| 37%   65C    P8              26W / 200W |   6298MiB / 49140MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               On  | 00000000:41:00.0 Off |  

## Model

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops


class Unembed(nn.Module):
    def __init__(self, d_vocab, d_model):
        super().__init__()
        self.W = nn.Parameter(torch.randn(d_vocab, d_model) / np.sqrt(d_model))

    def forward(self, x):
        return torch.einsum("pe,bse->bsp", self.W, x)


class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []

    def give_name(self, name):
        # Called by the model at initialisation
        self.name = name

    def add_hook(self, hook, dir="fwd"):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output,
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, name=self.name)

        if dir == "fwd":
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir == "bwd":
            handle = self.register_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")

    def remove_hooks(self, dir="fwd"):
        if (dir == "fwd") or (dir == "both"):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir == "bwd") or (dir == "both"):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ["fwd", "bwd", "both"]:
            raise ValueError(f"Invalid direction {dir}")

    def forward(self, x):
        return x


class PosEmbed(nn.Module):
    def __init__(self, max_ctx, d_model, weight_scale=1):
        super().__init__()
        self.W_pos = nn.Parameter(torch.randn(max_ctx, d_model) * weight_scale)

    def forward(self, x):
        return x + self.W_pos[: x.shape[-2]]


class LayerNorm(nn.Module):
    def __init__(self, d_model, epsilon=1e-4, model=[None]):
        super().__init__()
        self.model = model
        self.w_ln = nn.Parameter(torch.ones(d_model))
        self.b_ln = nn.Parameter(torch.zeros(d_model))
        self.epsilon = epsilon

    def forward(self, x):
        if self.model[0].use_ln:
            x = x - x.mean(axis=-1)[..., None]
            x = x / (x.std(axis=-1)[..., None] + self.epsilon)
            x = x * self.w_ln
            x = x + self.b_ln
            return x
        else:
            return x


# Attention
class Attention(nn.Module):
    """
    b : batch size
    d : embedding size of token
    p : vocabraly size (113 or 3)
    i : number of heads
    h : embedding size of each heads
    n_ctx : token size
    """

    def __init__(self, d_model, num_heads, d_head, n_ctx):
        super().__init__()
        self.W_K = nn.Parameter(
            torch.randn(num_heads, d_head, d_model) / np.sqrt(d_model)
        )
        self.W_Q = nn.Parameter(
            torch.randn(num_heads, d_head, d_model) / np.sqrt(d_model)
        )
        self.W_V = nn.Parameter(
            torch.randn(num_heads, d_head, d_model) / np.sqrt(d_model)
        )
        self.W_O = nn.Parameter(
            torch.randn(d_model, d_head * num_heads) / np.sqrt(d_model)
        )
        self.register_buffer("mask", torch.tril(torch.ones((n_ctx, n_ctx))))
        self.d_head = d_head

    def forward(self, x):
        k = torch.einsum("ihd,bpd->biph", self.W_K, x)
        q = torch.einsum("ihd,bpd->biph", self.W_Q, x)
        v = torch.einsum("ihd,bpd->biph", self.W_V, x)
        attn_scores_pre = torch.einsum("biph,biqh->biqp", k, q)
        attn_scores_masked = torch.tril(attn_scores_pre) - 1e10 * (
            1 - self.mask[: x.shape[-2], : x.shape[-2]]
        )
        attn_matrix = F.softmax(attn_scores_masked / np.sqrt(self.d_head), dim=-1)
        z = torch.einsum("biph,biqp->biqh", v, attn_matrix)
        z_flat = einops.rearrange(z, "b i q h -> b q (i h)")
        out = torch.einsum("df,bqf->bqd", self.W_O, z_flat)
        return out


class Dense(nn.Module):
    def __init__(self, d_in, d_out, act_type, weight_scale=1):
        super().__init__()
        self.W = nn.Parameter(torch.randn(d_out, d_in))
        torch.nn.init.normal_(self.W, mean=0, std=weight_scale / np.sqrt(d_in))

    def set_weight_ratio(self, weight_ratio):
        self.W = nn.Parameter(self.W * weight_ratio)

    def set_weight_ratio_l2(self, weight_ratio):
        self.W = nn.Parameter(self.W * torch.sqrt(weight_ratio))

    def forward(self, x):
        return x @ self.W.T


# for Transformer
class MLPBlock(nn.Module):
    """
    b : batch size
    d : embedding size of token
    p : vocabraly size (114 or 3)
    i : number of heads
    h : embedding size of each heads
    """

    def __init__(self, d_model, d_mlp, act_type):
        super().__init__()
        # bias & layer norm are removed.
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_model) / np.sqrt(d_model))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.W_out = nn.Parameter(torch.randn(d_model, d_mlp) / np.sqrt(d_model))
        self.b_out = nn.Parameter(torch.zeros(d_model))
        self.act_type = act_type
        # self.ln = LayerNorm(d_mlp, model=self.model)
        assert act_type in ["ReLU", "GeLU"]

    def forward(self, x):
        x = torch.einsum("md,bpd->bpm", self.W_in, x) + self.b_in
        if self.act_type == "ReLU":
            x = F.relu(x)
        elif self.act_type == "GeLU":
            x = F.gelu(x)
        x = torch.einsum("dm,bpm->bpd", self.W_out, x) + self.b_out
        return x

    def set_weight_ratio(self, weight_ratio):
        self.W_in = nn.Parameter(self.W_in * weight_ratio)
        self.W_out = nn.Parameter(self.W_out * weight_ratio)


class TransformerBlock(nn.Module):
    """
    b : batch size
    d : embedding size of token
    p : vocabraly size
    i : number of heads
    h : embedding size of each heads
    """

    def __init__(self, d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model):
        super().__init__()
        # self.ln1 = LayerNorm(d_model, model=self.model)
        self.model = model
        self.attn = Attention(d_model, num_heads, d_head, n_ctx)
        # self.ln2 = LayerNorm(d_model, model=self.model)
        self.mlp = MLPBlock(d_model, d_mlp, act_type)
        self.layer_norm = LayerNorm(d_model, model=self.model)
        self.hook_attn_out = HookPoint()
        self.hook_mlp_out = HookPoint()
        self.hook_resid_pre = HookPoint()
        self.hook_resid_mid = HookPoint()
        self.hook_resid_post = HookPoint()

    def forward(self, x):
        x = self.hook_resid_mid(
            x + self.hook_attn_out(self.attn((self.hook_resid_pre(x))))
        )
        x = self.layer_norm(x)
        x = self.hook_resid_post(x + self.hook_mlp_out(self.mlp((x))))
        return x

    def set_weight_ratio(self, weight_ratio):
        self.attn.set_weight_ratio(weight_ratio)
        self.mlp.set_weight_ratio(weight_ratio)


class InputEmbedder(nn.Module):
    """Input embedder."""

    def __init__(self, conf):

        """Initialize the input embedder.

    Args:
      num_classes: Total number of output classes.
      emb_dim: Dimensionality of example and label embeddings.
      example_encoding: How to encode example inputs.
        'resnet': simple resnet encoding
        'linear': flatten and pass through a linear layer
        'embedding': pass through an embedding layer
      flatten_superpixels: Whether to flatten the output of the resnet (instead
        of taking a mean over superpixels).
      example_dropout_prob: Dropout probability on example embeddings. Note that
        these are applied at both train and test.
      concatenate_labels: Whether to concatenate example and label embeddings
        into one token for each (example, label) pair, rather than being fed to
        the transformer as two separate tokens.
      use_positional_encodings: Whether to use positional encoding.
      positional_dropout_prob: Positional dropout probability.
      name: Optional name for the module.
    """
        super(InputEmbedder, self).__init__()
        self.num_labels = conf.d_vocab
        self.emb_dim = conf.d_emb
        self.p_dim = conf.p_dim
        self.emb_dim_content = self.emb_dim - self.p_dim
        self.n_ctx = conf.n_ctx

        self.Emb = nn.Linear(self.emb_dim, self.emb_dim)

        self.label_embs = nn.Parameter(
            torch.randn(self.num_labels, self.emb_dim_content) / np.sqrt(self.emb_dim_content)
        )

    def forward(self, examples, labels, is_training=True):
        """Call to the input embedder.

        Args:
          examples: input sequence of shape
            [batch_size, seq_len, height, width, channels]
          labels: input sequence of shape [batch_size, seq_len]
          is_training: if is currently training.

        Returns:
          outputs: output of the transformer tower
            of shape [batch_size, seq_len, channels].
        """
        # Encode the example inputs into shape (B, SS, E)
        B, SS, D = examples.shape
        # pos encoding
        pos_enc = F.one_hot(torch.arange(start=0,end=self.n_ctx+1,step=2), num_classes=self.p_dim).repeat(B,1,1).to(examples.device)
        h_example = torch.cat([examples, pos_enc], dim=2)

        # Embed the labels.
        labels_to_embed = labels
        h_label = self.label_embs[labels_to_embed]  # (B, SS, D)
        pos_enc = F.one_hot(torch.arange(start=1,end=self.n_ctx+1,step=2), num_classes=self.p_dim).repeat(B,1,1).to(examples.device)
        h_label = torch.cat([h_label, pos_enc], dim=2) # (B, SS, E)
        
        hh = torch.empty(
            (h_example.shape[0], h_example.shape[1] * 2 - 1, h_example.shape[2]),
            dtype=h_example.dtype,
        ).to(h_example.device)
        
        hh[:, 0::2] = h_example
        hh[:, 1::2] = h_label[:, :-1]

        return hh


class Transformer(nn.Module):
    def __init__(self, embedder, config):
        super().__init__()
        num_layers = config.num_layers
        d_model = config.d_emb
        d_mlp = config.d_emb * 4
        d_head = config.d_emb // config.num_heads
        num_heads = config.num_heads
        n_ctx = config.n_ctx
        act_type = config.act_type
        use_cache = config.use_cache
        use_ln = config.use_ln
        self.cache = {}
        self.use_cache = use_cache
        d_vocab = config.d_vocab

        self.embedder = embedder
        # self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(
                    d_model, d_mlp, d_head, num_heads, n_ctx, act_type, model=[self]
                )
                for i in range(num_layers)
            ]
        )
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module) == HookPoint:
                module.give_name(name)

    def forward(self, x, labels):
        x = self.embedder(x, labels,)
        # x = self.pos_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if "hook" in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks("fwd")
            hp.remove_hooks("bwd")

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()

        def save_hook_back(tensor, name):
            cache[name + "_grad"] = tensor[0].detach()

        for hp in self.hook_points():
            hp.add_hook(save_hook, "fwd")
            if incl_bwd:
                hp.add_hook(save_hook_back, "bwd")

class TransformerICL(nn.Module):
    def __init__(self, embedder, config):
        super().__init__()
        num_layers = config.num_layers
        d_model = config.d_emb
        d_mlp = config.d_emb * 4
        d_head = config.d_emb // config.num_heads
        num_heads = config.num_heads
        n_ctx = config.n_ctx
        act_type = config.act_type
        use_cache = config.use_cache
        use_ln = config.use_ln
        self.cache = {}
        self.use_cache = use_cache
        d_vocab = config.d_vocab

        self.embedder = embedder
        # self.pos_embed = PosEmbed(n_ctx, d_model)
        self.blocks = nn.ModuleList(
            [
                Attention(d_model, num_heads, d_head, n_ctx),
                Attention(d_model, num_heads, d_head, n_ctx),
                Dense(d_model, d_model, act_type),
                Dense(d_model, d_model, act_type),
                Dense(d_model, d_model, act_type),
            ]
        )
        # self.ln = LayerNorm(d_model, model=[self])
        self.unembed = Unembed(d_vocab, d_model)
        self.use_ln = use_ln

        for name, module in self.named_modules():
            if type(module) == HookPoint:
                module.give_name(name)

    def forward(self, x, labels):
        x = self.embedder(x, labels)
        for block in self.blocks:
            x = block(x)
        x = self.unembed(x)
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def hook_points(self):
        return [module for name, module in self.named_modules() if "hook" in name]

    def remove_all_hooks(self):
        for hp in self.hook_points():
            hp.remove_hooks("fwd")
            hp.remove_hooks("bwd")

    def cache_all(self, cache, incl_bwd=False):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, name):
            cache[name] = tensor.detach()

        def save_hook_back(tensor, name):
            cache[name + "_grad"] = tensor[0].detach()

        for hp in self.hook_points():
            hp.add_hook(save_hook, "fwd")
            if incl_bwd:
                hp.add_hook(save_hook_back, "bwd")

In [20]:
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torch
import numpy as np

class SamplingDataset(object):
  def __init__(self,conf):
    self.num_classes = conf.num_classes
    self.dim = conf.dim
    self.num_labels = conf.num_labels
    self.mu, self.labels = self._get_data()

  def _get_data(self):
    mu = torch.normal(mean=0, std=math.sqrt(1/self.dim), size=(self.num_classes,self.dim))
    labels = torch.randint(self.num_labels, size=(self.num_classes,1))
    return mu, labels

class SamplingLoader(DataLoader):

  def __init__(self,conf, dataset):
    self.dataset = dataset
    self.mu, self.labels = self.dataset.mu, self.dataset.labels
    self.data_type = conf.data_type
    self.num_seq = conf.num_seq
    self.alpha = conf.alpha
    self.num_classes = conf.num_classes
    self.num_labels = conf.num_labels
    self.ways = conf.ways
    self.p_bursty = conf.p_bursty
    self.p_icl = conf.p_icl
    self.eps = conf.eps
    self.dim = conf.dim
    if self.ways != 0:
      assert self.num_seq % self.ways == 0
    if self.ways == 0:
      self.p_bursty = 0
    prob = np.array([1/((k+1)**self.alpha) for k in range(self.num_classes)])
    self.prob = prob/prob.sum()

  def get_seq(self):
    while True:
      if self.data_type=="bursty":
        if self.p_icl > np.random.rand():
            # choise few shot example
            num_few_shot_class = self.num_seq//self.ways
            mus, labels = self._get_novel_class_seq(num_few_shot_class)
            # mus = self.mu[few_shot_class]
            mus = np.repeat(mus, self.ways, axis=0) # expand ways
            # labels = self.labels[few_shot_class]
            labels = np.repeat(labels, self.ways, axis=0) # expand ways
            classes = np.arange(num_few_shot_class)
            classes = np.repeat(classes, self.ways)
            # add noise
            x = self.add_noise(mus)
            # permutation shuffle
            ordering = np.random.permutation(self.num_seq)
            x = x[ordering]
            labels = labels[ordering]
            classes = classes[ordering]
            # select query labels
            query_class_idx = np.random.choice(len(classes), 1)
            query_class = classes[query_class_idx]
            query_label = labels[query_class_idx]
            query_mu = mus[query_class_idx]
            query_x = self.add_noise(query_mu)
            # concat
            x = torch.cat([x, query_x])
            labels = torch.cat([labels.flatten(), query_label.flatten()])
            
            yield {
                "examples":x.to(torch.float32),
                "labels":labels,
                "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
            }
            
        else:
          if self.p_bursty > np.random.rand():
            # choise few shot example
            num_few_shot_class = self.num_seq//self.ways
            few_shot_class = np.random.choice(self.num_classes, num_few_shot_class, replace=False)
            mus = self.mu[few_shot_class]
            mus = np.repeat(mus, self.ways, axis=0) # expand ways
            labels = self.labels[few_shot_class]
            labels = np.repeat(labels, self.ways, axis=0) # expand ways
            classes = np.repeat(few_shot_class, self.ways)
            # add noise
            x = self.add_noise(mus)
            # permutation shuffle
            ordering = np.random.permutation(self.num_seq)
            x = x[ordering]
            labels = labels[ordering]
            classes = classes[ordering]
            # select query labels
            query_class = np.random.choice(few_shot_class, 1)
            query_label = self.labels[query_class]
            query_mu = self.mu[query_class]
            query_x = self.add_noise(query_mu)
            # concat
            x = torch.cat([x, query_x])
            labels = torch.cat([labels.flatten(), query_label.flatten()])
            yield {
                "examples":x.to(torch.float32),
                "labels":labels,
                "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
            }
          else:
            # rank frequency
            classes = np.random.choice(self.num_classes, self.num_seq+1, p=self.prob)
            mus = self.mu[classes]
            labels = self.labels[classes]
            x = self.add_noise(mus)
            # permutation shuffle
            ordering = np.random.permutation(self.num_seq+1)
            x = x[ordering]
            labels = labels[ordering]
            classes = classes[ordering]

            yield {
                "examples":x.to(torch.float32),
                "labels":labels.flatten(),
                "classes" : torch.from_numpy(classes)
            }

      elif self.data_type == "no_support":
          # rank frequency
          classes = np.random.choice(self.num_classes, self.num_seq+1, p=self.prob)
          mus = self.mu[classes]
          labels = self.labels[classes]
          x = self.add_noise(mus)
          # permutation shuffle
          ordering = np.random.permutation(self.num_seq+1)
          x = x[ordering]
          labels = labels[ordering]
          classes = classes[ordering]

          yield {
              "examples":x.to(torch.float32),
              "labels":labels.flatten(),
              "classes" : torch.from_numpy(classes)
          }
          
      elif self.data_type == "holdout":
        # choise few shot example
        num_few_shot_class = self.num_seq//self.ways
        mus, labels = self._get_novel_class_seq(num_few_shot_class)
        # mus = self.mu[few_shot_class]
        mus = np.repeat(mus, self.ways, axis=0) # expand ways
        # labels = self.labels[few_shot_class]
        labels = np.repeat(labels, self.ways, axis=0) # expand ways
        classes = np.arange(num_few_shot_class)
        classes = np.repeat(classes, self.ways)
        # add noise
        x = self.add_noise(mus)
        # permutation shuffle
        ordering = np.random.permutation(self.num_seq)
        x = x[ordering]
        labels = labels[ordering]
        classes = classes[ordering]
        # select query labels
        query_class_idx = np.random.choice(len(classes), 1)
        query_class = classes[query_class_idx]
        query_label = labels[query_class_idx]
        query_mu = mus[query_class_idx]
        query_x = self.add_noise(query_mu)
        # concat
        x = torch.cat([x, query_x])
        labels = torch.cat([labels.flatten(), query_label.flatten()])
        
        yield {
            "examples":x.to(torch.float32),
            "labels":labels,
            "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
        }

      elif self.data_type == "flip":
        # choise few shot example
        num_few_shot_class = self.num_seq//self.ways
        few_shot_class = np.random.choice(self.num_classes, num_few_shot_class, replace=False)
        mus = self.mu[few_shot_class]
        mus = np.repeat(mus, self.ways, axis=0) # expand ways
        classes = np.repeat(few_shot_class, self.ways)
        # label flip
        labels = (self.labels[classes] + 1) % self.num_labels
        # add noise
        x = self.add_noise(mus)
        # permutation shuffle
        ordering = np.random.permutation(self.num_seq)
        x = x[ordering]
        labels = labels[ordering]
        classes = classes[ordering]
        # select query labels
        query_class = np.random.choice(few_shot_class, 1)
        query_label = (self.labels[query_class] + 1) % self.num_labels
        query_mu = self.mu[query_class]
        query_x = self.add_noise(query_mu)
        # concat
        x = torch.cat([x, query_x])
        labels = torch.cat([labels.flatten(), query_label.flatten()])
        
        yield {
            "examples":x.to(torch.float32),
            "labels":labels,
            "classes" : torch.cat([torch.from_numpy(classes).flatten(), torch.from_numpy(query_class).flatten()])
        }
    
  

  def add_noise(self, x):
    x = (x+self.eps*torch.normal(mean=0, std=math.sqrt(1/self.dim), size=(x.shape)))/(np.sqrt(1+self.eps**2))
    # x = (x+self.eps*np.random.normal(mean=0, std=np.sqrt(1/self.dim), size=(x.shape[0],1)))/(np.sqrt(1+self.eps**2))
    return x
  
  def _get_novel_class_seq(self,num_class):
    mu = torch.normal(mean=0, std=math.sqrt(1/self.dim), size=(num_class,self.dim))
    labels = torch.randint(self.num_labels, size=(num_class,1))
    return mu, labels

class IterDataset(IterableDataset):
    def __init__(self, generator):
        self.generator = generator


    def __iter__(self):
        return self.generator()



In [21]:
from dataclasses import dataclass, asdict

@dataclass
class TransformerConfig:
  num_layers: int = 2
  d_vocab: int = 32 # same as num_labels
  d_model: int = 128 
  d_mlp: int = 128
  d_head: int = 128
  num_heads: int = 1
  n_ctx: int = int(8*2+1)
  act_type: str = "ReLU"
  use_cache: bool = False
  use_ln: bool = True
  p_dim: int = 65
  d_emb: int = 128

@dataclass
class TrainDataConfig:
  num_classes: int = 512
  dim: int = 63
  num_labels: int = 32
  eps: float = 0.1
  alpha: float = 0
  ways: int = 2
  num_seq: int = 8
  p_bursty: float = 0.75
  p_icl: float = 0
  data_type: str = "bursty" # bursty, holdout, no_support, flip

@dataclass
class IWLDataConfig(TrainDataConfig):
  data_type: str = "no_support" # bursty, holdout, no_support, flip

@dataclass
class ICLDataConfig(TrainDataConfig):
  data_type: str = "holdout" # bursty, holdout, no_support, flip


@dataclass
class ICL2DataConfig(TrainDataConfig):
  data_type: str = "flip" # bursty, holdout, no_support, flip
  
@dataclass
class TrainConfig:
  batch_size: int = 1
  optimize_step: int = int(2e5)
  lr: float = 0.01
  optimizer: str = "sgd" # adam, sgd, adamw

@dataclass
class MainConfig:
  traindataconfig : TrainDataConfig = TrainDataConfig()
  icldataconfig: ICLDataConfig = ICLDataConfig()
  iwldataconfig: IWLDataConfig = IWLDataConfig()
  icl2dataconfig: ICL2DataConfig = ICL2DataConfig()
  modelconfig: TransformerConfig = TransformerConfig()
  trainconfig: TrainConfig = TrainConfig()
  device: str = "cuda:1"
# define config

In [22]:
def cal_acc(t,p):
    p_arg = torch.argmax(p,dim=1)
    return torch.sum(t == p_arg) / p.shape[0]
def to_gpu_dict(dic):
    dic = {k:v.to("cuda:1") for k,v in dic.items()}
    return dic

In [25]:
traindataconfig = MainConfig.traindataconfig
icldataconfig = MainConfig.icldataconfig
iwldataconfig = MainConfig.iwldataconfig
icl2dataconfig = MainConfig.icl2dataconfig
trainconfig = MainConfig.trainconfig

Dataset = SamplingDataset(traindataconfig)

trainloader = SamplingLoader(traindataconfig, dataset=Dataset)
train_seq_generator = trainloader.get_seq
train_dataset = IterDataset(train_seq_generator)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iclloader = SamplingLoader(icldataconfig, dataset=Dataset)
icl_seq_generator = iclloader.get_seq
icl_dataset = IterDataset(icl_seq_generator)
icl_dataloader = torch.utils.data.DataLoader(icl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iwlloader = SamplingLoader(iwldataconfig, dataset=Dataset)
iwl_seq_generator = iwlloader.get_seq
iwl_dataset = IterDataset(iwl_seq_generator)
iwl_dataloader = torch.utils.data.DataLoader(iwl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

icl2loader = SamplingLoader(icl2dataconfig, dataset=Dataset)
icl2_seq_generator = icl2loader.get_seq
icl2_dataset = IterDataset(icl2_seq_generator)
icl2_dataloader = torch.utils.data.DataLoader(icl2_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

cnt = 0
for data in train_dataloader:
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    # print(examples)
    print("train_class", classes)
    print("train_label", labels)
    cnt += 1
    if cnt > 2:
        break
cnt = 0
for data in iwl_dataloader:
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    # print(examples)
    print("iwl_class", classes)
    print("iwl_label", labels)
    cnt += 1
    if cnt > 2:
        break
cnt = 0
for data in icl2_dataloader:
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    # print(examples)
    print("icl2_class", classes)
    print("icl2_label", labels)
    cnt += 1
    if cnt > 2:
        break
cnt = 0
for data in icl_dataloader:
    examples = data["examples"]
    labels = data["labels"]
    classes = data["classes"]
    # print(examples
    print("icl_class", classes)
    print("icl_label", labels)
    cnt += 1
    if cnt > 5:
        break


train_class tensor([[245,  84, 251,  84, 245, 251,  80,  80,  84]])
train_label tensor([[ 7,  2, 26,  2,  7, 26, 21, 21,  2]])
train_class tensor([[ 50, 417, 493,  50,  12, 417,  12, 493, 493]])
train_label tensor([[14, 14,  3, 14, 28, 14, 28,  3,  3]])
train_class tensor([[413, 284, 284, 166, 166, 413,   6,   6, 284]])
train_label tensor([[ 2, 31, 31, 21, 21,  2, 30, 30, 31]])
iwl_class tensor([[237, 243,  91, 172, 133, 218, 206, 128, 315]])
iwl_label tensor([[18,  0, 12, 10, 25, 19, 26,  7, 27]])
iwl_class tensor([[ 69, 334, 407, 370,  97, 459,  51, 106,  10]])
iwl_label tensor([[ 5,  9, 14,  5, 18, 19,  0, 12, 27]])
iwl_class tensor([[337,  72,  40, 444,  39, 338,  46, 425, 183]])
iwl_label tensor([[28, 10,  1,  9, 28, 19, 23, 24, 28]])
icl2_class tensor([[333, 270, 333,  23, 471,  23, 471, 270,  23]])
icl2_label tensor([[19,  4, 19,  3, 16,  3, 16,  4,  3]])
icl2_class tensor([[144, 497, 497,  70, 127,  70, 144, 127,  70]])
icl2_label tensor([[ 1, 27, 27,  1, 17,  1,  1, 17,  1]])


In [8]:
# train
import wandb
from tqdm import tqdm
config = MainConfig()
wandb.init(project="icl-minima", config=asdict(config))
# data
Dataset = SamplingDataset(traindataconfig)

trainloader = SamplingLoader(traindataconfig, dataset=Dataset)
train_seq_generator = trainloader.get_seq
train_dataset = IterDataset(train_seq_generator)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iclloader = SamplingLoader(icldataconfig, dataset=Dataset)
icl_seq_generator = iclloader.get_seq
icl_dataset = IterDataset(icl_seq_generator)
icl_dataloader = torch.utils.data.DataLoader(icl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

iwlloader = SamplingLoader(iwldataconfig, dataset=Dataset)
iwl_seq_generator = iwlloader.get_seq
iwl_dataset = IterDataset(iwl_seq_generator)
iwl_dataloader = torch.utils.data.DataLoader(iwl_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

icl2loader = SamplingLoader(icl2dataconfig, dataset=Dataset)
icl2_seq_generator = icl2loader.get_seq
icl2_dataset = IterDataset(icl2_seq_generator)
icl2_dataloader = torch.utils.data.DataLoader(icl2_dataset, batch_size=trainconfig.batch_size, pin_memory=True, num_workers=os.cpu_count())

# model
embedder = InputEmbedder(config.modelconfig)
model = TransformerICL(embedder, config.modelconfig)
model.to(config.device)

# optimizer
if config.trainconfig.optimizer == "adam":
  optimizer =  torch.optim.Adam(model.parameters(), lr=config.trainconfig.lr)
elif config.trainconfig.optimizer == "sgd":
  optimizer =  torch.optim.SGD(model.parameters(), lr=config.trainconfig.lr)
elif config.trainconfig.optimizer == "adamw":
  optimizer =  torch.optim.AdamW(model.parameters(), lr=config.trainconfig.lr)

# loss
criterion = nn.CrossEntropyLoss()
step = 0
for (data_dict, icl_data_dict, iwl_data_dict, icl2_data_dict) in zip(train_dataloader, icl_dataloader, iwl_dataloader, icl2_dataloader):
  model.train()   
  data_dict = to_gpu_dict(data_dict)
  icl_data_dict = to_gpu_dict(icl_data_dict)
  iwl_data_dict = to_gpu_dict(iwl_data_dict)
  icl2_data_dict = to_gpu_dict(icl2_data_dict)
  
  logits = model(data_dict["examples"], data_dict["labels"])
  query_logit = logits[:,-1,:]

  optimizer.zero_grad()
  # print(data_dict["labels"][:,-1])
  loss = criterion(query_logit, data_dict["labels"][:,-1],)
  loss.backward()
  optimizer.step()
  train_acc = cal_acc(data_dict["labels"][:, -1], query_logit)
  # print("train_sample", data_dict["classes"], data_dict["labels"])
  wandb.log({"train/acc":train_acc,"train/loss":loss}, step=step)
  with torch.no_grad():
    model.eval()
    logits = model(icl_data_dict["examples"], icl_data_dict["labels"])
    query_logit = logits[:,-1,:]
    icl_acc = cal_acc(icl_data_dict["labels"][:, -1], query_logit)
    wandb.log({"valid/icl_acc":icl_acc}, step=step)
    # print("icl_sample", icl_data_dict["classes"], icl_data_dict["labels"])

    logits = model(iwl_data_dict["examples"], iwl_data_dict["labels"])
    query_logit = logits[:,-1,:]
    iwl_acc = cal_acc(iwl_data_dict["labels"][:, -1], query_logit)
    wandb.log({"valid/iwl_acc":iwl_acc}, step=step)
    # print("iwl_sample", iwl_data_dict["classes"], iwl_data_dict["labels"])

    logits = model(icl2_data_dict["examples"], icl2_data_dict["labels"])
    query_logit = logits[:,-1,:]
    icl2_acc = cal_acc(icl2_data_dict["labels"][:, -1], query_logit)
    wandb.log({"valid/icl2_acc":icl2_acc}, step=step)
    # print("icl2_sample", icl2_data_dict["classes"], icl2_data_dict["labels"])
          
  print("\r",step, train_acc.item(), iwl_acc.item(), icl_acc.item(), icl2_acc.item(), end="")
  step+=1
  if step > config.trainconfig.optimize_step:
    break


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgouki[0m. Use [1m`wandb login --relogin`[0m to force relogin


 197668 1.0 1.0 0.0234375 0.025 0.0 0.00007812539 0.2421875 0.015625 0.1640625 0.2890625 0.0703125 0.21875 0.1796875 0.0234375 0.2734375 0.2421875 0.1953125 0.015625 0.2578125 0.3046875 0.2421875 0.0390625 0.3515625 0.28906250.1953125 0.03125 0.265625 0.25 0.265625 0.2421875 0.1796875 0.0234375 0.28125 0.24218750.0234375 0.265625 0.31250.03125 0.296875 0.2890625 0.3046875 0.28125 0.242187526613 0.2578125 0.0625 0.296875 0.250.3046875 0.1796875 0.0390625 0.25 0.2890625 0.2265625 0.03125 0.3125 0.265625 0.2578125 0.078125 0.328125 0.2734375 0.2890625 0.343750.3125 0.03125 0.3671875 0.26562539063 0.265625 0.0546875 0.265625 0.2890625 0.2968750.0625 0.2421875 0.23437541675 0.375 0.0390625 0.3125 0.3125 0.6015625 0.2109375 0.1796875 0.1718750.1796875 0.1875 0.226562544411 0.6640625 0.2421875 0.1875 0.1718750.171875 0.375 0.1484375 0.078125 0.3359375 0.140625 0.09375 0.5859375 0.0546875 0.101562550913 0.8671875 0.734375 0.0390625 0.031250.859375 0.03125 0.015625 0.96875 0.0390625 0.0156250.9

KeyboardInterrupt: 