In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
from torchvision.datasets import ImageFolder

from scipy import integrate, linalg
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
 

import functools
import os
import math
import functools
import string
import wandb
from tqdm import tqdm


import sys
sys.path.append("..")
from src.utils import Config


from src.fid_score import calculate_frechet_distance, get_loader_stats
from src.inception import InceptionV3
from src.utils import freeze, unfreeze
import gc

## 1. Config

In [None]:
config = Config()

config.device = 'cuda'

config.data = Config()
config.data.num_channels = 3
config.data.channels = 3
config.data.centered = True
config.data.img_resize=32
config.data.image_size =32


config.training = Config()
config.training.sde = 'poisson'
config.training.continuous = True
config.training.batch_size = 128#4096
config.training.small_batch_size = 128
config.training.gamma = 5
config.training.restrict_M = True
config.training.tau = 0.03
config.training.snapshot_freq = 5_000
config.training.eval_freq = 5_000
config.training.model = 'ddpmpp'
config.training.M = 291
config.training.reduce_mean = False
config.training.n_iters =  1_000_000
config.training.fid_freq = 25_000
config.training.fid_batch_size = 250

config.model  = Config()
config.model.name = 'ncsnpp'
config.model.scale_by_sigma = False
config.model.ema_rate = 0.9999
config.model.normalization = 'GroupNorm'
config.model.nonlinearity = 'swish'
config.model.nf = 128
config.model.ch_mult = (1, 2, 2, 2)
config.model.num_res_blocks = 4
config.model.attn_resolutions = (16,)
config.model.resamp_with_conv = True
config.model.conditional = True
config.model.fir = False
config.model.fir_kernel = [1, 3, 3, 1]
config.model.skip_rescale = True
config.model.resblock_type = 'biggan'
config.model.progressive = 'none'
config.model.progressive_input = 'none'
config.model.progressive_combine = 'sum'
config.model.attention_type = 'ddpm'
config.model.init_scale = 0.
config.model.fourier_scale = 16
config.model.embedding_type = 'positional'
config.model.conv_size = 3
config.model.sigma_end = 0.01
config.model.dropout = 0.1

config.optim  = Config()
config.optim.weight_decay = 0
config.optim.optimizer = 'Adam'
config.optim.lr = 2e-7
config.optim.beta1 = 0.9
config.optim.eps = 1e-8
config.optim.warmup = 5000
config.optim.grad_clip = 1.


config.device = 'cuda'

config.sampling = Config()
config.sampling.method = 'ode'
config.sampling.ode_solver = 'rk45'
config.sampling.N = 100
config.sampling.z_max = 30
config.sampling.z_min = 1e-7
config.sampling.upper_norm = 3000
config.sampling.z_exp=1
config.sampling.visual_iterations = 10
# verbose
config.sampling.vs = False


## 2. Data

In [None]:
class Sampler:
    def __init__(
        self, device='cuda',
    ):
        self.device = device
    
    def sample(self, size=5):
        pass
    
class LoaderSampler(Sampler):
    def __init__(self, loader, device='cuda'):
        super(LoaderSampler, self).__init__(device)
        self.loader = loader
        self.it = iter(self.loader)
        
    def sample(self, size=5):
        assert size <= self.loader.batch_size
        try:
            batch, _ = next(self.it)
        except StopIteration:
            self.it = iter(self.loader)
            return self.sample(size)
        if len(batch) < size:
            return self.sample(size)
            
        return batch[:size].to(self.device)


