In [107]:
from src.tokenizers.character_level.character_level import CharacterLevelTokenizer as Tk
from src.datasets.dataset_helper import make_collate_fn
from torch.utils.data import DataLoader
from src.datasets.shakespeare.shakespeare import ShakespeareDataset as Ds
from src.schedule.vanilla import VanillaScheduler as Schedule

In [108]:
tk = Tk()
ds = Ds(tk, max_length=10, train=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [109]:
s = Schedule(beta_1 = 20.4054 / tk.vocab_size())

In [110]:
cfn = make_collate_fn(s, tk.vocab_size())

In [111]:
dl = DataLoader(ds, batch_size=2, shuffle=True, collate_fn=cfn)

In [112]:
result = next(iter(dl))

In [113]:
result

{'ground_truth': tensor([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        

In [114]:
s_output = result["scheduler_output"]

In [115]:
s_output["beta"].unsqueeze(-1).shape

torch.Size([1, 20, 1])

In [116]:
model_input = result["model_input"]

In [117]:
model_input.shape

torch.Size([1, 20, 35])

In [118]:
mask = result["mask"]
t = result["t"]
doc_id = result["document_id"]

In [119]:
mask, t, doc_id

(tensor([[False, False, False,  True,  True, False, False, False, False, False,
           True,  True, False, False, False,  True,  True,  True,  True,  True]]),
 tensor([[0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200,
          0.1200, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069,
          0.0069, 0.0069]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))

In [120]:
from torch import nn
import torch
import math

In [121]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device, dtype=t.dtype) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

In [122]:
embedding = nn.Parameter(torch.randn(tk.vocab_size(), 4))

In [123]:
x = model_input @ embedding

In [124]:
x, mask, t, doc_id

(tensor([[[-0.4772, -1.2780,  1.7580, -1.5087],
          [-2.0157, -1.9540,  0.9939,  0.5318],
          [ 0.7912, -1.1400,  1.6959, -1.4159],
          [-0.0956, -0.2108,  0.3923, -0.1514],
          [-0.2087, -0.3106,  0.4243, -0.0647],
          [-0.2620,  0.5741, -0.2553,  1.4251],
          [-1.0236, -0.1750,  1.7195, -0.2238],
          [-0.0156,  0.7904, -1.1333, -0.8141],
          [-0.0536,  0.3852, -1.1051, -0.5567],
          [-2.0157, -1.9540,  0.9939,  0.5318],
          [-0.1455, -0.3069,  0.3221, -0.0270],
          [-0.1481, -0.3050,  0.3194, -0.0184],
          [ 0.8501,  0.1454, -0.1745,  1.7904],
          [-0.0536,  0.3852, -1.1051, -0.5567],
          [-0.1069, -0.1281, -0.2889,  1.5231],
          [-0.1497, -0.3087,  0.3302, -0.0164],
          [-0.1462, -0.3073,  0.3258, -0.0296],
          [-0.1418, -0.3082,  0.3310, -0.0164],
          [-0.1509, -0.3096,  0.3265, -0.0177],
          [-0.1483, -0.3106,  0.3282, -0.0240]]], grad_fn=<UnsafeViewBackward0>),
 tenso

In [125]:
temb = SinusoidalTimeEmbedding(4)

In [126]:
device = t.device
half_dim = 4 // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device, dtype=t.dtype) * -emb)

In [127]:
emb, t, doc_id

(tensor([1.0000e+00, 1.0000e-04]),
 tensor([[0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200, 0.1200,
          0.1200, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069,
          0.0069, 0.0069]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))

In [128]:
emb.shape, t.shape, doc_id.shape

(torch.Size([2]), torch.Size([1, 20]), torch.Size([1, 20]))

In [129]:
emb = t.unsqueeze(-1) * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)

In [132]:
emb.shape

torch.Size([1, 20, 4])

In [139]:
mask.shape, mask

(torch.Size([1, 20]),
 tensor([[False, False, False,  True,  True, False, False, False, False, False,
           True,  True, False, False, False,  True,  True,  True,  True,  True]]))

In [138]:
t.shape

torch.Size([1, 20])

In [140]:
torch.where(mask == False, 1, t)

tensor([[1.0000, 1.0000, 1.0000, 0.1200, 0.1200, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000, 0.0069, 0.0069, 1.0000, 1.0000, 1.0000, 0.0069, 0.0069, 0.0069,
         0.0069, 0.0069]])

In [142]:
result["ground_truth"].shape

torch.Size([1, 20, 35])

In [143]:
mask.shape

torch.Size([1, 20])

In [144]:
s_output["alpha"].shape

torch.Size([1, 20])

In [148]:
(result["ground_truth"] * mask.unsqueeze(-1)) * s_output["alpha"]

RuntimeError: The size of tensor a (35) must match the size of tensor b (20) at non-singleton dimension 2

In [150]:
torch.sum(result["ground_truth"] * mask.unsqueeze(-1), dim=(-2, -1)).shape

torch.Size([1])