In [1]:
from torch import Tensor
import torch
from torch.nn import functional as F
from torch import nn
from torch.utils.data import Dataset
import random
from tqdm.auto import tqdm

In [7]:
def loss(
    beta_1: Tensor,
    t: Tensor,
    target: Tensor,
    model_output_probs: Tensor | None = None,
    model_output_logits: Tensor | None = None,
) -> Tensor:
    """
    Args:
        beta_1: Maximum possible accuracy (reached when t=1) of shape (batch_size,).
        t: A tensor representing the time step (batch_size,).
        target: Target tensor of shape (batch_size, seq_len, K).
        model_output_probs: Model output probabilities of shape (batch_size, seq_len, K). If None, model_output_logits must be provided.
        model_output_logits: Model output logits of shape (batch_size, seq_len, K). If None, model_output_probs must be provided.
    Returns:
        Loss value

    Must provide either model_output_probs or model_output_logits, but not both.
    """
    assert (
        model_output_probs is None or model_output_logits is None
    ), "Must provide either model_output_probs or model_output_logits, but not both"
    assert (
        model_output_probs is None or model_output_probs.shape == target.shape
    ), "model_output_probs must have the same shape as target if provided"
    assert (
        model_output_logits is None or model_output_logits.shape == target.shape
    ), "model_output_logits must have the same shape as target if provided"

    batch_size, seq_len, K = target.shape
    model_output = (
        model_output_probs
        if model_output_probs is not None
        else torch.softmax(model_output_logits, dim=-1)
    )
    result = torch.sum(
        K * beta_1 * t * torch.sum((target - model_output) ** 2) / (batch_size**2)
    )
    return result / seq_len

In [2]:
def beta_t(beta_1: Tensor, t: Tensor) -> Tensor:
    """
    Args:
        beta_1: Maximum possible accuracy (reached when t=1) of shape (batch_size,).
        t: A tensor representing the time step, where 1 corresponds to maximum accuracy of shape (batch_size,).
    Returns:
        Beta value at given time step t
    """
    assert beta_1.ndim == 1, "beta_1 should be a 1D tensor"
    assert t.ndim == 1, "t should be a 1D tensor"
    assert beta_1.shape == t.shape, "beta_1 and t should have the same shape"
    assert torch.all(t >= 0), "t must be at least 0"
    assert torch.all(t <= 1), "t must be at most 1"
    return beta_1 * (t**2)


def y_distribution(beta: Tensor, K: int, kron_x: Tensor) -> Tensor:
    """
    Args:
        beta: Tensor of accuracy values for each batch of shape (batch_size,).
        K: Number of classes (usually vocabulary size etc.)
        kron_x: One-hot encoded input tensor of shape (batch_size, seq_len, K).
    Returns:
        Noisy version of kron_x with the amount of noise controlled
        by beta. The shape of the output tensor is the same as kron_x, i.e., (batch_size, seq_len, K).
    """
    beta = beta.view(
        -1, 1, 1
    )  # allows for broadcasting with reach appropriate batch in kron_x
    mean = beta * (K * kron_x - 1)
    variance = beta * K
    epsilon = torch.normal(0, 1, kron_x.shape, device=kron_x.device)
    return mean + (variance**0.5) * epsilon


def theta(y: Tensor):
    """
    Args:
        y: Tensor of shape (batch_size, seq_len, K) representing the noisy version of kron_x.
    Returns:
        Tensor representing the scaled softmax of y, which is the input to the model.
    """
    assert y.ndim == 3, "y should be a 3D tensor of shape (batch_size, seq_len, K)"
    theta = F.softmax(y, dim=-1)
    theta = 2 * theta - 1  # scale to [-1, 1]
    return theta


def sample_t(batch_size, min_t=1e-6):
    return torch.clamp(torch.FloatTensor(batch_size).uniform_(0, 1), min=min_t)


