In [1]:
import os
import numpy as np
import torch
import sampling
import graph_lib
import noise_lib

from torch.utils.data import DataLoader, Dataset

from data.synthetic import utils as data_utils

import io
import PIL
import functools

torch.backends.cudnn.benchmark = True
from torch import nn, Tensor

import math
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F


def transformer_timestep_embedding(timesteps, dim, device,max_period=10000):
    """
    Create sinusoidal timestep embeddings (like in transformer position encodings).
    timesteps: (batch,) or (N,)
    Returns: (batch, dim)
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32,device=device) / half)
    args = timesteps[:, None].float() * freqs[None]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:  # zero pad if needed
        emb = F.pad(emb, (0, 1))
    return emb


class CatMLPScoreFunc(nn.Module):
    def __init__(self, vocab_size, seq_len,cat_embed_size, num_layers, hidden_size, time_scale_factor=1000.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.cat_embed_size = cat_embed_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.time_scale_factor = time_scale_factor
        input_dim=cat_embed_size*seq_len
        self.input_dim=input_dim

        self.embed = nn.Embedding(vocab_size, cat_embed_size)
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim,hidden_size))  # will init in forward
        for _ in range(num_layers-1):
            self.layers.append(nn.Linear(hidden_size,hidden_size))  # will init in forward
        self.final = nn.Linear(hidden_size, seq_len*vocab_size)

    def forward(self, x, t):
        """
        x: (batch, seq_len) – categorical token ids
        t: (batch,) – timesteps
        """
        B, L = x.shape
        V=self.vocab_size
        x = self.embed(x)  # (B, L, cat_embed_size)
        x = x.view(B, -1)  # (B, L * cat_embed_size)

        temb = transformer_timestep_embedding(t * self.time_scale_factor, self.hidden_size,x.device).to(x.device)
        
        for layer in self.layers:
            x = layer(x) + temb
            x = F.silu(x)

        x = self.final(x)  # (B, L*vocab_size)
        x=x.view(B, L, V)
        return x


In [3]:
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
data_directory='data//synthetic//checkerboard'

In [4]:
from omegaconf import OmegaConf
import utils
from model.ema import ExponentialMovingAverage
import losses
from itertools import chain
import os

graph_type='mixed' #uniform, masked, or mixed


cfg_path=f'configs//synthetic_config_{graph_type}.yaml'
cfg = OmegaConf.load(cfg_path)
cfg.model=OmegaConf.load('configs//model//tiny.yaml')
work_dir = f'for_synthetic_data_{graph_type}_nolog_loss'
if not os.path.exists(work_dir):
    os.makedirs(work_dir)
# Create directories for experimental logs
sample_dir = os.path.join(work_dir, "samples")
checkpoint_dir = os.path.join(work_dir, "checkpoints")
checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
utils.makedirs(sample_dir)
utils.makedirs(checkpoint_dir)
utils.makedirs(os.path.dirname(checkpoint_meta_dir))
logger = utils.get_logger(os.path.join(work_dir, "logs"))
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("Found {} CUDA devices.".format(torch.cuda.device_count()))
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
else:
    print("WARNING: Using device {}".format(device))
    print(f"Found {os.cpu_count()} total number of CPUs.")

# build token graph
graph = graph_lib.get_graph(cfg, device)
    
# build score model
batch_size = cfg.training.batch_size
seq_len=32
vocab_size = 2
mask_id = vocab_size
if graph_type=='uniform':
    expand_vocab_size = vocab_size 
else:
    expand_vocab_size = vocab_size + 1
time_scale_factor=1000.0
embed_dim=256
num_layers=3

score_model = CatMLPScoreFunc(vocab_size=expand_vocab_size,cat_embed_size=embed_dim,num_layers= num_layers,
    hidden_size=embed_dim,seq_len=seq_len,time_scale_factor= 1000.0,
    ).to(device)

num_parameters = sum(p.numel() for p in score_model.parameters())
print(f"Number of parameters in the model: {num_parameters}")
ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
print(score_model)
print(f"EMA: {ema}")

# build noise
noise = noise_lib.get_noise(cfg).to(device)
#noise = DDP(noise, device_ids=[rank], static_graph=True) Z:Commented this out
sampling_eps = 1e-5


# build optimization state
optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
print(f"Optimizer: {optimizer}")
scaler = torch.cuda.amp.GradScaler()
print(f"Scaler: {scaler}")
state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 
state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)


Found 1 CUDA devices.
NVIDIA GeForce GTX 1080 	 Memory: 8.00GB
Number of parameters in the model: 2254432
CatMLPScoreFunc(
  (embed): Embedding(3, 256)
  (layers): ModuleList(
    (0): Linear(in_features=8192, out_features=256, bias=True)
    (1-2): 2 x Linear(in_features=256, out_features=256, bias=True)
  )
  (final): Linear(in_features=256, out_features=96, bias=True)
)
EMA: <model.ema.ExponentialMovingAverage object at 0x00000225AAB2C220>


  scaler = torch.cuda.amp.GradScaler()
2025-04-23 12:58:07,035 - No checkpoint found at for_synthetic_data_mixed_nolog_loss\checkpoints-meta\checkpoint.pth. Returned the same state as input


Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-06
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-06
)
Scaler: <torch.cuda.amp.grad_scaler.GradScaler object at 0x00000225AAB2C490>


In [5]:
# Load the data from file - assumes we have already generated samples using generate_data.ipynb
data_file = os.path.join(data_directory, 'data.npy')
with open(data_file, 'rb') as f:
    data = np.load(f).astype(np.int64)
    print('data shape: %s' % str(data.shape))

# Define a custom Dataset to wrap the data
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = torch.from_numpy(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create a Dataset instance
train_set = CustomDataset(data)

# Function to cycle through DataLoader
def cycle_loader(dataloader):
    while True:
        for data in dataloader:
            yield data

# Initialize DataLoader without a sampler
train_ds = cycle_loader(DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True,  # Shuffle the data as needed
    persistent_workers=False,
))

# Create an iterator for the data
train_iter = iter(train_ds)

data shape: (10000000, 32)


In [6]:
#plotting code borrowed from Sun2023

def plot(xbin, fn_xbin2float, output_file=None):
  """Visualize binary data."""
  float_data = fn_xbin2float(xbin)
  if output_file is None:  # in-memory plot
    buf = io.BytesIO()
    data_utils.plot_samples(float_data, buf, im_size=4.1, im_fmt='png')
    buf.seek(0)
    image = np.asarray(PIL.Image.open(buf))[None, ...]
    return image
  else:
    with open(output_file, 'wb') as f:
      im_fmt = 'png' if output_file.endswith('.png') else 'pdf'
      data_utils.plot_samples(float_data, f, im_size=4.1, im_fmt=im_fmt,axis=True)


class BinarySyntheticHelper(object):
  """Binary synthetic model helper."""

  def __init__(self, seq_len,int_scale):
    self.seq_len = seq_len
    self.int_scale=int_scale
    self.bm, self.inv_bm = data_utils.get_binmap(seq_len,
                                                 'gray')

  def plot(self, xbin, output_file=None):
    fn_xbin2float = functools.partial(
        data_utils.bin2float, inv_bm=self.inv_bm,
        discrete_dim=self.seq_len, int_scale=self.int_scale)
    return plot(xbin, fn_xbin2float, output_file)

int_scale=5461.760975376213
model_helper = BinarySyntheticHelper(seq_len,int_scale)

remapping binary repr with gray code


In [7]:
#the metric used in Table 1 of Lou2023
def binary_mmd(x, y, sim_fn):
  """MMD for binary data."""
  x = x.astype(np.float32)
  y = y.astype(np.float32)
  kxx = sim_fn(x, x)
  kxx = kxx * (1 - np.eye(x.shape[0]))
  kxx = np.sum(kxx) / x.shape[0] / (x.shape[0] - 1)

  kyy = sim_fn(y, y)
  kyy = kyy * (1 - np.eye(y.shape[0]))
  kyy = np.sum(kyy) / y.shape[0] / (y.shape[0] - 1)
  kxy = np.sum(sim_fn(x, y))
  kxy = kxy / x.shape[0] / y.shape[0]
  mmd = kxx + kyy - 2 * kxy
  return mmd

def binary_exp_hamming_sim(x, y, bd):
  x = np.expand_dims(x, axis=1)
  y = np.expand_dims(y, axis=0)
  d = np.sum(np.abs(x - y), axis=-1)
  return np.exp(-bd * d)

def binary_exp_hamming_mmd(x, y, bandwidth=0.1):
  sim_fn = functools.partial(binary_exp_hamming_sim, bd=bandwidth)
  return binary_mmd(x, y, sim_fn)

In [8]:
num_train_steps=cfg.training.n_iters
log_freq=cfg.training.log_freq
snapshot_freq=cfg.training.snapshot_freq
eval_rounds=1
plot_samples=cfg.plot_samples
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(cfg)
train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum)
eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum)
sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)
num_train_steps = cfg.training.n_iters
initial_step = int(state['step'])
print(f"Starting training loop at step {initial_step}.")

Starting training loop at step 0.


In [9]:
#print the metric used in Table1 of Lou2023. Should get at least as small as 1.62e-5 to be considered done training. 
def run_sampling(score_model,save_step):
    avg_mmd = 0.0
    for i in range(eval_rounds):
        gt_data = []
        for _ in range(plot_samples // batch_size):
            gt_data.append(next(train_ds).cpu().numpy())
        gt_data = np.concatenate(gt_data, axis=0)
        gt_data = np.reshape(gt_data, (-1, seq_len))
        sample_data=[]
        for _ in range(plot_samples // batch_size):
            sample_data.append(sampling_fn(score_model).cpu().numpy())
        sample_data=np.concatenate(sample_data,axis=0)
        x0 = np.reshape(sample_data, gt_data.shape)
        mmd = binary_exp_hamming_mmd(x0, gt_data)
        avg_mmd += mmd
        model_helper.plot(x0,os.path.join(sample_dir,f'step_{save_step}_eval_round_{i}.pdf')) #plot a sample from the model
    avg_mmd = avg_mmd / eval_rounds
    return avg_mmd

In [None]:
#only run this if you want to load in an old checkpoint. For some reason state = utils.restore_checkpoint is not working
loaded_state = torch.load(checkpoint_dir+'//checkpoint_30.pth', map_location=device,weights_only=False)
score_model.load_state_dict(loaded_state['model'], strict=False)

<All keys matched successfully>

In [None]:
#this is how you can run sampling without retraining
#for some reason the analytic predictor performs very poorly. 
# My guess is this is because they do not enforce the score to be 1 on the diagonals
#In particular, the score for masked positions is actually being used for the probability of a token remaining masked, despite this not being trained anywhere
#For Euler this doesn't matter, since to construct the new rate matrix with graph.reverse_rate, the diagonals are zeroed out and replaced by the - column sums
#It could also just be a bug
cfg.sampling.predictor='euler'
sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)
run_sampling(score_model,50)

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


np.float64(0.00029790387683237274)

In [None]:
for step in range(initial_step,num_train_steps+1):
#while state['step'] < num_train_steps + 1:
    #step = state['step']
    batch=next(train_iter).to(device)
    # Compute loss using the simplified loss function
    loss=train_step_fn(state, batch)
    
        
    if step % log_freq == 0:
        print("step: %d, training_loss: %.5e" % (step, loss.mean().item()))
            
    if step > 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps:
        # Save the checkpoint.
        save_step = step // cfg.training.snapshot_freq
        utils.save_checkpoint(os.path.join(
                        checkpoint_dir, f'checkpoint_{save_step}.pth'), state)
        
        #want to use the ema weights for sampling
        ema.store(score_model.parameters())
        ema.copy_to(score_model.parameters())

        avg_mmd=run_sampling(score_model,save_step)
        
        with open('output.txt', 'a') as file:
            file.write(f'time: {datetime.now()},step: {save_step}, loss: {loss.mean().item()}, MMD: {avg_mmd}\n')
            
        ema.restore(score_model.parameters())

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 0, training_loss: 1.41226e+01
step: 1000, training_loss: 7.29549e+00
step: 2000, training_loss: 6.39916e+00
step: 3000, training_loss: 7.38684e+00
step: 4000, training_loss: 8.18592e+00
step: 5000, training_loss: 4.82452e+01
step: 6000, training_loss: 5.81248e+00
step: 7000, training_loss: 8.41568e+00
step: 8000, training_loss: 7.31727e+00
step: 9000, training_loss: 6.00721e+00
step: 10000, training_loss: 7.49004e+00
Average mmd : 0.0003640006553522479


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 11000, training_loss: 7.20117e+00
step: 12000, training_loss: 6.31513e+00
step: 13000, training_loss: 1.00243e+01
step: 14000, training_loss: 6.22494e+00
step: 15000, training_loss: 1.50774e+01
step: 16000, training_loss: 7.63878e+00
step: 17000, training_loss: 5.47862e+00
step: 18000, training_loss: 4.78853e+00
step: 19000, training_loss: 8.57851e+00
step: 20000, training_loss: 8.88760e+00
Average mmd : 0.0005333319484661647


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 21000, training_loss: 8.54827e+00
step: 22000, training_loss: 4.84739e+00
step: 23000, training_loss: 7.53565e+00
step: 24000, training_loss: 6.68250e+00
step: 25000, training_loss: 5.91620e+00
step: 26000, training_loss: 6.97186e+00
step: 27000, training_loss: 7.15542e+00
step: 28000, training_loss: 5.55506e+00
step: 29000, training_loss: 6.30056e+00
step: 30000, training_loss: 7.87461e+00
Average mmd : 0.00028390120539179


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 31000, training_loss: 6.29522e+00
step: 32000, training_loss: 6.47704e+00
step: 33000, training_loss: 6.49403e+00
step: 34000, training_loss: 8.06421e+00
step: 35000, training_loss: 8.34981e+00
step: 36000, training_loss: 7.02493e+00
step: 37000, training_loss: 5.69342e+00
step: 38000, training_loss: 6.88028e+00
step: 39000, training_loss: 5.99629e+00
step: 40000, training_loss: 6.28630e+00
Average mmd : 0.00012038011993886766


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 41000, training_loss: 7.84677e+00
step: 42000, training_loss: 7.75513e+00
step: 43000, training_loss: 6.73912e+00
step: 44000, training_loss: 6.69367e+00
step: 45000, training_loss: 5.69774e+00
step: 46000, training_loss: 5.77730e+00
step: 47000, training_loss: 6.53800e+00
step: 48000, training_loss: 5.39389e+00
step: 49000, training_loss: 5.71389e+00
step: 50000, training_loss: 6.37098e+00
Average mmd : 5.086176924623542e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 51000, training_loss: 1.00758e+01
step: 52000, training_loss: 5.56583e+00
step: 53000, training_loss: 5.88243e+00
step: 54000, training_loss: 6.86145e+00
step: 55000, training_loss: 7.03091e+00
step: 56000, training_loss: 5.79215e+00
step: 57000, training_loss: 8.24274e+00
step: 58000, training_loss: 6.88065e+00
step: 59000, training_loss: 6.06866e+00
step: 60000, training_loss: 6.68651e+00
Average mmd : 0.0001695241510300538


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 61000, training_loss: 6.34751e+00
step: 62000, training_loss: 6.95545e+00
step: 63000, training_loss: 6.79104e+00
step: 64000, training_loss: 6.64213e+00
step: 65000, training_loss: 9.87277e+00
step: 66000, training_loss: 6.00274e+00
step: 67000, training_loss: 5.19443e+00
step: 68000, training_loss: 6.19763e+00
step: 69000, training_loss: 7.54788e+00
step: 70000, training_loss: 6.44177e+00
Average mmd : 4.863812519101396e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 71000, training_loss: 5.83654e+00
step: 72000, training_loss: 1.05447e+01
step: 73000, training_loss: 9.54310e+00
step: 74000, training_loss: 6.95478e+00
step: 75000, training_loss: 9.93359e+00
step: 76000, training_loss: 5.30638e+00
step: 77000, training_loss: 5.32468e+00
step: 78000, training_loss: 1.69249e+01
step: 79000, training_loss: 7.35572e+00
step: 80000, training_loss: 6.98753e+00
Average mmd : 0.00013383285598972394


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 81000, training_loss: 5.80521e+00
step: 82000, training_loss: 6.34355e+00
step: 83000, training_loss: 6.75836e+00
step: 84000, training_loss: 8.75452e+00
step: 85000, training_loss: 7.13454e+00
step: 86000, training_loss: 7.09371e+00
step: 87000, training_loss: 5.73604e+00
step: 88000, training_loss: 5.89546e+00
step: 89000, training_loss: 6.71797e+00
step: 90000, training_loss: 5.25486e+00
Average mmd : 0.00016679028102728477


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 91000, training_loss: 8.05806e+00
step: 92000, training_loss: 6.16469e+00
step: 93000, training_loss: 4.93728e+00
step: 94000, training_loss: 5.14612e+00
step: 95000, training_loss: 6.94902e+00
step: 96000, training_loss: 5.17130e+00
step: 97000, training_loss: 4.47778e+00
step: 98000, training_loss: 5.50824e+00
step: 99000, training_loss: 8.44211e+00
step: 100000, training_loss: 7.55701e+00
Average mmd : 0.0002918816018480097


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 101000, training_loss: 2.66682e+01
step: 102000, training_loss: 6.34122e+00
step: 103000, training_loss: 6.73332e+00
step: 104000, training_loss: 9.63464e+00
step: 105000, training_loss: 7.59741e+00
step: 106000, training_loss: 5.06771e+00
step: 107000, training_loss: 9.49790e+00
step: 108000, training_loss: 4.48823e+00
step: 109000, training_loss: 7.96648e+00
step: 110000, training_loss: 6.24460e+00
Average mmd : 9.020412417315438e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 111000, training_loss: 8.35492e+00
step: 112000, training_loss: 6.79260e+00
step: 113000, training_loss: 7.06547e+00
step: 114000, training_loss: 7.88346e+00
step: 115000, training_loss: 7.61630e+00
step: 116000, training_loss: 7.34974e+00
step: 117000, training_loss: 5.91417e+00
step: 118000, training_loss: 6.69014e+00
step: 119000, training_loss: 8.05774e+00
step: 120000, training_loss: 6.62320e+00
Average mmd : 8.313656158831506e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 121000, training_loss: 6.54395e+00
step: 122000, training_loss: 7.08533e+00
step: 123000, training_loss: 9.04494e+00
step: 124000, training_loss: 5.07058e+00
step: 125000, training_loss: 7.15364e+00
step: 126000, training_loss: 6.44002e+00
step: 127000, training_loss: 6.01305e+00
step: 128000, training_loss: 6.61388e+00
step: 129000, training_loss: 7.58973e+00
step: 130000, training_loss: 5.86444e+00
Average mmd : 0.0002010154093174954


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 131000, training_loss: 5.97580e+00
step: 132000, training_loss: 6.62909e+00
step: 133000, training_loss: 4.49641e+00
step: 134000, training_loss: 6.00001e+00
step: 135000, training_loss: 8.74806e+00
step: 136000, training_loss: 5.46258e+00
step: 137000, training_loss: 8.31015e+00
step: 138000, training_loss: 7.40444e+00
step: 139000, training_loss: 7.24390e+00
step: 140000, training_loss: 6.50993e+00
Average mmd : 0.00017007240940619672


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 141000, training_loss: 5.46711e+00
step: 142000, training_loss: 6.47495e+00
step: 143000, training_loss: 5.13338e+00
step: 144000, training_loss: 6.69692e+00
step: 145000, training_loss: 5.06469e+00
step: 146000, training_loss: 5.24880e+00
step: 147000, training_loss: 7.59048e+00
step: 148000, training_loss: 6.75145e+00
step: 149000, training_loss: 4.79330e+00
step: 150000, training_loss: 1.01656e+01
Average mmd : 0.00016242291064316738


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 151000, training_loss: 5.70037e+00
step: 152000, training_loss: 8.53155e+00
step: 153000, training_loss: 8.28711e+00
step: 154000, training_loss: 5.67736e+00
step: 155000, training_loss: 5.38800e+00
step: 156000, training_loss: 8.95873e+00
step: 157000, training_loss: 4.59274e+00
step: 158000, training_loss: 6.25256e+00
step: 159000, training_loss: 5.34112e+00
step: 160000, training_loss: 7.83740e+00
Average mmd : 9.468582584137852e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 161000, training_loss: 6.26759e+00
step: 162000, training_loss: 7.25783e+00
step: 163000, training_loss: 5.74947e+00
step: 164000, training_loss: 6.92614e+00
step: 165000, training_loss: 1.00917e+01
step: 166000, training_loss: 6.43917e+00
step: 167000, training_loss: 5.83975e+00
step: 168000, training_loss: 6.11093e+00
step: 169000, training_loss: 6.53162e+00
step: 170000, training_loss: 5.89162e+00
Average mmd : 0.0003490734321175415


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 171000, training_loss: 6.35228e+00
step: 172000, training_loss: 1.13059e+01
step: 173000, training_loss: 6.58251e+00
step: 174000, training_loss: 8.23549e+00
step: 175000, training_loss: 1.12891e+01
step: 176000, training_loss: 6.66461e+00
step: 177000, training_loss: 7.13236e+00
step: 178000, training_loss: 7.73617e+00
step: 179000, training_loss: 8.67772e+00
step: 180000, training_loss: 7.06429e+00
Average mmd : 0.00019060092025391384


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 181000, training_loss: 5.70394e+00
step: 182000, training_loss: 1.05196e+01
step: 183000, training_loss: 5.58397e+00
step: 184000, training_loss: 5.73516e+00
step: 185000, training_loss: 8.83622e+00
step: 186000, training_loss: 7.82314e+00
step: 187000, training_loss: 6.06141e+00
step: 188000, training_loss: 8.21182e+00
step: 189000, training_loss: 5.86241e+00
step: 190000, training_loss: 6.38748e+00
Average mmd : 0.0002595504474547372


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 191000, training_loss: 7.75131e+00
step: 192000, training_loss: 7.23712e+00
step: 193000, training_loss: 5.83502e+00
step: 194000, training_loss: 6.97503e+00
step: 195000, training_loss: 6.01744e+00
step: 196000, training_loss: 7.66085e+00
step: 197000, training_loss: 5.46033e+00
step: 198000, training_loss: 7.12464e+00
step: 199000, training_loss: 5.59343e+00
step: 200000, training_loss: 9.05375e+00
Average mmd : 0.00018088776639901827


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 201000, training_loss: 6.47947e+00
step: 202000, training_loss: 5.72223e+00
step: 203000, training_loss: 6.47357e+00
step: 204000, training_loss: 6.68989e+00
step: 205000, training_loss: 8.49762e+00
step: 206000, training_loss: 6.37124e+00
step: 207000, training_loss: 6.67030e+00
step: 208000, training_loss: 5.81955e+00
step: 209000, training_loss: 9.32071e+00
step: 210000, training_loss: 7.21467e+00
Average mmd : 0.00017813098686120243


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 211000, training_loss: 5.37753e+00
step: 212000, training_loss: 8.11352e+00
step: 213000, training_loss: 6.80517e+00
step: 214000, training_loss: 6.22781e+00
step: 215000, training_loss: 7.32033e+00
step: 216000, training_loss: 7.18331e+00
step: 217000, training_loss: 8.14957e+00
step: 218000, training_loss: 6.30968e+00
step: 219000, training_loss: 7.92091e+00
step: 220000, training_loss: 7.40144e+00
Average mmd : 0.00019098800607364463


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 221000, training_loss: 5.76765e+00
step: 222000, training_loss: 9.33074e+00
step: 223000, training_loss: 7.93380e+00
step: 224000, training_loss: 1.01962e+01
step: 225000, training_loss: 6.40387e+00
step: 226000, training_loss: 7.81309e+00
step: 227000, training_loss: 9.11794e+00
step: 228000, training_loss: 6.34831e+00
step: 229000, training_loss: 5.87277e+00
step: 230000, training_loss: 6.55686e+00
Average mmd : 0.00030725405402853845


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 231000, training_loss: 6.27106e+00
step: 232000, training_loss: 6.70316e+00
step: 233000, training_loss: 6.16224e+00
step: 234000, training_loss: 6.43222e+00
step: 235000, training_loss: 6.66565e+00
step: 236000, training_loss: 9.01056e+00
step: 237000, training_loss: 7.26311e+00
step: 238000, training_loss: 5.60443e+00
step: 239000, training_loss: 6.26367e+00
step: 240000, training_loss: 7.83426e+00
Average mmd : 0.00029431629016080096


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 241000, training_loss: 7.08268e+00
step: 242000, training_loss: 5.80816e+00
step: 243000, training_loss: 6.77672e+00
step: 244000, training_loss: 7.64548e+00
step: 245000, training_loss: 7.52814e+00
step: 246000, training_loss: 5.95119e+00
step: 247000, training_loss: 6.63526e+00
step: 248000, training_loss: 6.34319e+00
step: 249000, training_loss: 6.13735e+00
step: 250000, training_loss: 4.44191e+00
Average mmd : 0.00033410929097421604


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 251000, training_loss: 7.32456e+00
step: 252000, training_loss: 6.08945e+00
step: 253000, training_loss: 5.41248e+00
step: 254000, training_loss: 8.48174e+00
step: 255000, training_loss: 7.77450e+00
step: 256000, training_loss: 5.43819e+00
step: 257000, training_loss: 6.01581e+00
step: 258000, training_loss: 8.52086e+00
step: 259000, training_loss: 7.61097e+00
step: 260000, training_loss: 1.79977e+01
Average mmd : 0.00039947567637660386


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 261000, training_loss: 8.50192e+00
step: 262000, training_loss: 7.22020e+00
step: 263000, training_loss: 4.27770e+00
step: 264000, training_loss: 6.75724e+00
step: 265000, training_loss: 7.62675e+00
step: 266000, training_loss: 5.79714e+00
step: 267000, training_loss: 7.44918e+00
step: 268000, training_loss: 6.82013e+00
step: 269000, training_loss: 5.60125e+00
step: 270000, training_loss: 9.31445e+00
Average mmd : 0.00039709591426828617


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 271000, training_loss: 5.65927e+00
step: 272000, training_loss: 5.96117e+00
step: 273000, training_loss: 6.95887e+00
step: 274000, training_loss: 5.81115e+00
step: 275000, training_loss: 5.64223e+00
step: 276000, training_loss: 6.00206e+00
step: 277000, training_loss: 6.05838e+00
step: 278000, training_loss: 6.14898e+00
step: 279000, training_loss: 7.19445e+00
step: 280000, training_loss: 5.00060e+00
Average mmd : 0.00043629542371792507


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 281000, training_loss: 5.63261e+00
step: 282000, training_loss: 6.73427e+00
step: 283000, training_loss: 6.84433e+00
step: 284000, training_loss: 6.59641e+00
step: 285000, training_loss: 7.01850e+00
step: 286000, training_loss: 6.07613e+00
step: 287000, training_loss: 6.06228e+00
step: 288000, training_loss: 6.40427e+00
step: 289000, training_loss: 8.07224e+00
step: 290000, training_loss: 7.13761e+00
Average mmd : 0.00027940588081754036


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 291000, training_loss: 6.44657e+00
step: 292000, training_loss: 6.27870e+00
step: 293000, training_loss: 5.07977e+00
step: 294000, training_loss: 1.24528e+01
step: 295000, training_loss: 5.29632e+00
step: 296000, training_loss: 7.60272e+00
step: 297000, training_loss: 5.56409e+00
step: 298000, training_loss: 9.05021e+00
step: 299000, training_loss: 7.01238e+00
step: 300000, training_loss: 6.50277e+00
Average mmd : 0.0003802995859849556