In [None]:
def get_random_colored_images(images, seed = 0x000000):
    np.random.seed(seed)
    
    images = 0.5*(images + 1)
    size = images.shape[0]
    colored_images = []
    hues = 360*np.random.rand(size)
    
    for V, H in zip(images, hues):
        V_min = 0
        
        a = (V - V_min)*(H%60)/60
        V_inc = a
        V_dec = V - a
        
        colored_image = torch.zeros((3, V.shape[1], V.shape[2]))
        H_i = round(H/60) % 6
        
        if H_i == 0:
            colored_image[0] = V
            colored_image[1] = V_inc
            colored_image[2] = V_min
        elif H_i == 1:
            colored_image[0] = V_dec
            colored_image[1] = V
            colored_image[2] = V_min
        elif H_i == 2:
            colored_image[0] = V_min
            colored_image[1] = V
            colored_image[2] = V_inc
        elif H_i == 3:
            colored_image[0] = V_min
            colored_image[1] = V_dec
            colored_image[2] = V
        elif H_i == 4:
            colored_image[0] = V_inc
            colored_image[1] = V_min
            colored_image[2] = V
        elif H_i == 5:
            colored_image[0] = V
            colored_image[1] = V_min
            colored_image[2] = V_dec
        
        colored_images.append(colored_image)
        
    colored_images = torch.stack(colored_images, dim = 0)
    colored_images = 2*colored_images - 1
    
    return colored_images
    

In [None]:
def load_dataset(name, path, img_size=64, batch_size=64, 
                 shuffle=True, device='cuda', return_dataset=False,
                 num_workers=0):
    
    
    if name.startswith("MNIST"):
        # In case of using certain classe from the MNIST dataset you need to specify them by writing in the next format "MNIST_{digit}_{digit}_..._{digit}"
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Lambda(lambda x: 2 * x - 1)
        ])
        
        dataset_name = name.split("_")[0]
        is_colored = dataset_name[-7:] == "colored"
        
        classes = [int(number) for number in name.split("_")[1:]]
        if not classes:
            classes = [i for i in range(10)]
        
        train_set = torchvision.datasets.MNIST(path, train=True, transform=transform, download=True)
        test_set = torchvision.datasets.MNIST(path, train=False, transform=transform, download=True)
        
        train_test = []
        
        for dataset in [train_set, test_set]:
            data = []
            labels = []
            for k in range(len(classes)):
                data.append(torch.stack(
                    [dataset[i][0] for i in range(len(dataset.targets)) if dataset.targets[i] == classes[k]],
                    dim=0
                ))
                labels += [k]*data[-1].shape[0]
            data = torch.cat(data, dim=0)
            data = data.reshape(-1, 1, 32, 32)
            labels = torch.tensor(labels)
            
            if is_colored:
                data = get_random_colored_images(data)
            
            train_test.append(TensorDataset(data, labels))
            
        train_set, test_set = train_test  
    else:
        raise Exception('Unknown dataset')
    
    if return_dataset:
        return train_set, test_set
        
    train_sampler = LoaderSampler(DataLoader(train_set, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size), device)
    test_sampler = LoaderSampler(DataLoader(test_set, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size), device)
    return train_sampler, test_sampler, test_set

In [None]:
DATASET1, DATASET1_PATH = 'MNIST-colored', '../data/MNIST'
 
X_sampler, X_test_sampler, Dataset = load_dataset(DATASET1, DATASET1_PATH, 
                                         img_size=config.data.image_size,
                                         batch_size=config.training.batch_size, num_workers=8)
 

## 3. Poisson class

In [None]:
class Poisson():
    
    
    def __init__(self, config):
        """Construct a PFGM.

        Args:
          config: configurations
        """
        self.config = config
        self.N = config.sampling.N

        
    @property
    def M(self):
        return self.config.training.M

 
        
    def ode(self, net_fn, x, t):

        z = np.exp(t.mean().cpu())
        if self.config.sampling.vs:
            print(z)
        x_drift, z_drift = net_fn(x, torch.ones((len(x))).cuda() * z)
        x_drift = x_drift.view(len(x_drift), -1)

        # Substitute the predicted z with the ground-truth
        # Please see Appendix B.2.3 in PFGM paper (https://arxiv.org/abs/2209.11178) for details
        z_exp = self.config.sampling.z_exp
        if z < z_exp and self.config.training.gamma > 0:
            data_dim = self.config.data.image_size * self.config.data.image_size * self.config.data.channels
            sqrt_dim = np.sqrt(data_dim)
            norm_1 = x_drift.norm(p=2, dim=1) / sqrt_dim
            x_norm = self.config.training.gamma * norm_1 / (1 -norm_1)
            x_norm = torch.sqrt(x_norm ** 2 + z ** 2)
            z_drift = -sqrt_dim * torch.ones_like(z_drift) * z / (x_norm + self.config.training.gamma)

        # Predicted normalized Poisson field
        v = torch.cat([x_drift, z_drift[:, None]], dim=1)

        dt_dz = 1 / (v[:, -1] + 1e-5)
        dx_dt = v[:, :-1].view(len(x), self.config.data.num_channels,
                          self.config.data.image_size, self.config.data.image_size)
        dx_dz = dx_dt * dt_dz.view(-1, *([1] * len(x.size()[1:])))
        # dx/dt_prime =  z * dx/dz
        dx_dt_prime = z * dx_dz
        return dx_dt_prime

