```
/media/john/Tertiary/Projects/ML/BayesianFlowNet/.venv/lib/python3.11/site-packages/torch/autograd/graph.py:823: UserWarning: Error detected in PowBackward0. Traceback of forward call that caused the error:
  File "/media/john/Tertiary/Projects/ML/BayesianFlowNet/train_script.py", line 96, in <module>
    train_discrete_model(
  File "/media/john/Tertiary/Projects/ML/BayesianFlowNet/src/training/training.py", line 85, in train_discrete_model
    divergence_loss(x, model.learnable_beta) * divergence_loss_strength
  File "/media/john/Tertiary/Projects/ML/BayesianFlowNet/src/training/discrete_loss.py", line 120, in divergence_loss
    beta_t_dist = y_distribution(beta_t, K, x, deterministic=True)  # should be logits
  File "/media/john/Tertiary/Projects/ML/BayesianFlowNet/src/datasets/discrete_helper.py", line 30, in y_distribution
    return mean + (variance**0.5) * epsilon
  File "/media/john/Tertiary/Projects/ML/BayesianFlowNet/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 39, in wrapped
    return f(*args, **kwargs)
 (Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:122.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Loss: 2.6713:  20%|████████████████▊                                                                 | 265895/1300000 [10:05:24<39:14:30,  7.32it/s]
Runtime error occurred: Function 'PowBackward0' returned nan values in its 0th output.
```

In [11]:
import torch

In [12]:
prev = torch.load('debug_data_past_epoch.pt')
curr = torch.load('debug_data_current_epoch.pt')

In [13]:
for k in prev:
    data = prev[k] # could be parameter or dict of parameters
    if isinstance(data, torch.Tensor) and torch.any(torch.isnan(data)):
        print(f'Prev {k} has NaNs')
    if isinstance(data, dict):
        for k2 in data:
            if torch.any(torch.isnan(data[k2])):
                print(f'Prev {k}.{k2} has NaNs')

In [14]:
for k in curr:
    data = curr[k] # could be parameter or dict of parameters
    if isinstance(data, torch.Tensor) and torch.any(torch.isnan(data)):
        print(f'Curr {k} has NaNs')
    if isinstance(data, dict):
        for k2 in data:
            if torch.any(torch.isnan(data[k2])):
                print(f'Curr {k}.{k2} has NaNs')

In [18]:

import sys
# a bit of a hack to make sure we can import the src module
if '..' not in sys.path:
    sys.path.insert(0, '..')

from src.nn.models.discrete_model import DiscreteModel
from src.tokenizers.ascii.ascii_tokenizer import ASCIITokenizer as Tokenizer
from src.training.discrete_loss import (
    alpha_below_linear_loss,
    alpha_variance_loss,
    divergence_loss,
    format_loss,
    loss,
    variance_loss,
)

# from train_script.py
tokenizer = Tokenizer()
max_seq_len = 32
folds = 8

model_kwargs = {
    "max_seq_len": max_seq_len,
    "K": tokenizer.vocab_size(),
    "hidden_dim": 512,
    "num_heads": 8,
    "layers": 5,
    "reference_beta_1": 20.4054 / tokenizer.vocab_size(),
    "learner_weight": 1.0,
    "freeze_body": False,
}
model = DiscreteModel(**model_kwargs)
model.load_state_dict(curr['model_state_dict'])

# from training.py
x = curr["x"]
t = curr["t"]

device = x.device
model.to(device)
t = t.to(device)


# from train_script.py
variance_loss_strength=0.8
divergence_loss_strength=0.8
alpha_linearity_loss_strength=0.4

output, alpha = model.forward(x, t)
formatted_loss = format_loss(
    alpha, x, model_output_logits=output, folds=folds
)
l_infty_loss = loss(formatted_loss)
var_loss = variance_loss(formatted_loss) * variance_loss_strength
alpha_var_loss = (
    alpha_variance_loss(alpha) * alpha_linearity_loss_strength
)
div_loss = (
    divergence_loss(x, model.learnable_beta) * divergence_loss_strength
)
alpha_linear_loss = alpha_below_linear_loss(x, model.learnable_beta)

l = (
    l_infty_loss
    + var_loss
    + div_loss
)

print(f"l_infty_loss: {l_infty_loss}")
print(f"var_loss: {var_loss}")
print(f"div_loss: {div_loss}")
print(f"l: {l}")

print("\nAre any of the new values NaN?")
print(f"output: {torch.isnan(output).any()}")
print(f"alpha: {torch.isnan(alpha).any()}")
print(f"l_infty_loss: {torch.isnan(l_infty_loss).any()}")
print(f"var_loss: {torch.isnan(var_loss).any()}")
print(f"div_loss: {torch.isnan(div_loss).any()}")
print(f"l: {torch.isnan(l).any()}")


l_infty_loss: 1.385916829109192
var_loss: 0.5999146699905396
div_loss: 0.3722866475582123
l: 2.3581180572509766

Are any of the new values NaN?
output: False
alpha: False
l_infty_loss: False
var_loss: False
div_loss: False
l: False


In [19]:

# Let's now debug the backward pass, as the forward pass seems fine.
# The error comes from divergence_loss, so let's re-implement it here
# and check for negative variance.

from src.datasets.discrete_helper import y_distribution
import torch.nn.functional as F

def divergence_loss_debug(x, learnable_beta):
    t_unif = torch.rand(x.shape[0], device=x.device)
    beta_t, _ = learnable_beta.get_alpha(t_unif, x.shape[-1])
    
    # Inside y_distribution
    K = x.shape[-1]
    mean = (K * x - 1) * beta_t.unsqueeze(-1).unsqueeze(-1)
    variance = (K**2 * (1 - x)) * beta_t.unsqueeze(-1).unsqueeze(-1)
    
    if torch.any(variance < 0):
        print("Negative variance detected!")
        print(f"Number of negative variance values: {torch.sum(variance < 0)}")
        print(f"Min variance: {torch.min(variance)}")
    else:
        print("No negative variance detected.")

    # The original implementation calls y_distribution with deterministic=True, which just returns the mean.
    # The error happens on the backward pass of the whole divergence_loss function.
    # The problem is not in y_distribution's forward pass, but what happens to its output.
    
    beta_t_dist = y_distribution(beta_t, K, x, deterministic=True)
    
    # The problem is likely in kl_div, where the input `x` might not be a valid distribution.
    # `x` is one-hot encoded, so log_softmax(x) will be -inf for all 0s and 0 for the 1.
    # Let's check the values of x.
    print(f"Min x: {x.min()}, Max x: {x.max()}")
    
    # The input to kl_div should be log-probabilities.
    # The target should also be probabilities if log_target=False, or log-probabilities if log_target=True.
    # Here, log_target=True, so F.log_softmax(x, dim=-1) is correct if x contains logits.
    # But x is one-hot encoded data.
    
    # Let's try to reproduce the NaN in the backward pass.
    l = divergence_loss(x, model.learnable_beta)
    print(f"divergence_loss output: {l}")
    
    try:
        l.backward()
        print("Backward pass successful.")
    except RuntimeError as e:
        print(f"Backward pass failed with error: {e}")


divergence_loss_debug(x, model.learnable_beta)


No negative variance detected.
Min x: 0.0, Max x: 1.0
divergence_loss output: 0.4226871132850647
Backward pass successful.
