In [24]:
import torch.nn as nn
import tensorflow as tf
import functools
import torch
import numpy as np
from scipy import special as sf
from scipy.stats import binom as spbinom
#from numba import njit,float64,int64,jit
#from numba.types import UniTuple
from matplotlib import pyplot as plt
#import numba_scipy
import gc
import os
from utils import save_checkpoint_withEval as save_checkpoint
from utils import restore_checkpoint_withEval as restore_checkpoint
from loadDataPipeline import generateData
from torch.utils.data import DataLoader

from scipy.optimize import bisect
from scipy.stats import binom

from torch import Tensor 

import sys
sys.path.append('score_sde_pytorch/')

softplus = nn.functional.softplus

In [6]:
from torch.utils.cpp_extension import load
#from score_sde_pytorch.models import ncsnpp
from configs.vp import cifar10_ncsnpp_continuous as configLoader
from models import utils as mutils
from models.ema import ExponentialMovingAverage

### Loading the ML model from Song et al.

In [28]:
config = configLoader.get_config()
config.training.batch_size = 128
config.training.snapshot_freq_for_preemption = 10
config.training.snapshot_freq = 50_000
config.training.log_freq = 100


config.data.dataset = 'CELEBA'
config.data.image_size = 64
config.data.data_path = '/project/smartFRACs/jesantos/generative-discrete-state-diffusion-models/dataBuffer/celebA/celebA_64_64.npy'

In [13]:
config.model.num_scales = 1_000
config.model.t_end = 15
config.model.num_bins = 256

### Solving for observation times (noise schedule) and forward solution or directly loading previously saved ones

In [27]:
def f(x):
    return np.log(x/(1-x))

In [None]:
from torch.utils.data import DataLoader, Dataset

class blackout_loader(Dataset)

    def __init__(self,  config):
        
        self.T = config.model.num_scales
        self.t_end = config.model.t_end
        self.batch_size = config.training.batch_size
        
        forward_solution(self, config)
        compute_weights()
        self.load_data()
        
    def forward_solution(self):
       
        num_bins = 256
        
        x_end = np.exp( -self.t_end )
        f_grid = np.linspace( -f(self.x_end), f(self.x_end), self.T )
        x_grid = np.array( [bisect(lambda x: f(x)-f_grid[i], x_end/2, 1-x_end/2) for i in range(self.T)] )
        self.observation_times = -np.log(x_grid)    

        table = np.zeros((num_bins, num_bins))
        for n in range(num_bins):
            for m in range(n):
                table[n, m] = n - m
            table[n, n] = 0
        self.table = np.repeat(table[:, :, None], T, axis=-1)

        support = np.arange(num_bins)
        sol = np.zeros((T+1, num_bins, num_bins))
        sol[0,:,:] = np.eye(num_bins)

        self.pt = np.exp(-observation_times)
        for t in range(T):
            p = pt[t]
            sol[t + 1, :, :] = binom.pmf(support, num_bins, 1 - p)
        self.cumulative = np.cumsum(sol, axis=1)    
        
        return
        
    def compute_weights(self):
        
        e_observation_times = np.insert(self.observation_times, 0, 0)

        #self.pt = torch.exp(-e_observation_times[1:])
        self.sampling_prob = torch.ones_like(pt) / torch.sum(torch.ones_like(pt))
        self.weights = pt*np.diff(e_observation_times)/sampling_prob
        
        return
        
        

    def load_data(self):
        self.data = np.load(config.data.data_path)
        self.training_ims = im.shape[0]
        self.num_batches = int( self.training_ims/self.batch_size )
        return 
        
    

    def __len__(self):
            return self.num_batches
        
    def __getitem__(self, idx):
        
        
        ims_ind =  torch.randperm( self.training_ims )[:self.batch_size]
        ims = self.data[ims_ind]
        
        t_ind = np.random.choice(self.T, size=(self.batch_size,1,1,1), p=self.sampling_prob)
        cp = self.cumulative[t_ind,:,ims]
        u = torch.FloatTensor(self.batch_size, 64, 64, 3).uniform_()
        nt = torch.argmax(u < cp, axis=4).int()
        index = ims*256*self.T + nt*self.T + t_ind
        
        
        birth_rate = self.table[index]  
        p = self.pt[t_ind]
        mean_v = (255/2*p).reshape((n, 1, 1, 1))
        
        
        return ((nt-mean_v)/width).permute((0,3,1,2)), birthRateBatch.permute((0,3,1,2)), tIndex
    
    
    def one_batch(self):
        return None
        
        

