In [15]:
import torch
from src.tokenizers.ascii.ascii_tokenizer import ASCIITokenizer as Tokenizer

In [6]:
data = torch.load("debug_data_current_epoch.pt")

In [14]:
data.keys() # dict_keys(['x', 't', 'output', 'alpha', 'formatted_loss', 'l_infty_loss', 'var_loss', 'div_loss', 'l', 'model_state_dict'])

# check every value for NaNs
for key, value in data.items():
    if isinstance(value, torch.Tensor) and torch.isnan(value).any():
        print(f"NaN found in {key}")
    elif isinstance(value, dict):
        print("Checking model state")
        for sub_key, sub_value in value.items():
            if isinstance(sub_value, torch.Tensor) and torch.isnan(sub_value).any():
                print(f"NaN found in {key}.{sub_key}")

NaN found in output
NaN found in formatted_loss
NaN found in l_infty_loss
NaN found in var_loss
NaN found in div_loss
NaN found in l
Checking model state


In [16]:
max_seq_len = 32
tokenizer = Tokenizer()
model_kwargs = {
    "max_seq_len": max_seq_len,
    "K": tokenizer.vocab_size(),
    "hidden_dim": 512,
    "num_heads": 8,
    "layers": 5,
}

In [17]:
from src.nn.models.discrete_model import DiscreteModel

model = DiscreteModel(**model_kwargs)


In [18]:
model.load_state_dict(data['model_state_dict'])

<All keys matched successfully>

In [24]:
x, t = data['x'].cpu(), data['t'].cpu()

In [25]:
output = model(x, t)

In [26]:
output

(tensor([[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],
 
         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],
 
         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],
 
         ...,
 
         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, na

In [27]:
beta, alpha = model.beta_and_alpha(t, tokenizer.vocab_size())

In [28]:
beta, alpha

(tensor([-2.6785e-06, -1.4770e+00, -1.5288e+00, -2.4835e+00, -8.8041e-01,
         -4.0479e-01, -3.3264e-01, -6.1584e-01, -2.6785e-06, -7.0165e-01,
         -9.4166e-01, -1.8685e+00, -9.3534e-01, -2.6451e+00, -2.5092e+00,
         -1.6764e+00, -2.6785e-06, -2.2560e+00, -4.4587e-01, -4.6710e-01,
         -1.3812e+00, -1.7136e-01, -4.2483e-01, -2.0943e+00, -2.6785e-06,
         -2.1890e+00, -1.7288e+00, -1.2947e+00, -1.2306e+00, -2.1246e+00,
         -1.1625e+00, -1.0675e+00, -2.6785e-06, -4.8445e-02, -2.6781e+00,
         -5.6973e-01, -2.0704e+00, -2.2214e+00, -1.2079e+00, -5.0947e-02,
         -2.6785e-06, -2.0688e+00, -4.2078e-01, -2.3500e+00, -1.8849e+00,
         -1.8465e+00, -2.5623e+00, -1.4666e+00, -2.6785e-06, -2.4352e-01,
         -3.6537e-01, -2.4676e+00, -1.7725e+00, -5.7731e-01, -2.5315e+00,
         -2.2546e+00, -2.6785e-06, -2.8782e-03, -2.7426e-01, -5.9708e-01,
         -8.2768e-01, -1.9792e+00, -2.0898e+00, -2.3763e+00],
        grad_fn=<SqueezeBackward1>),
 tensor([-2.6

In [29]:
from src.datasets.discrete_helper import theta, y_distribution

In [30]:
y_distribution(beta, tokenizer.vocab_size(), x)

tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        ...,

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]

In [31]:
beta_v = beta.view(-1, 1, 1)

In [32]:
mean = beta_v * (tokenizer.vocab_size() * x - 1)

In [33]:
mean

tensor([[[ 2.6785e-06,  2.6785e-06,  2.6785e-06,  ...,  2.6785e-06,
           2.6785e-06,  2.6785e-06],
         [ 2.6785e-06, -9.9103e-05,  2.6785e-06,  ...,  2.6785e-06,
           2.6785e-06,  2.6785e-06],
         [ 2.6785e-06,  2.6785e-06,  2.6785e-06,  ...,  2.6785e-06,
           2.6785e-06,  2.6785e-06],
         ...,
         [ 2.6785e-06,  2.6785e-06,  2.6785e-06,  ...,  2.6785e-06,
          -9.9103e-05,  2.6785e-06],
         [-9.9103e-05,  2.6785e-06,  2.6785e-06,  ...,  2.6785e-06,
           2.6785e-06,  2.6785e-06],
         [-9.9103e-05,  2.6785e-06,  2.6785e-06,  ...,  2.6785e-06,
           2.6785e-06,  2.6785e-06]],

        [[ 1.4770e+00,  1.4770e+00,  1.4770e+00,  ...,  1.4770e+00,
           1.4770e+00,  1.4770e+00],
         [ 1.4770e+00, -5.4647e+01,  1.4770e+00,  ...,  1.4770e+00,
           1.4770e+00,  1.4770e+00],
         [ 1.4770e+00,  1.4770e+00,  1.4770e+00,  ...,  1.4770e+00,
           1.4770e+00,  1.4770e+00],
         ...,
         [ 1.4770e+00,  1

In [34]:
variance = beta_v * tokenizer.vocab_size()

In [35]:
variance

tensor([[[-1.0178e-04]],

        [[-5.6124e+01]],

        [[-5.8094e+01]],

        [[-9.4372e+01]],

        [[-3.3456e+01]],

        [[-1.5382e+01]],

        [[-1.2640e+01]],

        [[-2.3402e+01]],

        [[-1.0178e-04]],

        [[-2.6663e+01]],

        [[-3.5783e+01]],

        [[-7.1002e+01]],

        [[-3.5543e+01]],

        [[-1.0051e+02]],

        [[-9.5348e+01]],

        [[-6.3702e+01]],

        [[-1.0178e-04]],

        [[-8.5728e+01]],

        [[-1.6943e+01]],

        [[-1.7750e+01]],

        [[-5.2484e+01]],

        [[-6.5118e+00]],

        [[-1.6144e+01]],

        [[-7.9583e+01]],

        [[-1.0178e-04]],

        [[-8.3184e+01]],

        [[-6.5695e+01]],

        [[-4.9200e+01]],

        [[-4.6761e+01]],

        [[-8.0735e+01]],

        [[-4.4174e+01]],

        [[-4.0565e+01]],

        [[-1.0178e-04]],

        [[-1.8409e+00]],

        [[-1.0177e+02]],

        [[-2.1650e+01]],

        [[-7.8674e+01]],

        [[-8.4413e+01]],

        [[-4

In [36]:
variance ** 0.5

tensor([[[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[nan]],

        [[