## 4. Utils Poisson

In [None]:
def forward_pz(sde, config,  samples_batch_x, samples_batch_y):
    
    """Perturbing the augmented training data. See Algorithm 2 in PFGM paper.

    Args:
      sde: An `methods.SDE` object that represents the forward SDE.
      config: configurations
      samples_batch: A mini-batch of un-augmented training data
      m: A 1D torch tensor. The exponents of (1+\tau).

    Returns:
      Perturbed samples
    """
   




    """
    m = torch.rand((samples_batch_x.shape[0],), device=samples_batch_x.device) * sde.M
    #### noise parametrization ####
    tau = config.training.tau
    z = torch.randn((len(samples_batch_x), 1, 1, 1)).to(samples_batch_x.device) * config.model.sigma_end
    z = z.abs()
    data_dim = config.data.channels * config.data.image_size * config.data.image_size
    multiplier = (1+tau) ** m
    perturbed_z = z.squeeze() * multiplier
    
    perturbed_x = samples_batch_x*(perturbed_z[0]/config.sampling.z_max) + (1 - perturbed_z[0]/config.sampling.z_max)*samples_batch_y
    perturbed_x += torch.randn_like(perturbed_x)
    perturbed_samples_vec = torch.cat((perturbed_x.reshape(len(samples_batch_x), -1),
                                       perturbed_z[:, None]), dim=1)
    return perturbed_samples_vec
    #### noise parametrization ####
    """
 
 
    m = torch.rand((samples_batch_x.shape[0],), device=samples_batch_x.device) * sde.M
    data_dim = config.data.channels * config.data.image_size * config.data.image_size # N
    tau = config.training.tau
    z = torch.randn((len(samples_batch_x), 1, 1, 1)).to(samples_batch_x.device) * config.model.sigma_end  # [B,1,1,1]
    z = z.abs() # [B,1,1,1]
    
    
    # Confine the norms of perturbed data.
    # see Appendix B.1.1 of https://arxiv.org/abs/2209.11178
    if config.training.restrict_M:
        idx = (z < 0.005).squeeze()
        num = int(idx.int().sum())
        restrict_m = int(sde.M * 0.7)
        m[idx] = torch.rand((num,), device=samples_batch_x.device) * restrict_m
    
        
    multiplier = (1+tau) ** m # torch.Size([B])
    # Perturb z
    perturbed_z = z.squeeze() * multiplier # torch.Size([B])* torch.Size([B]) = torch.Size([B])
    
    
    ####### perturbation for x component #######
    
    # Sample uniform angle
    gaussian = torch.randn(len(samples_batch_x), data_dim).to(samples_batch_x.device) # torch.Size([B, C*H*W])
    unit_gaussian = gaussian / torch.norm(gaussian, p=2, dim=1, keepdim=True) #  torch.Size([B, C*H*W])
    
    # injected noise amount
    noise = torch.randn_like(samples_batch_x).reshape(len(samples_batch_x), -1) * config.model.sigma_end #torch.Size([B, C*H*W])
    norm_m = torch.norm(noise, p=2, dim=1) * multiplier # torch.Size([B])*torch.Size([B]) = torch.Size([B])
    
    
    # Construct the perturbation for x
    perturbation_x = unit_gaussian * norm_m[:, None] # torch.Size([B,C*H*W])* torch.Size([B,1])=  torch.Size([B,C*H*W])
    perturbation_x = perturbation_x.view_as(samples_batch_x) # torch.size([B,C,H,W])
    
    # Perturb x
    perturbed_x = samples_batch_x + perturbation_x # torch.size([B,C,H,W])
    
    # Augment the data with extra dimension z
    perturbed_samples_vec = torch.cat((perturbed_x.reshape(len(samples_batch_x), -1),
                                       perturbed_z[:, None]), dim=1)
    
    # concatenate: torch.Size([B,C*H*W], torch.Size([[B,1]]) = torch.Size([B,C*H*W + 1]
    return perturbed_samples_vec
    

