In [1]:
import os
from itertools import chain

import numpy as np
import torch
import losses
import sampling
import graph_lib
import noise_lib
import utils
from model import SEDD
from model.ema import ExponentialMovingAverage

from omegaconf import OmegaConf

from torch.utils.data import DataLoader, Dataset

from data.synthetic import utils as data_utils

import io
import PIL
import functools

torch.backends.cudnn.benchmark = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg_path='configs//synthetic_config_masked.yaml'
cfg = OmegaConf.load(cfg_path)
cfg.model=OmegaConf.load('configs//model//tiny.yaml')
work_dir = 'for_synthetic_data_masked'
if not os.path.exists(work_dir):
    os.makedirs(work_dir)

In [9]:
# Create directories for experimental logs
sample_dir = os.path.join(work_dir, "samples")
checkpoint_dir = os.path.join(work_dir, "checkpoints")
checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
utils.makedirs(sample_dir)
utils.makedirs(checkpoint_dir)
utils.makedirs(os.path.dirname(checkpoint_meta_dir))
logger = utils.get_logger(os.path.join(work_dir, "logs"))
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print("Found {} CUDA devices.".format(torch.cuda.device_count()))
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        print(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
else:
    print("WARNING: Using device {}".format(device))
    print(f"Found {os.cpu_count()} total number of CPUs.")

# build token graph
graph = graph_lib.get_graph(cfg, device)
    
# build score model
score_model = SEDD(cfg).to(device)
#score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True) Z:Commented this out

num_parameters = sum(p.numel() for p in score_model.parameters())
print(f"Number of parameters in the model: {num_parameters}")
ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
print(score_model)
print(f"EMA: {ema}")

# build noise
noise = noise_lib.get_noise(cfg).to(device)
#noise = DDP(noise, device_ids=[rank], static_graph=True) Z:Commented this out
sampling_eps = 1e-5


# build optimization state
optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
print(f"Optimizer: {optimizer}")
scaler = torch.cuda.amp.GradScaler()
print(f"Scaler: {scaler}")
state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 
state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)
initial_step = int(state['step'])

Found 1 CUDA devices.
NVIDIA GeForce GTX 1080 	 Memory: 8.00GB
Number of parameters in the model: 609795
SEDD(
  (vocab_embed): EmbeddingLayer()
  (sigma_map): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
    )
  )
  (rotary_emb): Rotary()
  (blocks): ModuleList(
    (0-2): 3 x DDiTBlock(
      (norm1): LayerNorm()
      (attn_qkv): Linear(in_features=64, out_features=192, bias=False)
      (attn_out): Linear(in_features=64, out_features=64, bias=False)
      (dropout1): Dropout(p=0.01, inplace=False)
      (norm2): LayerNorm()
      (mlp): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='tanh')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
      (dropout2): Dropout(p=0.01, inplace=False)
      (adaLN_modulation): Linear(in_features=256, out_features=384, bias=T

  scaler = torch.cuda.amp.GradScaler()


In [10]:
# Load the data from file - assumes we have already generated samples using generate_data.ipynb
data_file = os.path.join(cfg.data.train, 'data.npy')
with open(data_file, 'rb') as f:
    data = np.load(f).astype(np.int64)
    print('data shape: %s' % str(data.shape))

# Define a custom Dataset to wrap the data
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = torch.from_numpy(data)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Create a Dataset instance
train_set = CustomDataset(data)

# Function to cycle through DataLoader
def cycle_loader(dataloader):
    while True:
        for data in dataloader:
            yield data

# Initialize DataLoader without a sampler
train_ds = cycle_loader(DataLoader(
    train_set,
    batch_size=cfg.training.batch_size // (cfg.ngpus * cfg.training.accum),
    num_workers=0,
    pin_memory=True,
    shuffle=True,  # Shuffle the data as needed
    persistent_workers=False,
))

# Create an iterator for the data
train_iter = iter(train_ds)

data shape: (10000000, 32)


In [11]:
#plotting code borrowed from Sun2023

def plot(xbin, fn_xbin2float, output_file=None):
  """Visualize binary data."""
  float_data = fn_xbin2float(xbin)
  if output_file is None:  # in-memory plot
    buf = io.BytesIO()
    data_utils.plot_samples(float_data, buf, im_size=4.1, im_fmt='png')
    buf.seek(0)
    image = np.asarray(PIL.Image.open(buf))[None, ...]
    return image
  else:
    with open(output_file, 'wb') as f:
      im_fmt = 'png' if output_file.endswith('.png') else 'pdf'
      data_utils.plot_samples(float_data, f, im_size=4.1, im_fmt=im_fmt)


class BinarySyntheticHelper(object):
  """Binary synthetic model helper."""

  def __init__(self, config):
    self.config = config
    self.bm, self.inv_bm = data_utils.get_binmap(config.model.length,
                                                 'gray')

  def plot(self, xbin, output_file=None):
    fn_xbin2float = functools.partial(
        data_utils.bin2float, inv_bm=self.inv_bm,
        discrete_dim=self.config.model.length, int_scale=self.config.int_scale)
    return plot(xbin, fn_xbin2float, output_file)
  
model_helper = BinarySyntheticHelper(cfg)

remapping binary repr with gray code


In [12]:
#the metric used in Table 1 of Lou2023
def binary_mmd(x, y, sim_fn):
  """MMD for binary data."""
  x = x.astype(np.float32)
  y = y.astype(np.float32)
  kxx = sim_fn(x, x)
  kxx = kxx * (1 - np.eye(x.shape[0]))
  kxx = np.sum(kxx) / x.shape[0] / (x.shape[0] - 1)

  kyy = sim_fn(y, y)
  kyy = kyy * (1 - np.eye(y.shape[0]))
  kyy = np.sum(kyy) / y.shape[0] / (y.shape[0] - 1)
  kxy = np.sum(sim_fn(x, y))
  kxy = kxy / x.shape[0] / y.shape[0]
  mmd = kxx + kyy - 2 * kxy
  return mmd

def binary_exp_hamming_sim(x, y, bd):
  x = np.expand_dims(x, axis=1)
  y = np.expand_dims(y, axis=0)
  d = np.sum(np.abs(x - y), axis=-1)
  return np.exp(-bd * d)

def binary_exp_hamming_mmd(x, y, bandwidth=0.1):
  sim_fn = functools.partial(binary_exp_hamming_sim, bd=bandwidth)
  return binary_mmd(x, y, sim_fn)

In [13]:
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(cfg)
train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum)
eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum)

