In [1]:
import os
import numpy as np
import torch
import sampling
import graph_lib
import noise_lib

from torch.utils.data import DataLoader, Dataset

from data.synthetic import utils as data_utils

import io
import PIL
import functools

torch.backends.cudnn.benchmark = True
from torch import nn, Tensor

import math
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F


def transformer_timestep_embedding(timesteps, dim, device,max_period=10000):
    """
    Create sinusoidal timestep embeddings (like in transformer position encodings).
    timesteps: (batch,) or (N,)
    Returns: (batch, dim)
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32,device=device) / half)
    args = timesteps[:, None].float() * freqs[None]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:  # zero pad if needed
        emb = F.pad(emb, (0, 1))
    return emb


class CatMLPScoreFunc(nn.Module):
    def __init__(self, vocab_size, seq_len,cat_embed_size, num_layers, hidden_size, time_scale_factor=1000.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.cat_embed_size = cat_embed_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.time_scale_factor = time_scale_factor
        input_dim=cat_embed_size*seq_len
        self.input_dim=input_dim

        self.embed = nn.Embedding(vocab_size, cat_embed_size)
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim,hidden_size))  # will init in forward
        for _ in range(num_layers-1):
            self.layers.append(nn.Linear(hidden_size,hidden_size))  # will init in forward
        self.final = nn.Linear(hidden_size, seq_len*vocab_size)

    def forward(self, x, t):
        """
        x: (batch, seq_len) – categorical token ids
        t: (batch,) – timesteps
        """
        B, L = x.shape
        V=self.vocab_size
        x = self.embed(x)  # (B, L, cat_embed_size)
        x = x.view(B, -1)  # (B, L * cat_embed_size)

        temb = transformer_timestep_embedding(t * self.time_scale_factor, self.hidden_size,x.device).to(x.device)
        
        for layer in self.layers:
            x = layer(x) + temb
            x = F.silu(x)

        x = self.final(x)  # (B, L*vocab_size)
        x=x.view(B, L, V)
        return x


In [3]:
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
data_directory='data//synthetic//checkerboard'

In [4]:
from omegaconf import OmegaConf
import utils
from model.ema import ExponentialMovingAverage
import losses
from itertools import chain
import os

graph_type='mixed' #uniform, masked, or mixed


cfg_path=f'configs//synthetic_config_{graph_type}.yaml'
cfg = OmegaConf.load(cfg_path)
cfg.model=OmegaConf.load('configs//model//tiny.yaml')
work_dir = f'for_synthetic_data_{graph_type}'
if not os.path.exists(work_dir):
    os.makedirs(work_dir)
# Create directories for experimental logs
sample_dir = os.path.join(work_dir, "samples")
checkpoint_dir = os.path.join(work_dir, "checkpoints")
checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
utils.makedirs(sample_dir)
utils.makedirs(checkpoint_dir)
utils.makedirs(os.path.dirname(checkpoint_meta_dir))
logger = utils.get_logger(os.path.join(work_dir, "logs"))
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("Found {} CUDA devices.".format(torch.cuda.device_count()))
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
else:
    print("WARNING: Using device {}".format(device))
    print(f"Found {os.cpu_count()} total number of CPUs.")

# build token graph
graph = graph_lib.get_graph(cfg, device)
    
# build score model
batch_size = cfg.training.batch_size
seq_len=32
vocab_size = 2
mask_id = vocab_size
if graph_type=='uniform':
    expand_vocab_size = vocab_size 
else:
    expand_vocab_size = vocab_size + 1
time_scale_factor=1000.0
embed_dim=256
num_layers=3

score_model = CatMLPScoreFunc(vocab_size=expand_vocab_size,cat_embed_size=embed_dim,num_layers= num_layers,
    hidden_size=embed_dim,seq_len=seq_len,time_scale_factor= 1000.0,
    ).to(device)

num_parameters = sum(p.numel() for p in score_model.parameters())
print(f"Number of parameters in the model: {num_parameters}")
ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
print(score_model)
print(f"EMA: {ema}")

# build noise
noise = noise_lib.get_noise(cfg).to(device)
#noise = DDP(noise, device_ids=[rank], static_graph=True) Z:Commented this out
sampling_eps = 1e-5


# build optimization state
optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
print(f"Optimizer: {optimizer}")
scaler = torch.cuda.amp.GradScaler()
print(f"Scaler: {scaler}")
state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 
state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)


Found 1 CUDA devices.
NVIDIA GeForce GTX 1080 	 Memory: 8.00GB
Number of parameters in the model: 2254432
CatMLPScoreFunc(
  (embed): Embedding(3, 256)
  (layers): ModuleList(
    (0): Linear(in_features=8192, out_features=256, bias=True)
    (1-2): 2 x Linear(in_features=256, out_features=256, bias=True)
  )
  (final): Linear(in_features=256, out_features=96, bias=True)
)
EMA: <model.ema.ExponentialMovingAverage object at 0x00000140A46BBDC0>


  scaler = torch.cuda.amp.GradScaler()
2025-04-21 18:51:50,082 - No checkpoint found at for_synthetic_data_mixed\checkpoints-meta\checkpoint.pth. Returned the same state as input


Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-06
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-06
)
Scaler: <torch.cuda.amp.grad_scaler.GradScaler object at 0x00000140A46D4190>


In [5]:
# Load the data from file - assumes we have already generated samples using generate_data.ipynb
data_file = os.path.join(data_directory, 'data.npy')
with open(data_file, 'rb') as f:
    data = np.load(f).astype(np.int64)
    print('data shape: %s' % str(data.shape))

# Define a custom Dataset to wrap the data
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = torch.from_numpy(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create a Dataset instance
train_set = CustomDataset(data)

# Function to cycle through DataLoader
def cycle_loader(dataloader):
    while True:
        for data in dataloader:
            yield data

# Initialize DataLoader without a sampler
train_ds = cycle_loader(DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True,  # Shuffle the data as needed
    persistent_workers=False,
))

# Create an iterator for the data
train_iter = iter(train_ds)

data shape: (10000000, 32)


In [6]:
#plotting code borrowed from Sun2023

def plot(xbin, fn_xbin2float, output_file=None):
  """Visualize binary data."""
  float_data = fn_xbin2float(xbin)
  if output_file is None:  # in-memory plot
    buf = io.BytesIO()
    data_utils.plot_samples(float_data, buf, im_size=4.1, im_fmt='png')
    buf.seek(0)
    image = np.asarray(PIL.Image.open(buf))[None, ...]
    return image
  else:
    with open(output_file, 'wb') as f:
      im_fmt = 'png' if output_file.endswith('.png') else 'pdf'
      data_utils.plot_samples(float_data, f, im_size=4.1, im_fmt=im_fmt,axis=True)


class BinarySyntheticHelper(object):
  """Binary synthetic model helper."""

  def __init__(self, seq_len,int_scale):
    self.seq_len = seq_len
    self.int_scale=int_scale
    self.bm, self.inv_bm = data_utils.get_binmap(seq_len,
                                                 'gray')

  def plot(self, xbin, output_file=None):
    fn_xbin2float = functools.partial(
        data_utils.bin2float, inv_bm=self.inv_bm,
        discrete_dim=self.seq_len, int_scale=self.int_scale)
    return plot(xbin, fn_xbin2float, output_file)

int_scale=5461.760975376213
model_helper = BinarySyntheticHelper(seq_len,int_scale)

remapping binary repr with gray code


In [7]:
#the metric used in Table 1 of Lou2023
def binary_mmd(x, y, sim_fn):
  """MMD for binary data."""
  x = x.astype(np.float32)
  y = y.astype(np.float32)
  kxx = sim_fn(x, x)
  kxx = kxx * (1 - np.eye(x.shape[0]))
  kxx = np.sum(kxx) / x.shape[0] / (x.shape[0] - 1)

  kyy = sim_fn(y, y)
  kyy = kyy * (1 - np.eye(y.shape[0]))
  kyy = np.sum(kyy) / y.shape[0] / (y.shape[0] - 1)
  kxy = np.sum(sim_fn(x, y))
  kxy = kxy / x.shape[0] / y.shape[0]
  mmd = kxx + kyy - 2 * kxy
  return mmd

def binary_exp_hamming_sim(x, y, bd):
  x = np.expand_dims(x, axis=1)
  y = np.expand_dims(y, axis=0)
  d = np.sum(np.abs(x - y), axis=-1)
  return np.exp(-bd * d)

def binary_exp_hamming_mmd(x, y, bandwidth=0.1):
  sim_fn = functools.partial(binary_exp_hamming_sim, bd=bandwidth)
  return binary_mmd(x, y, sim_fn)

In [8]:
num_train_steps=cfg.training.n_iters
log_freq=cfg.training.log_freq
snapshot_freq=cfg.training.snapshot_freq
eval_rounds=1
plot_samples=cfg.plot_samples
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(cfg)
train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum)
eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum)
sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)
num_train_steps = cfg.training.n_iters
initial_step = int(state['step'])
print(f"Starting training loop at step {initial_step}.")

Starting training loop at step 0.


In [9]:
for step in range(initial_step,num_train_steps+1):
#while state['step'] < num_train_steps + 1:
    #step = state['step']
    batch=next(train_iter).to(device)
    # Compute loss using the simplified loss function
    loss=train_step_fn(state, batch)
    
        
    if step % log_freq == 0:
        print("step: %d, training_loss: %.5e" % (step, loss.mean().item()))
            
    if step > 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps:
        # Save the checkpoint.
        save_step = step // cfg.training.snapshot_freq
        utils.save_checkpoint(os.path.join(
                        checkpoint_dir, f'checkpoint_{save_step}.pth'), state)
        
        #want to use the ema weights for sampling
        ema.store(score_model.parameters())
        ema.copy_to(score_model.parameters())

        #print the metric used in Table1 of Lou2023. Should get at least as small as 1.62e-5 to be considered done training. 
        avg_mmd = 0.0
        for i in range(eval_rounds):
            gt_data = []
            for _ in range(plot_samples // batch_size):
                gt_data.append(next(train_ds).cpu().numpy())
            gt_data = np.concatenate(gt_data, axis=0)
            gt_data = np.reshape(gt_data, (-1, seq_len))
            sample_data=[]
            for _ in range(plot_samples // batch_size):
                sample_data.append(sampling_fn(score_model).cpu().numpy())
            sample_data=np.concatenate(sample_data,axis=0)
            x0 = np.reshape(sample_data, gt_data.shape)
            mmd = binary_exp_hamming_mmd(x0, gt_data)
            avg_mmd += mmd
            model_helper.plot(x0,os.path.join(sample_dir,f'step_{save_step}_eval_round_{i}.pdf')) #plot a sample from the model

        avg_mmd = avg_mmd / eval_rounds
        print(f'Average mmd : {avg_mmd}')
        with open('output.txt', 'a') as file:
            file.write(f'time: {datetime.now()},step: {save_step}, loss: {loss.mean().item()}, MMD: {avg_mmd}\n')
            
        ema.restore(score_model.parameters())

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 0, training_loss: 1.71482e+01
step: 1000, training_loss: 7.56344e+00
step: 2000, training_loss: 7.38736e+00
step: 3000, training_loss: 6.21957e+00
step: 4000, training_loss: 7.24757e+00
step: 5000, training_loss: 7.60618e+00
step: 6000, training_loss: 7.50251e+00
step: 7000, training_loss: 9.82674e+00
step: 8000, training_loss: 6.46239e+00
step: 9000, training_loss: 8.39887e+00
step: 10000, training_loss: 7.53453e+00
Average mmd : 0.008014215259827373


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 11000, training_loss: 6.20032e+00
step: 12000, training_loss: 5.55305e+00
step: 13000, training_loss: 8.70048e+00
step: 14000, training_loss: 6.91900e+00
step: 15000, training_loss: 7.50054e+00
step: 16000, training_loss: 7.99133e+00
step: 17000, training_loss: 8.19992e+00
step: 18000, training_loss: 6.02326e+00
step: 19000, training_loss: 6.65472e+00
step: 20000, training_loss: 5.56809e+00
Average mmd : 0.006183738895954027


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 21000, training_loss: 5.24581e+00
step: 22000, training_loss: 6.28522e+00
step: 23000, training_loss: 5.63971e+00
step: 24000, training_loss: 5.14239e+00
step: 25000, training_loss: 5.73482e+00
step: 26000, training_loss: 6.30408e+00
step: 27000, training_loss: 7.19137e+00
step: 28000, training_loss: 5.68646e+00
step: 29000, training_loss: 1.13469e+01
step: 30000, training_loss: 5.94672e+00
Average mmd : 0.005127112353450314


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 31000, training_loss: 8.28453e+00
step: 32000, training_loss: 5.42605e+00
step: 33000, training_loss: 7.10979e+00
step: 34000, training_loss: 4.94279e+00
step: 35000, training_loss: 5.16012e+00
step: 36000, training_loss: 7.05031e+00
step: 37000, training_loss: 4.96110e+00
step: 38000, training_loss: 9.43828e+00
step: 39000, training_loss: 5.00266e+00
step: 40000, training_loss: 6.73908e+00
Average mmd : 0.004143952856917599


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 41000, training_loss: 5.83709e+00
step: 42000, training_loss: 6.91426e+00
step: 43000, training_loss: 4.69429e+00
step: 44000, training_loss: 6.28773e+00
step: 45000, training_loss: 6.68255e+00
step: 46000, training_loss: 6.14435e+00
step: 47000, training_loss: 5.97496e+00
step: 48000, training_loss: 8.47275e+00
step: 49000, training_loss: 6.89346e+00
step: 50000, training_loss: 6.67046e+00
Average mmd : 0.004575772925793142


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 51000, training_loss: 6.07232e+00
step: 52000, training_loss: 7.39342e+00
step: 53000, training_loss: 5.48835e+00
step: 54000, training_loss: 9.65886e+00
step: 55000, training_loss: 5.84599e+00
step: 56000, training_loss: 5.10499e+00
step: 57000, training_loss: 5.68367e+00
step: 58000, training_loss: 5.74119e+00
step: 59000, training_loss: 6.80108e+00
step: 60000, training_loss: 7.66481e+00
Average mmd : 0.004514003534140776


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 61000, training_loss: 5.72118e+00
step: 62000, training_loss: 6.78144e+00
step: 63000, training_loss: 5.53180e+00
step: 64000, training_loss: 4.87567e+00
step: 65000, training_loss: 8.20247e+00
step: 66000, training_loss: 6.39982e+00
step: 67000, training_loss: 1.13396e+01
step: 68000, training_loss: 6.64314e+00
step: 69000, training_loss: 6.53169e+00
step: 70000, training_loss: 7.70550e+00
Average mmd : 0.005099337407798943


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 71000, training_loss: 7.86062e+00
step: 72000, training_loss: 5.79530e+00
step: 73000, training_loss: 7.30661e+00
step: 74000, training_loss: 7.95391e+00
step: 75000, training_loss: 8.81277e+00
step: 76000, training_loss: 7.53140e+00
step: 77000, training_loss: 7.14815e+00
step: 78000, training_loss: 5.99551e+00
step: 79000, training_loss: 6.14385e+00
step: 80000, training_loss: 6.36565e+00
Average mmd : 0.0042862445600751475


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 81000, training_loss: 6.86002e+00
step: 82000, training_loss: 7.01986e+00
step: 83000, training_loss: 6.48730e+00
step: 84000, training_loss: 6.09210e+00
step: 85000, training_loss: 5.76579e+00
step: 86000, training_loss: 5.80154e+00
step: 87000, training_loss: 7.96791e+00
step: 88000, training_loss: 5.62312e+00
step: 89000, training_loss: 7.18913e+00
step: 90000, training_loss: 7.71258e+00
Average mmd : 0.004098657170211739


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 91000, training_loss: 5.51851e+00
step: 92000, training_loss: 6.09230e+00
step: 93000, training_loss: 7.66061e+00
step: 94000, training_loss: 7.52122e+00
step: 95000, training_loss: 8.76316e+00
step: 96000, training_loss: 5.80143e+00
step: 97000, training_loss: 1.18751e+01
step: 98000, training_loss: 6.87599e+00
step: 99000, training_loss: 6.90630e+00
step: 100000, training_loss: 5.99662e+00
Average mmd : 0.004875880921448583


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 101000, training_loss: 6.19692e+00
step: 102000, training_loss: 6.72991e+00
step: 103000, training_loss: 6.48671e+00
step: 104000, training_loss: 6.53297e+00
step: 105000, training_loss: 6.48122e+00
step: 106000, training_loss: 8.79349e+00
step: 107000, training_loss: 5.41652e+00
step: 108000, training_loss: 5.59052e+00
step: 109000, training_loss: 7.61667e+00
step: 110000, training_loss: 5.01813e+00
Average mmd : 0.005549620991961768


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 111000, training_loss: 6.71742e+00
step: 112000, training_loss: 5.74333e+00
step: 113000, training_loss: 5.14003e+00
step: 114000, training_loss: 4.82829e+00
step: 115000, training_loss: 5.64233e+00
step: 116000, training_loss: 6.86800e+00
step: 117000, training_loss: 7.07904e+00
step: 118000, training_loss: 8.38668e+00
step: 119000, training_loss: 6.26332e+00
step: 120000, training_loss: 7.00603e+00
Average mmd : 0.0055802792701933335


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 121000, training_loss: 6.59375e+00
step: 122000, training_loss: 7.51922e+00
step: 123000, training_loss: 6.95140e+00
step: 124000, training_loss: 9.80435e+00
step: 125000, training_loss: 4.68892e+00
step: 126000, training_loss: 5.65260e+00
step: 127000, training_loss: 5.41659e+00
step: 128000, training_loss: 7.78098e+00
step: 129000, training_loss: 6.44449e+00
step: 130000, training_loss: 5.68276e+00
Average mmd : 0.005255278781293005


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 131000, training_loss: 4.20282e+00
step: 132000, training_loss: 8.50259e+00
step: 133000, training_loss: 8.23747e+00
step: 134000, training_loss: 5.64922e+00
step: 135000, training_loss: 1.01628e+01
step: 136000, training_loss: 7.24032e+00
step: 137000, training_loss: 5.50666e+00
step: 138000, training_loss: 6.98588e+00
step: 139000, training_loss: 8.30356e+00
step: 140000, training_loss: 8.10367e+00
Average mmd : 0.005060532787630023


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 141000, training_loss: 5.91157e+00
step: 142000, training_loss: 4.80830e+00
step: 143000, training_loss: 7.09075e+00
step: 144000, training_loss: 7.48178e+00
step: 145000, training_loss: 6.33236e+00
step: 146000, training_loss: 8.68794e+00
step: 147000, training_loss: 8.68512e+00
step: 148000, training_loss: 7.30269e+00
step: 149000, training_loss: 5.57020e+00
step: 150000, training_loss: 5.81595e+00
Average mmd : 0.0056356134093047405


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 151000, training_loss: 6.71158e+00
step: 152000, training_loss: 6.25601e+00
step: 153000, training_loss: 5.92863e+00
step: 154000, training_loss: 8.13591e+00
step: 155000, training_loss: 6.51802e+00
step: 156000, training_loss: 6.96736e+00
step: 157000, training_loss: 7.22419e+00
step: 158000, training_loss: 7.76303e+00
step: 159000, training_loss: 6.59953e+00
step: 160000, training_loss: 7.04725e+00
Average mmd : 0.0055915635959540855


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 161000, training_loss: 5.25667e+00
step: 162000, training_loss: 9.68425e+00
step: 163000, training_loss: 7.47592e+00
step: 164000, training_loss: 6.67079e+00
step: 165000, training_loss: 5.43872e+00
step: 166000, training_loss: 5.41939e+00
step: 167000, training_loss: 6.28497e+00
step: 168000, training_loss: 6.48529e+00
step: 169000, training_loss: 5.32108e+00
step: 170000, training_loss: 5.52277e+00
Average mmd : 0.006086738169605865


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 171000, training_loss: 8.03352e+00
step: 172000, training_loss: 7.21258e+00
step: 173000, training_loss: 5.76316e+00
step: 174000, training_loss: 6.49017e+00
step: 175000, training_loss: 9.95176e+00
step: 176000, training_loss: 7.28934e+00
step: 177000, training_loss: 8.18254e+00
step: 178000, training_loss: 4.46652e+00
step: 179000, training_loss: 5.37013e+00
step: 180000, training_loss: 8.12510e+00
Average mmd : 0.006452796496756896


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 181000, training_loss: 5.63393e+00
step: 182000, training_loss: 5.62833e+00
step: 183000, training_loss: 1.13450e+01
step: 184000, training_loss: 6.22040e+00
step: 185000, training_loss: 5.29311e+00
step: 186000, training_loss: 6.74946e+00
step: 187000, training_loss: 8.16230e+00
step: 188000, training_loss: 6.11681e+00
step: 189000, training_loss: 1.12089e+01
step: 190000, training_loss: 7.59340e+00
Average mmd : 0.0063040166436793


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 191000, training_loss: 6.00727e+00
step: 192000, training_loss: 4.65562e+00
step: 193000, training_loss: 5.71405e+00
step: 194000, training_loss: 6.38425e+00
step: 195000, training_loss: 6.78953e+00
step: 196000, training_loss: 6.33279e+00
step: 197000, training_loss: 5.08725e+00
step: 198000, training_loss: 9.66574e+00
step: 199000, training_loss: 6.70251e+00
step: 200000, training_loss: 5.92803e+00
Average mmd : 0.006165108258018359


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 201000, training_loss: 1.15165e+01
step: 202000, training_loss: 6.07237e+00
step: 203000, training_loss: 7.09740e+00
step: 204000, training_loss: 6.69464e+00
step: 205000, training_loss: 6.28498e+00
step: 206000, training_loss: 6.29545e+00
step: 207000, training_loss: 6.52742e+00
step: 208000, training_loss: 7.01406e+00
step: 209000, training_loss: 7.05415e+00
step: 210000, training_loss: 6.86799e+00
Average mmd : 0.006842140646523498


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 211000, training_loss: 6.57255e+00
step: 212000, training_loss: 6.04637e+00
step: 213000, training_loss: 1.42476e+01
step: 214000, training_loss: 5.09158e+00
step: 215000, training_loss: 6.04804e+00
step: 216000, training_loss: 5.60534e+00
step: 217000, training_loss: 4.80957e+00
step: 218000, training_loss: 6.15242e+00
step: 219000, training_loss: 5.80425e+00
step: 220000, training_loss: 5.36893e+00
Average mmd : 0.006704628274250335


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 221000, training_loss: 6.52713e+00
step: 222000, training_loss: 6.13855e+00
step: 223000, training_loss: 6.13467e+00
step: 224000, training_loss: 7.67262e+00
step: 225000, training_loss: 5.71007e+00
step: 226000, training_loss: 7.31298e+00
step: 227000, training_loss: 7.78920e+00
step: 228000, training_loss: 8.97919e+00
step: 229000, training_loss: 6.33180e+00
step: 230000, training_loss: 6.66581e+00
Average mmd : 0.007195369014367681


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 231000, training_loss: 9.54192e+00
step: 232000, training_loss: 7.09519e+00
step: 233000, training_loss: 7.10993e+00
step: 234000, training_loss: 4.93627e+00
step: 235000, training_loss: 5.87007e+00
step: 236000, training_loss: 7.17744e+00
step: 237000, training_loss: 6.73492e+00
step: 238000, training_loss: 5.93736e+00
step: 239000, training_loss: 6.10162e+00
step: 240000, training_loss: 7.03871e+00
Average mmd : 0.006454251440129122


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 241000, training_loss: 4.65923e+00
step: 242000, training_loss: 6.31370e+00
step: 243000, training_loss: 9.97977e+00
step: 244000, training_loss: 8.21424e+00
step: 245000, training_loss: 5.05551e+00
step: 246000, training_loss: 1.55409e+01
step: 247000, training_loss: 5.14649e+00
step: 248000, training_loss: 5.15760e+00
step: 249000, training_loss: 1.04008e+01
step: 250000, training_loss: 5.41006e+00
Average mmd : 0.0067997103632913


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 251000, training_loss: 4.25668e+00
step: 252000, training_loss: 5.66461e+00
step: 253000, training_loss: 7.52434e+00
step: 254000, training_loss: 5.98570e+00
step: 255000, training_loss: 6.20129e+00
step: 256000, training_loss: 7.13139e+00
step: 257000, training_loss: 6.76070e+00
step: 258000, training_loss: 6.86370e+00
step: 259000, training_loss: 7.74012e+00
step: 260000, training_loss: 9.28472e+00
Average mmd : 0.006375067448422089


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 261000, training_loss: 6.94636e+00
step: 262000, training_loss: 8.44631e+00
step: 263000, training_loss: 6.16020e+00
step: 264000, training_loss: 5.72176e+00
step: 265000, training_loss: 5.40070e+00
step: 266000, training_loss: 7.17984e+00
step: 267000, training_loss: 6.37594e+00
step: 268000, training_loss: 5.91050e+00
step: 269000, training_loss: 5.71467e+00
step: 270000, training_loss: 7.44411e+00
Average mmd : 0.006250351471408178


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 271000, training_loss: 5.95332e+00
step: 272000, training_loss: 7.86274e+00
step: 273000, training_loss: 6.24064e+00
step: 274000, training_loss: 7.34550e+00
step: 275000, training_loss: 5.67584e+00
step: 276000, training_loss: 7.07193e+00
step: 277000, training_loss: 6.87819e+00
step: 278000, training_loss: 5.44106e+00
step: 279000, training_loss: 6.16385e+00
step: 280000, training_loss: 6.69821e+00
Average mmd : 0.006746526415710552


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 281000, training_loss: 8.13327e+00
step: 282000, training_loss: 7.64615e+00
step: 283000, training_loss: 2.52936e+01
step: 284000, training_loss: 5.13257e+00
step: 285000, training_loss: 6.95082e+00
step: 286000, training_loss: 5.01486e+00
step: 287000, training_loss: 7.13616e+00
step: 288000, training_loss: 6.52189e+00
step: 289000, training_loss: 7.85740e+00
step: 290000, training_loss: 6.25495e+00
Average mmd : 0.00759036481226838


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 291000, training_loss: 4.73005e+00
step: 292000, training_loss: 6.62744e+00
step: 293000, training_loss: 8.96135e+00
step: 294000, training_loss: 6.60660e+00
step: 295000, training_loss: 7.99044e+00
step: 296000, training_loss: 6.78157e+00
step: 297000, training_loss: 5.61602e+00
step: 298000, training_loss: 7.39843e+00
step: 299000, training_loss: 5.26534e+00
step: 300000, training_loss: 4.51206e+00
Average mmd : 0.006221866321461489