def collate_fn(batch):
    """
    This collate function will truncate all sequences to the minimum length of
    the sequences in the batch

    Args:
        batch: List of dictionaries, each containing 'x', 't', and 'beta'.
    Returns:
        A dictionary with keys 'x', 't', 'beta_1' and 'theta', where 'x' is a tensor of shape
        (batch_size, seq_len, K), 't' is a tensor of shape (batch_size,), 'beta_1'
        is a tensor of shape (batch_size,), and 'theta' is the transformed version of 'x'.
    """
    x = [item["x"] for item in batch]
    min_length = min(seq.shape[0] for seq in x)
    x = [tensor[:min_length] for tensor in x]

    x = torch.stack(x, dim=0)  # Shape: (batch_size, seq_len, K)
    t = torch.cat([item["t"] for item in batch], dim=0)  # Shape: (batch_size,)
    beta = torch.cat([item["beta"] for item in batch], dim=0)
    beta_1 = torch.cat(
        [item["beta_1"] for item in batch], dim=0
    )  # Shape: (batch_size,)

    y = y_distribution(beta, x.shape[-1], x)  # Shape: (batch_size, seq_len, K)
    theta_tensor = theta(y)  # Shape: (batch_size, seq_len, K)

    return {"x": x, "t": t, "beta_1": beta_1, "theta": theta_tensor}

In [3]:
class TokenizerBase:
    def vocab_size(self) -> int:
        raise NotImplementedError("This method should be implemented by subclasses.")

    def encode(self, text: str) -> Tensor:
        raise NotImplementedError("This method should be implemented by subclasses.")

    def decode(self, tokens: Tensor) -> str:
        raise NotImplementedError("This method should be implemented by subclasses.")

In [4]:
class DiscreteSyntheticTokenizer(TokenizerBase):
    # only tokenizes strings like " 8 , 9 , 1 0 , 1 1 , 1 2 ,"
    # this is intended to be used only with the discrete synthetic dataset
    def __init__(self):
        super().__init__()
        self.vocab = {",": 10}
        for i in range(10):
            key = str(i)
            value = i
            self.vocab[key] = value

        self.anti_vocab = {}
        for k in self.vocab:
            self.anti_vocab[self.vocab[k]] = k

    def vocab_size(self) -> int:
        return len(self.vocab)

    def encode(self, text: str) -> Tensor:
        splits = text.split()
        res = [self.vocab.get(s, 0) for s in splits]
        return torch.tensor(res, dtype=torch.long)

    def decode(self, tokens: Tensor) -> str:
        assert tokens.ndim == 2, "tokens should be a 2D tensor of shape (seq_len, K)"
        seq_len, K = tokens.shape
        cur_seq = []
        for i in range(seq_len):
            one_hot_encoding = tokens[i]
            value = torch.argmax(one_hot_encoding)
            cur_seq.append(self.anti_vocab.get(value.item(), ""))
        return " ".join(cur_seq)

In [5]:
class DiscreteSyntheticDataset(Dataset):
    def __init__(
        self,
        tokenizer: DiscreteSyntheticTokenizer,
        length: int = 32,
        tokenized_length: int = 32,
        mini: int = 0,
        maxi: int = 100,
        beta_1: float = 4.0,
        min_t: float = 1e-6,
    ):
        self.length = length
        self.tokenized_length = tokenized_length
        self.tokenizer = tokenizer
        self.mini = mini
        self.maxi = maxi
        self.min_t = min_t
        self.beta_1 = torch.tensor([beta_1])

    def generate_sequence(self):
        start = random.randint(self.mini, self.maxi - self.length)
        end = start + self.length
        acc = ""
        for i in range(start, end + 1):
            for c in str(i):
                acc += " " + c
            acc += " ,"
        tokenized = self.tokenizer.encode(acc)
        return tokenized[: self.tokenized_length]

    def __len__(self):
        return 10000

    def __getitem__(self, idx):
        seq = F.one_hot(
            self.generate_sequence(), num_classes=self.tokenizer.vocab_size()
        )
        t = sample_t(1, self.min_t)
        beta = beta_t(self.beta_1, t)
        return {"x": seq, "t": t, "beta": beta, "beta_1": self.beta_1}

