In [None]:
!pip install -r "../requirements.txt"

In [None]:
!pip install -e ..

In [None]:
import torch
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
from heroic_gm.models.heroic_nn import HeroicScoreNet
from heroic_gm.data.heroic_dataset import HeroicDataset
import tqdm

In [None]:
PARQUET_DATABASE_URL="https://huggingface.co/datasets/0xJustin/Dungeons-and-Diffusion/resolve/main/data/train-00000-of-00001-9b40395dcd3257f2.parquet"

In [None]:
# From Yang Song notebook
device = 'cuda'

def marginal_prob_std(t, sigma):
  """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

  Args:    
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.  
  
  Returns:
    The standard deviation.
  """    
  t = torch.tensor(t, device=device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Compute the diffusion coefficient of our SDE.

  Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    The vector of diffusion coefficients.
  """
  return torch.tensor(sigma**t, device=device)
  
sigma =  25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [None]:
# From Yang Song notebook
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
  """The loss function for training score-based generative models.

  Args:
    model: A PyTorch model instance that represents a 
      time-dependent score-based model.
    x: A mini-batch of training data.    
    marginal_prob_std: A function that gives the standard deviation of 
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
  """
  random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
  z = torch.randn_like(x)
  std = marginal_prob_std(random_t)
  perturbed_x = x + z * std[:, None, None, None]
  score = model(perturbed_x, random_t)
  loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
  return loss

In [None]:
# From Yang Song notebook

score_model = torch.nn.DataParallel(HeroicScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs =   50 
## size of a mini-batch
batch_size =  32
## learning rate
lr=1e-4 

dataset = HeroicDataset(PARQUET_DATABASE_URL)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = tqdm.trange(n_epochs)
for epoch in tqdm_epoch:
  avg_loss = 0.
  num_items = 0
  for _, y in data_loader:
    y = y.to(device)    
    loss = loss_fn(score_model, y, marginal_prob_std_fn)
    optimizer.zero_grad()
    loss.backward()    
    optimizer.step()
    avg_loss += loss.item() * y.shape[0]
    num_items += y.shape[0]
  # Print the averaged training loss so far.
  tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
  # Update the checkpoint after each epoch of training.
  torch.save(score_model.state_dict(), 'ckpt.pth')