## 5. losses

In [None]:
def loss_pfgm(model, batch_x, batch_y):
    """Compute the loss function.

    Args:
      model: A PFGM or score model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
    samples_full_x = batch_x
    samples_full_y = batch_y

    perturbed_samples_vec = forward_pz(sde, sde.config,batch_x, batch_y )

    with torch.no_grad():
        
        real_samples_vec_x = torch.cat(
          (samples_full_x.reshape(len(samples_full_x), -1),
           torch.zeros((len(samples_full_x), 1)).to(samples_full_x.device)), dim=1)
        
        real_samples_vec_y = torch.cat(
          (samples_full_y.reshape(len(samples_full_y), -1),
           config.sampling.z_max*torch.ones((len(samples_full_y), 1)).to(samples_full_y.device)), dim=1)

        data_dim = sde.config.data.image_size * sde.config.data.image_size * sde.config.data.channels
        
        gt_distance_x = torch.sum((perturbed_samples_vec.unsqueeze(1) - real_samples_vec_x) ** 2,
                                dim=[-1]).sqrt()
        gt_distance_y = torch.sum((perturbed_samples_vec.unsqueeze(1) - real_samples_vec_y) ** 2,
                                dim=[-1]).sqrt()
        

        # For numerical stability, timing each row by its minimum value
        distance_x = torch.min(gt_distance_x, dim=1, keepdim=True)[0] / (gt_distance_x + 1e-7)
        distance_x = distance_x ** (data_dim + 1)
        distance_x = distance_x[:, :, None]
        
        distance_y = torch.min(gt_distance_y, dim=1, keepdim=True)[0] / (gt_distance_y + 1e-7)
        distance_y = distance_y ** (data_dim + 1)
        distance_y = distance_y[:, :, None]


        # Normalize the coefficients (effectively multiply by c(\tilde{x}) in the paper)
        coeff_x = distance_x / (torch.sum(distance_x, dim=1, keepdim=True) + 1e-7)
        coeff_y = distance_y / (torch.sum(distance_y, dim=1, keepdim=True) + 1e-7)
        
        diff_x = - (perturbed_samples_vec.unsqueeze(1) - real_samples_vec_x)
        diff_y = - (perturbed_samples_vec.unsqueeze(1) - real_samples_vec_y)

        # Calculate empirical Poisson field (N+1 dimension in the augmented space)
        gt_direction_x = torch.sum(coeff_x * diff_x, dim=1)
        gt_direction_x = gt_direction_x.view(gt_direction_x.size(0), -1)
        
        gt_direction_y = torch.sum(coeff_y * diff_y, dim=1)
        gt_direction_y = gt_direction_y.view(gt_direction_y.size(0), -1)


    gt_norm_x = gt_direction_x.norm(p=2, dim=1)
    # Normalizing the N+1-dimensional Poisson field
    gt_direction_x /= (gt_norm_x.view(-1, 1) + sde.config.training.gamma)
    gt_direction_x *= np.sqrt(data_dim)
    
    gt_norm_y = gt_direction_y.norm(p=2, dim=1)
    # Normalizing the N+1-dimensional Poisson field
    gt_direction_y /= (gt_norm_y.view(-1, 1) + sde.config.training.gamma)
    gt_direction_y *= np.sqrt(data_dim)


    target = gt_direction_x - gt_direction_y

    #net_fn = mutils.get_predict_fn(sde, model, train=train, continuous=continuous)

    perturbed_samples_x = perturbed_samples_vec[:, :-1].view_as(batch_x)
    #perturbed_samples_z = torch.clamp(perturbed_samples_vec[:, -1], 1e-10)
    perturbed_samples_z = perturbed_samples_vec[:, -1]
    net_x, net_z = model(perturbed_samples_x, perturbed_samples_z)

    net_x = net_x.view(net_x.shape[0], -1)
    # Predicted N+1-dimensional Poisson field
    net = torch.cat([net_x, net_z[:, None]], dim=1)
    loss = ((net - target) ** 2)
    #loss = reduce_op(loss.reshape(loss.shape[0], -1), dim=-1)
    loss = torch.mean(loss)

    return loss

## 6. ODE

In [None]:
def get_rk45_sampler_pfgm(sde, y, config, shape,   rtol=1e-4, atol=1e-4,
                    method='RK45', eps=1e-3, device='cuda'):

    """RK45 ODE sampler for PFGM.

    Args:
    sde: An `methods.SDE` object that represents PFGM.
    shape: A sequence of integers. The expected shape of a single sample.
    inverse_scaler: The inverse data normalizer.
    rtol: A `float` number. The relative tolerance level of the ODE solver.
    atol: A `float` number. The absolute tolerance level of the ODE solver.
    method: A `str`. The algorithm used for the black-box ODE solver.
      See the documentation of `scipy.integrate.solve_ivp`.
    eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
    device: PyTorch device.

    Returns:
    A sampling function that returns samples and the number of function evaluations during sampling.
    """

    
    def ode_sampler(model, y):

        x = y

        z = torch.ones((len(x), 1, 1, 1)).to(x.device)
        z = z.repeat((1, 1, sde.config.data.image_size, sde.config.data.image_size)) * sde.config.sampling.z_max
        x = x.view(shape)
        # Augment the samples with extra dimension z
        # We concatenate the extra dimension z as an addition channel to accomondate this solver
        x = torch.cat((x, z), dim=1)
        x = x.float()
        new_shape = (len(x), sde.config.data.channels + 1, sde.config.data.image_size, sde.config.data.image_size)
        
        

        def ode_func(t, x):

            if sde.config.sampling.vs:
                print(np.exp(t))


            x = from_flattened_numpy(x, new_shape).to(device).type(torch.float32)

            # Change-of-variable z=exp(t)
            z = np.exp(t)
            #net_fn = get_predict_fn(sde, model, train=False)

            x_drift, z_drift = model(x[:, :-1], torch.ones((len(x))).cuda() * z)
            x_drift = x_drift.view(len(x_drift), -1)

            # Substitute the predicted z with the ground-truth
            # Please see Appendix B.2.3 in PFGM paper (https://arxiv.org/abs/2209.11178) for details
            z_exp = sde.config.sampling.z_exp



            if z < z_exp and sde.config.training.gamma > 0:
                data_dim = sde.config.data.image_size * sde.config.data.image_size * sde.config.data.channels
                sqrt_dim = np.sqrt(data_dim)
                norm_1 = x_drift.norm(p=2, dim=1) / sqrt_dim
                x_norm = sde.config.training.gamma * norm_1 / (1 - norm_1)
                x_norm = torch.sqrt(x_norm ** 2 + z ** 2)
                z_drift = -sqrt_dim * torch.ones_like(z_drift) * z / (x_norm + sde.config.training.gamma)

                
                
            # Predicted normalized Poisson field
            v = torch.cat([x_drift, z_drift[:, None]], dim=1)
            dt_dz = 1 / (v[:, -1] + 1e-5)
            dx_dt = v[:, :-1].view(shape)

            # Get dx/dz
            dx_dz = dx_dt * dt_dz.view(-1, *([1] * len(x.size()[1:])))
            # drift = z * (dx/dz, dz/dz) = z * (dx/dz, 1)
            drift = torch.cat([z * dx_dz,
                               torch.ones((len(dx_dz), 1, sde.config.data.image_size,
                                           sde.config.data.image_size)).to(dx_dz.device) * z], dim=1)
            return to_flattened_numpy(drift)
        
        

        
        # Black-box ODE solver for the probability flow ODE.
        # Note that we use z = exp(t) for change-of-variable to accelearte the ODE simulation
        solution = integrate.solve_ivp(ode_func,
                                       (np.log(sde.config.sampling.z_max),
                                                  np.log(eps)), to_flattened_numpy(x),
                                     rtol=rtol, atol=atol, method=method)

        nfe = solution.nfev
        num_itrs = len(solution.y[0])
        x = torch.tensor(solution.y[:, -1]).reshape(new_shape).to(device).type(torch.float32)
        
        trajectory = []
        visual_iters = np.linspace(int(num_itrs//8), num_itrs, config.sampling.visual_iterations)
       
        for itr in visual_iters:
            traj = torch.tensor(solution.y[:,int(itr)-1]).reshape(new_shape).to(device).type(torch.float32)
            trajectory.append(traj[:,:-1])
            
            
        # Detach augmented z dimension
        x = x[:, :-1]
        #x = inverse_scaler(x)
        return x, nfe, torch.stack(trajectory,dim=0)

    return ode_sampler

## 7. Model

In [None]:
def optimization_manager(config):
    """Returns an optimize_fn based on `config`."""

    def optimize_fn(optimizer, params, step, lr=config.optim.lr,
                      warmup=config.optim.warmup,
                      grad_clip=config.optim.grad_clip):
        """Optimizes with warmup and gradient clipping (disabled if negative)."""
        if warmup > 0:
            for g in optimizer.param_groups:
                g['lr'] = lr * np.minimum(step / warmup, 1.0)
        if grad_clip >= 0:
            torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
        optimizer.step()

    return optimize_fn

In [None]:
sys.path.append("..")
from models import DDPM, ExponentialMovingAverage
net = DDPM(config).to(config.device)

params = net.parameters()
optimizer = torch.optim.Adam(params,
                       lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                       weight_decay=config.optim.weight_decay)

ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=net, ema=ema, step=0)
sde = Poisson(config=config)
sampling_eps = config.sampling.z_min
optimize_fn = optimization_manager(config)
reduce_mean = config.training.reduce_mean

num_train_steps = config.training.n_iters

In [None]:
def to_flattened_numpy(x):
    """Flatten a torch tensor `x` and convert it to numpy."""
    return x.detach().cpu().numpy().reshape((-1,))


def from_flattened_numpy(x, shape):
    """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
    return torch.from_numpy(x.reshape(shape))

