In [105]:
import sys
import os

parent_dir = os.path.abspath(os.path.join(os.path.dirname(os.getcwd()), "./"))
sys.path.append(parent_dir)

In [106]:
import torch
from torch import Tensor
from torch.nn import functional as F

In [107]:
def beta_t(beta_1: Tensor, t: Tensor) -> Tensor:
    """
    Args:
        beta_1: Maximum possible accuracy (reached when t=1) of shape (batch_size,).
        t: A tensor representing the time step, where 1 corresponds to maximum accuracy of shape (batch_size,).
    Returns:
        Beta value at given time step t
    """
    assert beta_1.ndim == 1, "beta_1 should be a 1D tensor"
    assert t.ndim == 1, "t should be a 1D tensor"
    assert beta_1.shape == t.shape, "beta_1 and t should have the same shape"
    assert torch.all(t >= 0), "t must be at least 0"
    assert torch.all(t <= 1), "t must be at most 1"
    return beta_1 * (t ** 2)

def y_distribution(beta: Tensor, K: int, kron_x: Tensor) -> Tensor:
    """
    Args:
        beta: Tensor of accuracy values for each batch of shape (batch_size,).
        K: Number of classes (usually vocabulary size etc.)
        kron_x: One-hot encoded input tensor of shape (batch_size, seq_len, K).
    Returns:
        Noisy version of kron_x with the amount of noise controlled
        by beta. The shape of the output tensor is the same as kron_x, i.e., (batch_size, seq_len, K).
    """
    beta = beta.view(-1, 1, 1) # allows for broadcasting with reach appropriate batch in kron_x
    mean = beta * (K * kron_x - 1)
    variance = beta * K
    epsilon = torch.normal(0, 1, kron_x.shape, device=kron_x.device)
    return mean + variance * epsilon # I know the name `variance` suggests it should be squared, but this works just fine

def theta(y: Tensor):
    """
    Args:
        y: Tensor of shape (batch_size, seq_len, K) representing the noisy version of kron_x.
    Returns:
        Tensor representing the scaled softmax of y, which is the input to the model.
    """
    assert y.ndim == 3, "y should be a 3D tensor of shape (batch_size, seq_len, K)"
    theta = F.softmax(y, dim=-1)
    theta = 2 * theta - 1  # scale to [-1, 1]
    return theta

def sample_t(batch_size, min_t=1e-6):
   return torch.clamp(torch.FloatTensor(batch_size).uniform_(0,1), min=min_t)

In [108]:
from torch.nn import functional as F
from torch.utils.data import Dataset
from src.tokenizers.discrete_synthetic.discrete_synthetic_tokenizer import DiscreteSyntheticTokenizer
import random

class DiscreteSyntheticDataset(Dataset):
    def __init__(self, tokenizer: DiscreteSyntheticTokenizer, length: int = 32, tokenized_length: int = 32, mini: int = 0, maxi: int = 100, min_t: float = 1e-6, beta_1: float = 4.0):
        self.length = length
        self.tokenized_length = tokenized_length
        self.tokenizer = tokenizer
        self.mini = mini
        self.maxi = maxi
        self.min_t = min_t
        self.beta_1 = torch.tensor([beta_1])

    def generate_sequence(self):
        start = random.randint(self.mini, self.maxi - self.length)
        end = start + self.length
        acc = ""
        for i in range(start, end+1):
            for c in str(i):
                acc += " " + c
            acc += " ,"
        tokenized = self.tokenizer.encode(acc)
        return tokenized[:self.tokenized_length - random.randint(1, 2)] # add jitter to mimic real data variability
    
    def __len__(self):
        return 10000
    
    def __getitem__(self, idx):
        seq = F.one_hot(self.generate_sequence(), num_classes=self.tokenizer.vocab_size())
        t = sample_t(1, self.min_t)
        beta = beta_t(self.beta_1, t)
        # y = y_distribution(beta, self.tokenizer.vocab_size(), seq)
        # theta_y = theta(y)
        return {
            "x": seq,
            "t": t,
            "beta": beta,
        }

In [109]:
from src.tokenizers.discrete_synthetic.discrete_synthetic_tokenizer import DiscreteSyntheticTokenizer

In [110]:
tk = DiscreteSyntheticTokenizer()
ds = DiscreteSyntheticDataset(tk, length=4, tokenized_length=4)

In [111]:
from torch.utils.data import DataLoader

In [112]:
def collage_fn(batch):
    """
    We expect batch to be a list of dictionaries with keys 'x', 't', and 'beta'.
    """

    # first, we handle truncating the `x` tensors as needed
    x = [item['x'] for item in batch]
    min_length = min(tensor.shape[0] for tensor in x)
    x = [tensor[:min_length] for tensor in x]

    # and now we stack the x tensors and concatenate the t and beta tensors
    x = torch.stack(x, dim=0)
    t = torch.cat([item['t'] for item in batch], dim=0)
    beta = torch.cat([item['beta'] for item in batch], dim=0)

    # and now we use this to calculate the theta
    y = y_distribution(beta, x.shape[-1], x)
    theta_y = theta(y)

    return {
        "x": x,
        "t": t,
        "theta": theta_y
    }

In [113]:
dl = DataLoader(ds, batch_size=2, shuffle=True, collate_fn=collage_fn)

In [114]:
iter_dl = iter(dl)

In [115]:
batch = next(iter_dl)

In [116]:
batch

{'x': tensor([[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
          [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
 
         [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]]),
 't': tensor([0.5281, 0.3546]),
 'theta': tensor([[[-0.9997, -1.0000, -1.0000, -1.0000, -1.0000, -0.7456, -0.9963,
           -1.0000,  0.7416, -1.0000, -1.0000],
          [-1.0000, -0.9826, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
           -1.0000, -1.0000, -1.0000,  0.9826]],
 
         [[-1.0000,  0.7658, -1.0000, -0.9959, -0.8931, -1.0000, -0.9996,
           -0.8772, -1.0000, -1.0000, -1.0000],
          [-1.0000, -0.9991, -1.0000, -1.0000, -0.9997, -1.0000, -1.0000,
           -0.9997, -1.0000, -1.0000,  0.9985]]])}