In [12]:
import os
import numpy as np
import torch
import sampling
import graph_lib
import noise_lib

from torch.utils.data import DataLoader, Dataset

from data.synthetic import utils as data_utils

import io
import PIL
import functools

torch.backends.cudnn.benchmark = True
from torch import nn, Tensor

import math
from datetime import datetime

In [13]:
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F


def transformer_timestep_embedding(timesteps, dim, device,max_period=10000):
    """
    Create sinusoidal timestep embeddings (like in transformer position encodings).
    timesteps: (batch,) or (N,)
    Returns: (batch, dim)
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32,device=device) / half)
    args = timesteps[:, None].float() * freqs[None]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:  # zero pad if needed
        emb = F.pad(emb, (0, 1))
    return emb


class CatMLPScoreFunc(nn.Module):
    def __init__(self, vocab_size, seq_len,cat_embed_size, num_layers, hidden_size, time_scale_factor=1000.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.cat_embed_size = cat_embed_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.time_scale_factor = time_scale_factor
        input_dim=cat_embed_size*seq_len
        self.input_dim=input_dim

        self.embed = nn.Embedding(vocab_size, cat_embed_size)
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim,hidden_size))  # will init in forward
        for _ in range(num_layers-1):
            self.layers.append(nn.Linear(hidden_size,hidden_size))  # will init in forward
        self.final = nn.Linear(hidden_size, seq_len*vocab_size)

    def forward(self, x, t):
        """
        x: (batch, seq_len) – categorical token ids
        t: (batch,) – timesteps
        """
        B, L = x.shape
        V=self.vocab_size
        x = self.embed(x)  # (B, L, cat_embed_size)
        x = x.view(B, -1)  # (B, L * cat_embed_size)

        temb = transformer_timestep_embedding(t * self.time_scale_factor, self.hidden_size,x.device).to(x.device)
        
        for layer in self.layers:
            x = layer(x) + temb
            x = F.silu(x)

        x = self.final(x)  # (B, L*vocab_size)
        x=x.view(B, L, V)
        return x


In [14]:
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
data_directory='data//synthetic//checkerboard'

In [None]:
from omegaconf import OmegaConf
import utils
from model.ema import ExponentialMovingAverage
import losses
from itertools import chain
import os

graph_type='mixed' #uniform, masked, or mixed


cfg_path=f'configs//synthetic_config_{graph_type}.yaml'
cfg = OmegaConf.load(cfg_path)
cfg.model=OmegaConf.load('configs//model//tiny.yaml')
work_dir = f'for_synthetic_data_{graph_type}'
if not os.path.exists(work_dir):
    os.makedirs(work_dir)
# Create directories for experimental logs
sample_dir = os.path.join(work_dir, "samples")
checkpoint_dir = os.path.join(work_dir, "checkpoints")
checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
utils.makedirs(sample_dir)
utils.makedirs(checkpoint_dir)
utils.makedirs(os.path.dirname(checkpoint_meta_dir))
logger = utils.get_logger(os.path.join(work_dir, "logs"))
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("Found {} CUDA devices.".format(torch.cuda.device_count()))
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
else:
    print("WARNING: Using device {}".format(device))
    print(f"Found {os.cpu_count()} total number of CPUs.")

# build token graph
graph = graph_lib.get_graph(cfg, device)
    
# build score model
batch_size = cfg.training.batch_size
seq_len=32
vocab_size = 2
mask_id = vocab_size
if graph_type=='uniform':
    expand_vocab_size = vocab_size 
else:
    expand_vocab_size = vocab_size + 1
time_scale_factor=1000.0
embed_dim=256
num_layers=3

score_model = CatMLPScoreFunc(vocab_size=expand_vocab_size,cat_embed_size=embed_dim,num_layers= num_layers,
    hidden_size=embed_dim,seq_len=seq_len,time_scale_factor= 1000.0,
    ).to(device)

num_parameters = sum(p.numel() for p in score_model.parameters())
print(f"Number of parameters in the model: {num_parameters}")
ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
print(score_model)
print(f"EMA: {ema}")

# build noise
noise = noise_lib.get_noise(cfg).to(device)
#noise = DDP(noise, device_ids=[rank], static_graph=True) Z:Commented this out
sampling_eps = 1e-5


# build optimization state
optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
print(f"Optimizer: {optimizer}")
scaler = torch.cuda.amp.GradScaler()
print(f"Scaler: {scaler}")
state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 
state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)


Found 1 CUDA devices.
NVIDIA GeForce GTX 1080 	 Memory: 8.00GB
Number of parameters in the model: 2245952
CatMLPScoreFunc(
  (embed): Embedding(2, 256)
  (layers): ModuleList(
    (0): Linear(in_features=8192, out_features=256, bias=True)
    (1-2): 2 x Linear(in_features=256, out_features=256, bias=True)
  )
  (final): Linear(in_features=256, out_features=64, bias=True)
)
EMA: <model.ema.ExponentialMovingAverage object at 0x00000255D7AC3F10>
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-06
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-06
)
Scaler: <torch.cuda.amp.grad_scaler.GradScaler object at 0x00000255D7AC30A0>


  scaler = torch.cuda.amp.GradScaler()
2025-04-18 11:21:11,060 - No checkpoint found at for_synthetic_data_uniform\checkpoints-meta\checkpoint.pth. Returned the same state as input


In [16]:
# Load the data from file - assumes we have already generated samples using generate_data.ipynb
data_file = os.path.join(data_directory, 'data.npy')
with open(data_file, 'rb') as f:
    data = np.load(f).astype(np.int64)
    print('data shape: %s' % str(data.shape))

# Define a custom Dataset to wrap the data
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = torch.from_numpy(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create a Dataset instance
train_set = CustomDataset(data)

# Function to cycle through DataLoader
def cycle_loader(dataloader):
    while True:
        for data in dataloader:
            yield data

# Initialize DataLoader without a sampler
train_ds = cycle_loader(DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=True,  # Shuffle the data as needed
    persistent_workers=False,
))

# Create an iterator for the data
train_iter = iter(train_ds)

data shape: (10000000, 32)


In [17]:
#plotting code borrowed from Sun2023

def plot(xbin, fn_xbin2float, output_file=None):
  """Visualize binary data."""
  float_data = fn_xbin2float(xbin)
  if output_file is None:  # in-memory plot
    buf = io.BytesIO()
    data_utils.plot_samples(float_data, buf, im_size=4.1, im_fmt='png')
    buf.seek(0)
    image = np.asarray(PIL.Image.open(buf))[None, ...]
    return image
  else:
    with open(output_file, 'wb') as f:
      im_fmt = 'png' if output_file.endswith('.png') else 'pdf'
      data_utils.plot_samples(float_data, f, im_size=4.1, im_fmt=im_fmt,axis=True)


class BinarySyntheticHelper(object):
  """Binary synthetic model helper."""

  def __init__(self, seq_len,int_scale):
    self.seq_len = seq_len
    self.int_scale=int_scale
    self.bm, self.inv_bm = data_utils.get_binmap(seq_len,
                                                 'gray')

  def plot(self, xbin, output_file=None):
    fn_xbin2float = functools.partial(
        data_utils.bin2float, inv_bm=self.inv_bm,
        discrete_dim=self.seq_len, int_scale=self.int_scale)
    return plot(xbin, fn_xbin2float, output_file)

int_scale=5461.760975376213
model_helper = BinarySyntheticHelper(seq_len,int_scale)

remapping binary repr with gray code


In [18]:
#the metric used in Table 1 of Lou2023
def binary_mmd(x, y, sim_fn):
  """MMD for binary data."""
  x = x.astype(np.float32)
  y = y.astype(np.float32)
  kxx = sim_fn(x, x)
  kxx = kxx * (1 - np.eye(x.shape[0]))
  kxx = np.sum(kxx) / x.shape[0] / (x.shape[0] - 1)

  kyy = sim_fn(y, y)
  kyy = kyy * (1 - np.eye(y.shape[0]))
  kyy = np.sum(kyy) / y.shape[0] / (y.shape[0] - 1)
  kxy = np.sum(sim_fn(x, y))
  kxy = kxy / x.shape[0] / y.shape[0]
  mmd = kxx + kyy - 2 * kxy
  return mmd

def binary_exp_hamming_sim(x, y, bd):
  x = np.expand_dims(x, axis=1)
  y = np.expand_dims(y, axis=0)
  d = np.sum(np.abs(x - y), axis=-1)
  return np.exp(-bd * d)

def binary_exp_hamming_mmd(x, y, bandwidth=0.1):
  sim_fn = functools.partial(binary_exp_hamming_sim, bd=bandwidth)
  return binary_mmd(x, y, sim_fn)

In [19]:
num_train_steps=cfg.training.n_iters
log_freq=cfg.training.log_freq
snapshot_freq=cfg.training.snapshot_freq
eval_rounds=1
plot_samples=cfg.plot_samples
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(cfg)
train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum)
eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum)
sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)
num_train_steps = cfg.training.n_iters
initial_step = int(state['step'])
print(f"Starting training loop at step {initial_step}.")

Starting training loop at step 0.


In [20]:
for step in range(initial_step,num_train_steps+1):
#while state['step'] < num_train_steps + 1:
    #step = state['step']
    batch=next(train_iter).to(device)
    # Compute loss using the simplified loss function
    loss=train_step_fn(state, batch)
    
        
    if step % log_freq == 0:
        print("step: %d, training_loss: %.5e" % (step, loss.mean().item()))
            
    if step > 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps:
        # Save the checkpoint.
        save_step = step // cfg.training.snapshot_freq
        utils.save_checkpoint(os.path.join(
                        checkpoint_dir, f'checkpoint_{save_step}.pth'), state)
        
        #want to use the ema weights for sampling
        ema.store(score_model.parameters())
        ema.copy_to(score_model.parameters())

        #print the metric used in Table1 of Lou2023. Should get at least as small as 1.62e-5 to be considered done training. 
        avg_mmd = 0.0
        for i in range(eval_rounds):
            gt_data = []
            for _ in range(plot_samples // batch_size):
                gt_data.append(next(train_ds).cpu().numpy())
            gt_data = np.concatenate(gt_data, axis=0)
            gt_data = np.reshape(gt_data, (-1, seq_len))
            sample_data=[]
            for _ in range(plot_samples // batch_size):
                sample_data.append(sampling_fn(score_model).cpu().numpy())
            sample_data=np.concatenate(sample_data,axis=0)
            x0 = np.reshape(sample_data, gt_data.shape)
            mmd = binary_exp_hamming_mmd(x0, gt_data)
            avg_mmd += mmd
            model_helper.plot(x0,os.path.join(sample_dir,f'step_{save_step}_eval_round_{i}.pdf')) #plot a sample from the model

        avg_mmd = avg_mmd / eval_rounds
        print(f'Average mmd : {avg_mmd}')
        with open('output.txt', 'a') as file:
            file.write(f'time: {datetime.now()},step: {save_step}, loss: {loss.mean().item()}, MMD: {avg_mmd}\n')
            
        ema.restore(score_model.parameters())

  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 0, training_loss: 2.42984e+01
step: 1000, training_loss: 2.35295e+01
step: 2000, training_loss: 2.08440e+01
step: 3000, training_loss: 1.99424e+01
step: 4000, training_loss: 2.07609e+01
step: 5000, training_loss: 3.32889e+01
step: 6000, training_loss: 1.94826e+01
step: 7000, training_loss: 1.96789e+01
step: 8000, training_loss: 1.98326e+01
step: 9000, training_loss: 2.04955e+01
step: 10000, training_loss: 2.72859e+01
Average mmd : 0.0001212035835974623


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 11000, training_loss: 2.67193e+01
step: 12000, training_loss: 1.84116e+01
step: 13000, training_loss: 2.13653e+01
step: 14000, training_loss: 2.62131e+01
step: 15000, training_loss: 2.14237e+01
step: 16000, training_loss: 2.06137e+01
step: 17000, training_loss: 1.48178e+01
step: 18000, training_loss: 2.26144e+01
step: 19000, training_loss: 2.12487e+01
step: 20000, training_loss: 2.29754e+01
Average mmd : 1.632814446689279e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 21000, training_loss: 1.97591e+01
step: 22000, training_loss: 1.81924e+01
step: 23000, training_loss: 1.84218e+01
step: 24000, training_loss: 2.00165e+01
step: 25000, training_loss: 1.96742e+01
step: 26000, training_loss: 1.84061e+01
step: 27000, training_loss: 1.89762e+01
step: 28000, training_loss: 1.77145e+01
step: 29000, training_loss: 3.01783e+01
step: 30000, training_loss: 2.55082e+01
Average mmd : -5.340756874139263e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 31000, training_loss: 2.04053e+01
step: 32000, training_loss: 2.06433e+01
step: 33000, training_loss: 1.89799e+01
step: 34000, training_loss: 1.94992e+01
step: 35000, training_loss: 2.11672e+01
step: 36000, training_loss: 1.81417e+01
step: 37000, training_loss: 2.31990e+01
step: 38000, training_loss: 1.88216e+01
step: 39000, training_loss: 2.11051e+01
step: 40000, training_loss: 1.95474e+01
Average mmd : 2.5428415665662563e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 41000, training_loss: 1.83713e+01
step: 42000, training_loss: 2.45525e+01
step: 43000, training_loss: 2.84870e+01
step: 44000, training_loss: 2.54536e+01
step: 45000, training_loss: 2.30064e+01
step: 46000, training_loss: 1.83735e+01
step: 47000, training_loss: 1.92659e+01
step: 48000, training_loss: 1.80435e+01
step: 49000, training_loss: 1.95341e+01
step: 50000, training_loss: 2.03498e+01
Average mmd : 2.825103205378321e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 51000, training_loss: 2.11859e+01
step: 52000, training_loss: 1.85472e+01
step: 53000, training_loss: 2.09118e+01
step: 54000, training_loss: 2.56620e+01
step: 55000, training_loss: 1.70449e+01
step: 56000, training_loss: 2.08101e+01
step: 57000, training_loss: 2.00293e+01
step: 58000, training_loss: 3.78715e+01
step: 59000, training_loss: 1.89577e+01
step: 60000, training_loss: 1.68416e+01
Average mmd : -5.002449323665559e-07


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 61000, training_loss: 2.10935e+01
step: 62000, training_loss: 1.94929e+01
step: 63000, training_loss: 1.72051e+01
step: 64000, training_loss: 2.01241e+01
step: 65000, training_loss: 1.91059e+01
step: 66000, training_loss: 1.71747e+01
step: 67000, training_loss: 1.71328e+01
step: 68000, training_loss: 2.14887e+01
step: 69000, training_loss: 1.93940e+01
step: 70000, training_loss: 2.11639e+01
Average mmd : -5.656282333710294e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 71000, training_loss: 1.93090e+01
step: 72000, training_loss: 1.90613e+01
step: 73000, training_loss: 3.55533e+01
step: 74000, training_loss: 2.13146e+01
step: 75000, training_loss: 2.36965e+01
step: 76000, training_loss: 2.37611e+01
step: 77000, training_loss: 1.74447e+01
step: 78000, training_loss: 1.91886e+01
step: 79000, training_loss: 1.81170e+01
step: 80000, training_loss: 2.21272e+01
Average mmd : 1.7922467292486033e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 81000, training_loss: 1.93499e+01
step: 82000, training_loss: 1.79138e+01
step: 83000, training_loss: 1.71285e+01
step: 84000, training_loss: 2.06671e+01
step: 85000, training_loss: 2.04101e+01
step: 86000, training_loss: 2.42155e+01
step: 87000, training_loss: 1.85299e+01
step: 88000, training_loss: 1.68726e+01
step: 89000, training_loss: 1.72409e+01
step: 90000, training_loss: 1.81432e+01
Average mmd : -3.711220933344528e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 91000, training_loss: 1.86171e+01
step: 92000, training_loss: 1.87462e+01
step: 93000, training_loss: 2.26295e+01
step: 94000, training_loss: 2.15814e+01
step: 95000, training_loss: 2.26632e+01
step: 96000, training_loss: 1.95392e+01
step: 97000, training_loss: 1.87940e+01
step: 98000, training_loss: 1.69678e+01
step: 99000, training_loss: 1.93619e+01
step: 100000, training_loss: 1.75012e+01
Average mmd : -1.0201209253379862e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 101000, training_loss: 2.12920e+01
step: 102000, training_loss: 2.28504e+01
step: 103000, training_loss: 1.79551e+01
step: 104000, training_loss: 1.83885e+01
step: 105000, training_loss: 2.64216e+01
step: 106000, training_loss: 1.97800e+01
step: 107000, training_loss: 1.95902e+01
step: 108000, training_loss: 1.85535e+01
step: 109000, training_loss: 1.78531e+01
step: 110000, training_loss: 2.14032e+01
Average mmd : -2.9705265058466157e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 111000, training_loss: 2.18645e+01
step: 112000, training_loss: 1.82142e+01
step: 113000, training_loss: 1.92611e+01
step: 114000, training_loss: 3.28514e+01
step: 115000, training_loss: 2.18657e+01
step: 116000, training_loss: 2.02603e+01
step: 117000, training_loss: 1.62913e+01
step: 118000, training_loss: 1.77747e+01
step: 119000, training_loss: 2.28871e+01
step: 120000, training_loss: 2.08035e+01
Average mmd : 7.08472549915129e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 121000, training_loss: 1.92482e+01
step: 122000, training_loss: 4.01608e+01
step: 123000, training_loss: 2.23597e+01
step: 124000, training_loss: 1.68935e+01
step: 125000, training_loss: 2.63937e+01
step: 126000, training_loss: 1.91021e+01
step: 127000, training_loss: 1.86300e+01
step: 128000, training_loss: 2.03073e+01
step: 129000, training_loss: 1.79578e+01
step: 130000, training_loss: 2.17059e+01
Average mmd : 3.0190679296770995e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 131000, training_loss: 1.83912e+01
step: 132000, training_loss: 1.90754e+01
step: 133000, training_loss: 2.00229e+01
step: 134000, training_loss: 2.07455e+01
step: 135000, training_loss: 2.31827e+01
step: 136000, training_loss: 1.99798e+01
step: 137000, training_loss: 2.67335e+01
step: 138000, training_loss: 2.16678e+01
step: 139000, training_loss: 2.60009e+01
step: 140000, training_loss: 1.56220e+01
Average mmd : 8.318851971389485e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 141000, training_loss: 1.98153e+01
step: 142000, training_loss: 1.43805e+01
step: 143000, training_loss: 1.83044e+01
step: 144000, training_loss: 1.99401e+01
step: 145000, training_loss: 1.77299e+01
step: 146000, training_loss: 1.94549e+01
step: 147000, training_loss: 2.24804e+01
step: 148000, training_loss: 2.44748e+01
step: 149000, training_loss: 1.97695e+01
step: 150000, training_loss: 2.05691e+01
Average mmd : 7.205973194390758e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 151000, training_loss: 1.87167e+01
step: 152000, training_loss: 2.35711e+01
step: 153000, training_loss: 1.55677e+01
step: 154000, training_loss: 1.63002e+01
step: 155000, training_loss: 2.25106e+01
step: 156000, training_loss: 1.79228e+01
step: 157000, training_loss: 1.77518e+01
step: 158000, training_loss: 1.71027e+01
step: 159000, training_loss: 1.69288e+01
step: 160000, training_loss: 1.65109e+01
Average mmd : -1.2365235741662595e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 161000, training_loss: 2.17977e+01
step: 162000, training_loss: 2.14804e+01
step: 163000, training_loss: 1.93641e+01
step: 164000, training_loss: 2.27103e+01
step: 165000, training_loss: 2.17447e+01
step: 166000, training_loss: 2.04266e+01
step: 167000, training_loss: 1.77442e+01
step: 168000, training_loss: 1.84252e+01
step: 169000, training_loss: 1.83157e+01
step: 170000, training_loss: 1.93273e+01
Average mmd : 5.682422043062907e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 171000, training_loss: 1.80365e+01
step: 172000, training_loss: 1.92828e+01
step: 173000, training_loss: 2.14613e+01
step: 174000, training_loss: 1.93449e+01
step: 175000, training_loss: 2.50170e+01
step: 176000, training_loss: 2.16177e+01
step: 177000, training_loss: 1.97483e+01
step: 178000, training_loss: 1.93271e+01
step: 179000, training_loss: 1.93184e+01
step: 180000, training_loss: 1.86955e+01
Average mmd : -1.8148298809417263e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 181000, training_loss: 2.59543e+01
step: 182000, training_loss: 1.90063e+01
step: 183000, training_loss: 1.89390e+01
step: 184000, training_loss: 2.50664e+01
step: 185000, training_loss: 2.20781e+01
step: 186000, training_loss: 1.65911e+01
step: 187000, training_loss: 1.85015e+01
step: 188000, training_loss: 1.77582e+01
step: 189000, training_loss: 2.84059e+01
step: 190000, training_loss: 1.90760e+01
Average mmd : 2.0296853412316018e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 191000, training_loss: 2.23102e+01
step: 192000, training_loss: 1.73466e+01
step: 193000, training_loss: 2.10173e+01
step: 194000, training_loss: 1.91711e+01
step: 195000, training_loss: 3.04460e+01
step: 196000, training_loss: 1.98811e+01
step: 197000, training_loss: 2.02583e+01
step: 198000, training_loss: 2.65102e+01
step: 199000, training_loss: 1.82702e+01
step: 200000, training_loss: 1.71762e+01
Average mmd : -3.322864373966894e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 201000, training_loss: 3.74201e+01
step: 202000, training_loss: 1.68194e+01
step: 203000, training_loss: 1.77708e+01
step: 204000, training_loss: 1.85826e+01
step: 205000, training_loss: 2.37129e+01
step: 206000, training_loss: 1.87821e+01
step: 207000, training_loss: 1.88308e+01
step: 208000, training_loss: 2.03442e+01
step: 209000, training_loss: 1.64164e+01
step: 210000, training_loss: 1.80733e+01
Average mmd : -7.027717269425526e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 211000, training_loss: 1.95369e+01
step: 212000, training_loss: 1.88554e+01
step: 213000, training_loss: 1.82109e+01
step: 214000, training_loss: 2.13106e+01
step: 215000, training_loss: 2.62926e+01
step: 216000, training_loss: 2.09522e+01
step: 217000, training_loss: 3.63382e+01
step: 218000, training_loss: 1.92660e+01
step: 219000, training_loss: 2.31504e+01
step: 220000, training_loss: 1.79929e+01
Average mmd : 9.376785618409045e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 221000, training_loss: 2.15940e+01
step: 222000, training_loss: 1.87405e+01
step: 223000, training_loss: 1.81374e+01
step: 224000, training_loss: 1.92206e+01
step: 225000, training_loss: 2.71982e+01
step: 226000, training_loss: 1.71303e+01
step: 227000, training_loss: 2.17340e+01
step: 228000, training_loss: 1.94151e+01
step: 229000, training_loss: 2.32967e+01
step: 230000, training_loss: 1.67300e+01
Average mmd : 0.00011119047215912836


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 231000, training_loss: 2.55445e+01
step: 232000, training_loss: 2.82795e+01
step: 233000, training_loss: 1.97671e+01
step: 234000, training_loss: 1.91936e+01
step: 235000, training_loss: 2.13338e+01
step: 236000, training_loss: 2.38660e+01
step: 237000, training_loss: 1.84668e+01
step: 238000, training_loss: 1.88737e+01
step: 239000, training_loss: 2.08938e+01
step: 240000, training_loss: 1.74645e+01
Average mmd : -3.587005642780028e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 241000, training_loss: 2.61986e+01
step: 242000, training_loss: 2.12288e+01
step: 243000, training_loss: 2.25842e+01
step: 244000, training_loss: 2.32395e+01
step: 245000, training_loss: 1.94309e+01
step: 246000, training_loss: 2.20262e+01
step: 247000, training_loss: 2.16806e+01
step: 248000, training_loss: 1.74307e+01
step: 249000, training_loss: 2.14717e+01
step: 250000, training_loss: 2.10141e+01
Average mmd : -4.735319371451663e-06


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 251000, training_loss: 1.89682e+01
step: 252000, training_loss: 2.23543e+01
step: 253000, training_loss: 2.50968e+01
step: 254000, training_loss: 1.69513e+01
step: 255000, training_loss: 1.82974e+01
step: 256000, training_loss: 2.99302e+01
step: 257000, training_loss: 1.76267e+01
step: 258000, training_loss: 1.80767e+01
step: 259000, training_loss: 1.82399e+01
step: 260000, training_loss: 1.83938e+01
Average mmd : -6.816454444824593e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 261000, training_loss: 2.09887e+01
step: 262000, training_loss: 2.07003e+01
step: 263000, training_loss: 1.87934e+01
step: 264000, training_loss: 2.09138e+01
step: 265000, training_loss: 1.83559e+01
step: 266000, training_loss: 2.71851e+01
step: 267000, training_loss: 2.13770e+01
step: 268000, training_loss: 2.40124e+01
step: 269000, training_loss: 1.86528e+01
step: 270000, training_loss: 2.41552e+01
Average mmd : 8.641048432711518e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 271000, training_loss: 1.88365e+01
step: 272000, training_loss: 1.77788e+01
step: 273000, training_loss: 2.04835e+01
step: 274000, training_loss: 1.93604e+01
step: 275000, training_loss: 2.51989e+01
step: 276000, training_loss: 1.85279e+01
step: 277000, training_loss: 1.86710e+01
step: 278000, training_loss: 2.26047e+01
step: 279000, training_loss: 2.19829e+01
step: 280000, training_loss: 1.91292e+01
Average mmd : 2.6044790146673158e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 281000, training_loss: 1.98236e+01
step: 282000, training_loss: 1.76718e+01
step: 283000, training_loss: 1.82537e+01
step: 284000, training_loss: 1.95154e+01
step: 285000, training_loss: 2.26115e+01
step: 286000, training_loss: 1.74093e+01
step: 287000, training_loss: 2.14619e+01
step: 288000, training_loss: 2.59510e+01
step: 289000, training_loss: 2.02672e+01
step: 290000, training_loss: 2.32667e+01
Average mmd : 2.009589560203473e-05


  with torch.cuda.amp.autocast(dtype=torch.bfloat16):


step: 291000, training_loss: 1.87541e+01
step: 292000, training_loss: 2.72955e+01
step: 293000, training_loss: 3.57964e+01
step: 294000, training_loss: 2.20398e+01
step: 295000, training_loss: 1.85206e+01
step: 296000, training_loss: 1.66638e+01
step: 297000, training_loss: 1.92094e+01
step: 298000, training_loss: 1.80110e+01
step: 299000, training_loss: 2.49247e+01
step: 300000, training_loss: 1.86149e+01
Average mmd : -2.437535099542032e-06