In [6]:
class DiscreteModel(nn.Module):
    def __init__(
        self,
        max_seq_len: int,
        K: int,
        hidden_dim: int,
        num_heads: int,
        layers: int = 3,
        dropout: float = 0.1,
    ):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisble by num_heads"
        self.emb = nn.Parameter(torch.randn(K, hidden_dim))
        self.pos_emb = nn.Parameter(torch.randn(max_seq_len, hidden_dim))
        self.time_vec = nn.Parameter(torch.randn(1, hidden_dim))
        self.layers = nn.ModuleList(
            [
                nn.TransformerEncoderLayer(
                    hidden_dim,
                    num_heads,
                    hidden_dim * 4,
                    dropout,
                    batch_first=True,
                    bias=False,
                )
                for i in range(layers)
            ]
        )
        self.classifier = nn.Parameter(torch.randn(hidden_dim, K))

    def token_emb(self, x):
        return x @ self.emb

    def positional_emb(self, x):
        return x + self.pos_emb[: x.shape[1]]

    def time_emb(self, x, t):
        assert t.ndim == 1, "time vector `t` should be vector of length batch_size"
        # we need to first unsqueeze t to get it from shape (batch_size,)
        # to (batch_size, 1) so it is compatible with the time_vec's (1, hidden_dim)
        # the result is (batch_size, hidden_dim) however the x is
        # (batch_size, seq_len, hidden_dim) so we need a second unsqueeze
        return (t.unsqueeze(-1) @ self.time_vec).unsqueeze(-2) + x

    def forward(self, x, t):
        x = self.token_emb(x)
        x = self.positional_emb(x)
        x = self.time_emb(x, t)
        for i, l in enumerate(self.layers):
            x = l.forward(x)
        return F.relu(x @ self.classifier)

In [None]:
# Training setup
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# Initialize tokenizer and dataset
tokenizer = DiscreteSyntheticTokenizer()
dataset = DiscreteSyntheticDataset(
    tokenizer=tokenizer,
    length=16,
    tokenized_length=32,
    beta_1=4.0
)

# Create data loader
dataloader = DataLoader(
    dataset, 
    batch_size=8, 
    shuffle=True, 
    collate_fn=collate_fn
)

# Initialize model
model = DiscreteModel(
    max_seq_len=32,
    K=tokenizer.vocab_size(),
    hidden_dim=128,
    num_heads=4,
    layers=3,
    dropout=0.1
)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocab size: {tokenizer.vocab_size()}")
print(f"Dataset size: {len(dataset)}")

# Training loop for 2 epochs
num_epochs = 2
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in progress_bar:
        optimizer.zero_grad()
        
        # Get batch data
        x = batch["x"]  # Target one-hot encoded sequences
        t = batch["t"]  # Time steps
        beta_1 = batch["beta_1"]  # Beta values
        theta_tensor = batch["theta"]  # Input to model (noisy version)
        
        # Forward pass
        model_output_logits = model(theta_tensor, t)
        
        # Compute loss
        batch_loss = loss(
            beta_1=beta_1,
            t=t,
            target=x.float(),
            model_output_logits=model_output_logits
        )
        
        # Backward pass
        batch_loss.backward()
        optimizer.step()
        
        # Update statistics
        epoch_loss += batch_loss.item()
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({"loss": f"{batch_loss.item():.4f}"})
    
    avg_loss = epoch_loss / num_batches
    print(f"Epoch {epoch+1} completed. Average loss: {avg_loss:.4f}")

print("Training completed!")

Model parameters: 597,632
Vocab size: 11
Dataset size: 10000


Epoch 1/2:   0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 1 completed. Average loss: 41.1602


Epoch 2/2:   0%|          | 0/1250 [00:00<?, ?it/s]

Epoch 2 completed. Average loss: 40.8119
Training completed!