In [None]:
def plot(x):
    fig,ax = plt.subplots(5,5,figsize=(5,5))
    for idx in range(5):
        for jdx in range(5):
            ax[idx,jdx].imshow(x[idx,jdx])
            ax[idx,jdx].set_yticks([])
            ax[idx,jdx].set_xticks([])
    fig.tight_layout(pad=0.001)       
    return fig

In [None]:
def plot_trajectory(traj):
    
    fig,ax = plt.subplots(5,len(traj),figsize=(len(traj),5),sharex=True,sharey=True)
    for time in range(len(traj)):
        for idx in range(5):
            ax[idx,time].imshow(np.clip(traj[time,idx].permute(1,2,0).cpu().numpy()*255,0,255).astype(np.uint8))
            ax[idx,time].set_xticks([])
            ax[idx,time].set_yticks([])

    fig.tight_layout(pad=0.01)
    return fig

In [None]:
loader = DataLoader(Dataset, shuffle=True, num_workers=8, batch_size=config.training.fid_batch_size)
mu_data,sigma_data = get_loader_stats(loader)

In [None]:
def get_pushed_loader_stats(net, sde, config, batch_size, verbose=False, device='cuda',
                            use_downloaded_weights=False):
    dims = 2048
    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx], use_downloaded_weights=use_downloaded_weights).to(device)
    freeze(net); freeze(model);
    
    size = len(loader.dataset)
    pred_arr = []
    shape = (config.training.fid_batch_size, config.data.num_channels,
                         config.data.image_size, config.data.image_size)
     
    
    with torch.no_grad():
        for step, (X, _) in tqdm(enumerate(loader)) if not verbose else tqdm(enumerate(loader)):
            for i in range(0, len(X), batch_size):
                start, end = i, min(i + batch_size, len(X))
                
                batch_y = torch.randn(config.training.fid_batch_size, config.data.num_channels,
                                  config.data.image_size,config.data.image_size)
                sampling_fn = get_rk45_sampler_pfgm(sde=sde, y=batch_y , config=config,
                                               shape=shape,
                                               eps=config.sampling.z_min,
                                               device=config.device)
    
                batch,_,_ = sampling_fn(net, batch_y)
        
                pred_arr.append(model(batch)[0].cpu().data.numpy().reshape(end-start, -1))

    pred_arr = np.vstack(pred_arr[:-1])
    mu, sigma = np.mean(pred_arr, axis=0), np.cov(pred_arr, rowvar=False)
    gc.collect(); torch.cuda.empty_cache()
    return mu, sigma

