In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import math
import pandas as pd
import random
import wandb

from torch import nn, optim
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.utils.data import DataLoader, Dataset, ConcatDataset

from datasets import load_dataset
from transformers import AutoTokenizer
from staticvectors import StaticVectors
from datetime import datetime
from tqdm import tqdm

from models.LanguageTransformer import LanguageTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# dynamically select device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [3]:
model_config = {
    'emb_dim': 256,
    'num_layers': 1,
    'num_heads': 1
}

In [4]:
# load language model
model = LanguageTransformer(
    vocab_size=5,
    embed_dim=model_config['emb_dim'],
    num_layers=model_config['num_layers'],
    num_heads=model_config['num_heads']
).to(device)

In [6]:
seq = torch.randint(0, 5, (1, 5)).to(device)
out = model(seq)
print(out.shape)
assert out.shape == (1, 5, 5)
print("[dry_run] passed")
print(out)

tensor([[[[-0.6866, -0.7309, -0.5838, -0.7103, -0.6750],
          [   -inf, -0.7311, -0.6775, -0.7035, -0.6630],
          [   -inf,    -inf, -0.3126, -0.8994, -0.8327],
          [   -inf,    -inf,    -inf, -0.6293, -0.5698],
          [   -inf,    -inf,    -inf,    -inf, -0.5161]]]], device='mps:0',
       grad_fn=<MaskedFillBackward0>)
tensor([[[[0.1979, 0.1893, 0.2193, 0.1933, 0.2002],
          [0.0000, 0.2407, 0.2540, 0.2475, 0.2577],
          [0.0000, 0.0000, 0.4650, 0.2586, 0.2764],
          [0.0000, 0.0000, 0.0000, 0.4851, 0.5149],
          [0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]]], device='mps:0',
       grad_fn=<SoftmaxBackward0>)
torch.Size([1, 5, 5])
[dry_run] passed
tensor([[[ 0.4615, -0.8638,  0.2154,  0.3924,  0.4707],
         [ 0.0386, -0.9822,  0.1829,  0.3597,  0.3074],
         [ 0.2930,  2.4258, -0.7168, -0.3578, -2.0050],
         [-0.4893, -0.4189,  0.2224, -0.1561,  0.6248],
         [-0.4342, -0.1854,  0.2496, -0.1335,  0.8159]]], device='mps:0',
       