if cfg.training.snapshot_sampling:
        sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
        sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)


num_train_steps = cfg.training.n_iters
print(f"Starting training loop at step {initial_step}.")

Starting training loop at step 10001.


In [14]:
while state['step'] < num_train_steps + 1:
    step = state['step']
    batch=next(train_iter).to(device)
    loss = train_step_fn(state, batch)

    # flag to see if there was movement ie a full batch got computed
    if step != state['step']:
        if step % cfg.training.log_freq == 0:
            print("step: %d, training_loss: %.5e" % (step, loss.item()))
            
    if step % cfg.training.snapshot_freq_for_preemption == 0:
        utils.save_checkpoint(checkpoint_meta_dir, state)

    if step > 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps:
        # Save the checkpoint.
        save_step = step // cfg.training.snapshot_freq
        utils.save_checkpoint(os.path.join(
                        checkpoint_dir, f'checkpoint_{save_step}.pth'), state)
        
        #want to use the ema weights for sampling
        ema.store(score_model.parameters())
        ema.copy_to(score_model.parameters())
        
        
        
        
        #print the metric used in Table1 of Lou2023. Should get at least as small as 1.62e-5 to be considered done training. 
        avg_mmd = 0.0
        for i in range(cfg.eval_rounds):
            gt_data = []
            for _ in range(cfg.plot_samples // cfg.training.batch_size):
                gt_data.append(next(train_ds).cpu().numpy())
            gt_data = np.concatenate(gt_data, axis=0)
            gt_data = np.reshape(gt_data, (-1, cfg.model.length))
            sample_data=[]
            for _ in range(cfg.plot_samples // cfg.training.batch_size):
                sample_data.append(sampling_fn(score_model).cpu().numpy())
            sample_data=np.concatenate(sample_data,axis=0)
            x0 = np.reshape(sample_data, gt_data.shape)
            mmd = binary_exp_hamming_mmd(x0, gt_data)
            avg_mmd += mmd
            print(f'Eval round {i} mmd: {mmd}')
            model_helper.plot(x0,f'step_{save_step}_eval_round_{i}.pdf') #plot a sample from the model

        avg_mmd = avg_mmd / cfg.eval_rounds
        print(f'Average mmd : {avg_mmd}')
        with open('output.txt', 'a') as file:
            file.write(f'step: {save_step}, loss: {loss.item()}, MMD: {avg_mmd}\n')

        ema.restore(score_model.parameters())

step: 11000, training_loss: 2.19158e+01
step: 12000, training_loss: 2.30325e+01
step: 13000, training_loss: 2.24266e+01
step: 14000, training_loss: 2.32441e+01
step: 15000, training_loss: 2.16463e+01
step: 16000, training_loss: 2.23951e+01
step: 17000, training_loss: 2.20677e+01
step: 18000, training_loss: 2.21815e+01
step: 19000, training_loss: 2.21960e+01
step: 20000, training_loss: 2.21214e+01
Eval round 0 mmd: 0.007332335078250596
Average mmd : 0.007332335078250596
step: 21000, training_loss: 2.56982e+01
step: 22000, training_loss: 2.22217e+01
step: 23000, training_loss: 2.23996e+01
step: 24000, training_loss: 2.26569e+01
step: 25000, training_loss: 2.22019e+01
step: 26000, training_loss: 2.33324e+01
step: 27000, training_loss: 2.24384e+01
step: 28000, training_loss: 2.24196e+01
step: 29000, training_loss: 2.22064e+01
step: 30000, training_loss: 2.20200e+01
Eval round 0 mmd: 0.00706799882557424
Average mmd : 0.00706799882557424
step: 31000, training_loss: 2.22789e+01
step: 32000, t