In [25]:
import numpy as np
from models.base_models import Transformer
from arithmetic_sampler import ArithmeticSampler
from config import get_config
from train import train
import torch
import torch.nn as nn
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
config = get_config()
sampler = ArithmeticSampler(config.task.max_variables)
model = Transformer(config)

In [3]:
train(model, sampler, config, verbose=False)

Results are saved in:  results/train_e791069c5cae4ac9d47c461d4f718897
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
Transformer                              [256, 82]                 [256, 82, 16]             --
├─Embedding: 1-1                         [256, 82]                 [256, 82, 64]             1,024
├─ModuleList: 1-2                        --                        --                        --
│    └─TFBlock: 2-1                      [256, 82, 64]             [256, 82, 64]             --
│    │    └─Identity: 3-1                [256, 82, 64]             [256, 82, 64]             --
│    │    └─MultiHeadAttention: 3-2      [256, 82, 64]             [256, 82, 64]             16,384
│    └─TFBlock: 2-2                      [256, 82, 64]             [256, 82, 64]             --
│    │    └─Identity: 3-3                [256, 82, 64]             [256, 82, 64]             --
│    │    └─MultiHeadAttention: 3-4      [256, 82, 64]

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhyan84[0m ([33mhyan84-university-of-wisconsin-madison[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Start training...


  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

  0%|          | 0/1250 [00:00<?, ?it/s]

Training complete.


In [123]:
%%timeit
sampler.generate(2**17)

1.16 s ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [46]:
batch, mask = sampler.generate(1)

In [49]:
sampler.decode(batch)

['                        ((4+4)+(0-5))+(6-(1-((7-2)-7)))=(8+(0-5))+(6-(1-((7-2)-7)))']

In [50]:
def sample_from_transformer(model, batch, mask, tokenizer=None, temperature=1.0, top_k=None):
    """
    model: Transformer with output logits of shape (B, T, vocab_size)
    input_ids: LongTensor of shape (1, T) — starting token sequence
    k: how many tokens to sample
    tokenizer: (optional) for decoding results
    """
    k = mask.sum().item()
    model.eval()
    seq_len = batch.shape[1]
    batch = F.pad(batch, pad=(k, 0), value=15)
    print(batch.shape)
    generated = batch[:,:seq_len].clone()

    for t in range(k):
        # Get model output for current sequence
        with torch.no_grad():
            output, _ = model(generated[:,t:seq_len+t])  # logits: (1, T, vocab_size)

        logits = output[:, -1, :]  # get logits for the last token

        # Apply temperature
        logits = logits / temperature

        # (Optional) Top-k filtering
        if top_k is not None:
            values, indices = torch.topk(logits, top_k)
            mask = logits < values[:, [-1]]
            logits[mask] = float('-inf')

        probs = F.softmax(logits, dim=-1)  # shape (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1)  # shape (1, 1)

        # Append sampled token to sequence
        generated = torch.cat((generated, next_token), dim=1)

    if tokenizer:
        return tokenizer.decode(generated[0], skip_special_tokens=True)
    return generated


In [52]:
sampler.decode(sample_from_transformer(model, batch, mask))

torch.Size([1, 110])


['                                                   ((4+4)+(0-5))+(6-(1-((7-2)-7)))=(1+1)))))))))))))1-62-6))))']

In [53]:
targets = batch[:, 1:]
target_mask = mask[:, 1:]

In [55]:
target_mask.shape

torch.Size([1, 82])

In [56]:
targets.shape

torch.Size([1, 82])