##  8. Training

In [None]:
wandb.init(project="ElectroGeneration",
name=f"CM_Exp_Decay_BS_{config.training.batch_size}_LR_{config.optim.lr}_zmax_{config.sampling.z_max}")

In [None]:
initial_step = 0
 
for step in tqdm(range(initial_step, num_train_steps + 1)):
    
    batch_x = X_sampler.sample(config.training.batch_size).to(config.device)
    #batch_y = Y_sampler.sample(config.training.batch_size).to(config.device)
    batch_y = torch.randn_like(batch_x).to(config.device)
     
    
    optimizer = state['optimizer']
    optimizer.zero_grad()
    
    loss = loss_pfgm(net, batch_x, batch_y)
    loss.backward()
    optimize_fn(optimizer, net.parameters(), step=state['step'])
    state['step'] += 1
    state['ema'].update(net.parameters())
    wandb.log({"loss train":loss.item()},step=step)
    
   
    if step % config.training.eval_freq == 0:
        
        batch_x = X_test_sampler.sample(config.training.batch_size).to(config.device)
        #batch_y = Y_test_sampler.sample(config.training.batch_size).to(config.device)
        batch_y = torch.randn_like(batch_x).to(config.device)
        
            
        with torch.no_grad():
            ema = state['ema']
            ema.store(net.parameters())
            ema.copy_to(net.parameters())
            eval_loss = loss_pfgm(net, batch_x, batch_y)
            ema.restore(net.parameters())
            wandb.log({"loss eval":eval_loss.item()},step=step)

 
 
    if step % config.training.snapshot_freq == 0:
        with torch.no_grad():
            ema.store(net.parameters())
            ema.copy_to(net.parameters())

            shape = (25, config.data.num_channels,
                         config.data.image_size, config.data.image_size)

            batch_y = torch.randn(25, config.data.num_channels,
                                  config.data.image_size,config.data.image_size)
                                  
                                  
            sampling_fn = get_rk45_sampler_pfgm(sde=sde,
                                                y=batch_y , config=config,
                                               shape=shape,
                                               eps=config.sampling.z_min,
                                               device=config.device)
            sample, n, traj = sampling_fn(net, batch_y)
            ema.restore(net.parameters())
            

            sample = np.clip(sample.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
            batch_y = np.clip(batch_y.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
            fig_1 = plot(sample.reshape(5,5,32,32,3) )
            fig_2 = plot(batch_y.reshape(5,5,32,32,3) )
            fig_3 = plot_trajectory(traj)
            wandb.log({"Generated Images":fig_1},step=step)
            wandb.log({"Init Images":fig_2},step=step)
            wandb.log({"Trajectories":fig_3},step=step)
            
            
    if step % config.training.fid_freq == 0 and step>1 :
        with torch.no_grad():
            mu,sigma = get_pushed_loader_stats(net, sde, config, batch_size=config.training.fid_batch_size, 
                                               verbose=False, device='cuda',
                            use_downloaded_weights=False)
            fid = calculate_frechet_distance(mu,sigma,mu_data,sigma_data)
            unfreeze(net)
            wandb.log({"FID":fid},step=step)