In [12]:
train_loader = blackout_loader(config, batch_size=None, pin_memory=True, num_workers=16, shuffle= True)

### Visualize one batch

In [None]:
train_batch_GPU = next(train_iter)
train_batch = train_batch_GPU.numpy()

output_image_batch, brRate_batch, tIndexArray = generateBatchDataGPU(train_batch_GPU, T)

output_image_batch = np.transpose(output_image_batch.detach().cpu().numpy(), (0,2,3,1))
brRate_batch = np.transpose(brRate_batch.detach().cpu().numpy(), (0,2,3,1))
tIndexArray = tIndexArray.detach().cpu().numpy()

for i in range(20):
    
    testImage = train_batch[i,:,:,:]
    
    
    output_image = (255.0*(output_image_batch[i,:,:,:]+1.)/2.).astype('int32')
    birthRate = brRate_batch[i,:,:,:]
    targetTime = tIndexArray[i]
    
    fig, ax = plt.subplots(1,3, figsize=(4.8,1.5))
    
    ax[0].imshow(testImage)
    
    if np.amax(output_image)!=0:
        ax[1].imshow(output_image/np.amax(output_image))
    else:
        ax[1].imshow(output_image)
        
    ax[1].set_title('$t='+str(targetTime)+'$')
    
    if np.amax(birthRate)-np.amin(birthRate)!=0:
        ax[2].imshow((birthRate-np.amin(birthRate))/(np.amax(birthRate)-np.amin(birthRate)))
    else:
        ax[2].imshow(birthRate)
        
    for j in range(3):
        
        ax[j].set_xticklabels('')
        ax[j].set_yticklabels('')
    
    fig.tight_layout()


### Instantiate an ML model to learn the transition rate

In [None]:
score_model = mutils.create_model(config)
score_fn = mutils.get_model_fn(score_model, train=True)
optimizer = torch.optim.Adam(score_model.parameters(),lr=config.optim.lr) 

ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)

train_batch = next(train_iter).to(config.device).float()
train_batch = train_batch.permute(0, 3, 1, 2)
imgBatch = train_batch

workdir = 'blackout-celebA64'

state = dict(optimizer=optimizer, model=score_model, ema=ema, lossHistory=[], evalLossHistory=[], step=0)

checkpoint_dir = os.path.join(workdir, "checkpoints")
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
tf.io.gfile.makedirs(checkpoint_dir)
tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
initial_step = int(state['step'])
lossHistory = state['lossHistory']
evalLossHistory = state['evalLossHistory']

### Training

In [None]:
for step in range(initial_step, config.training.n_iters):
    
    try:
        train_batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        train_batch = next(train_iter)
        
    output_image_batch, birthRate_batch, tIndexArray = generateBatchDataGPU(train_batch, T)
    
    optimizer.zero_grad()

    y = softplus(score_fn(output_image_batch, tIndexArray))
    
    loss = torch.mean( weightsGPU[tIndexArray.long()].reshape([config.training.batch_size,1,1,1])*(y - birthRate_batch*torch.log(y)))
    
    loss.backward()

    state['ema'].update(state['model'].parameters())
    
    optimizer.step()
    
    lossHistory.append(loss.detach().cpu().numpy())

    if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
        save_checkpoint(checkpoint_meta_dir, state)
        
    if step != 0 and step % config.training.snapshot_freq == 0 or step == config.training.n_iters:
        save_step = step // config.training.snapshot_freq
        save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state)    
    
    if np.mod(step, config.training.log_freq)==0:
        
        ema.store(score_model.parameters())
        ema.copy_to(score_model.parameters())
        
        y = softplus(score_fn(output_image_batch, tIndexArray))
        
        loss = torch.mean( weightsGPU[tIndexArray.long()].reshape([config.training.batch_size,1,1,1])*(y - birthRate_batch*torch.log(y)))
    
        ema.restore(score_model.parameters())
        
        evalLossHistory.append(loss.detach().cpu().numpy())

        print(f'current iter: {step}, loss: {lossHistory[-1]}, eval loss: {evalLossHistory[-1]}')
        
    state['step'] = step
    state['lossHistory'] = lossHistory
    state['evalLossHistory'] = evalLossHistory
    
    gc.collect()
    torch.cuda.empty_cache()