In [14]:
import os
import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm

%matplotlib inline

In [2]:
from graph_bridges.models.backward_rates.backward_rate import GaussianTargetRateImageX0PredEMA
from graph_bridges.data.dataloaders import DoucetTargetData
from graph_bridges.models.samplers.sampling import ReferenceProcess
from graph_bridges.models.reference_process.ctdd_reference import ReferenceProcess
from graph_bridges.models.losses.ctdd_losses import GenericAux
from graph_bridges.models.schedulers.scheduling_ctdd import CTDDScheduler

from graph_bridges.models.pipelines.pipelines_utils import create_pipelines
from graph_bridges.models.schedulers.scheduling_utils import create_scheduler
from graph_bridges.models.backward_rates.backward_rate_utils import create_model
from graph_bridges.models.reference_process.reference_process_utils import create_reference
from graph_bridges.data.dataloaders_utils import create_dataloader
from graph_bridges.models.losses.loss_utils import create_loss

from graph_bridges.configs.graphs.lobster.config_base import BridgeConfig

In [3]:
config = BridgeConfig()
device = torch.device(config.device)

In [4]:
#=================================================================
# CREATE OBJECTS FROM CONFIGURATION

data_dataloader: DoucetTargetData
model : GaussianTargetRateImageX0PredEMA
reference_process: ReferenceProcess
loss : GenericAux
scheduler:CTDDScheduler

data_dataloader = create_dataloader(config,device)
model = create_model(config,device)
reference_process = create_reference(config,device)
loss = create_loss(config,device)
scheduler = create_scheduler(config,device)

Scheduler


In [5]:
#=================================================================
sample_ = data_dataloader.sample(config.number_of_paths, device)
minibatch = sample_.unsqueeze(1).unsqueeze(1)

In [16]:
# TIME ===========================================================
#ts = torch.rand((minibatch.shape[0],), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
B = minibatch.shape[0]
ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
#==========

x_t, x_tilde, qt0, rate = scheduler.add_noise(minibatch,reference_process,ts,device,return_dict=False)
scheduler_noise_output = scheduler.add_noise(minibatch,reference_process,ts,device,return_dict=True)

x_logits,p0t_reg,p0t_sig,reg_x = model.forward(minibatch,ts,x_tilde)
model_forward_output = model.forward(minibatch,ts,x_tilde,return_dict=True)

loss_ = loss.calc_loss(minibatch,x_tilde,qt0,rate,x_logits,reg_x,p0t_sig,p0t_reg,device)

In [9]:
config.initialize()

In [12]:
#config.best_model_path
#config.sinkhorn_plot_path

'C:\\Users\\cesar\\Desktop\\Projects\\DiffusiveGenerativeModelling\\Codes\\graph-bridges\\results\\graph\\lobster\\testing\\sinkhorn_{0}.tr'

In [17]:
from diffusers.optimization import get_cosine_schedule_with_warmup
num_batches = 20

optimizer = torch.optim.AdamW(model.parameters(), lr=config.optimizer.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.optimizer.lr_warmup_steps,
    num_training_steps=(num_batches * config.optimizer.num_epochs),
)

In [18]:
loss_ = loss.calc_loss(minibatch,x_tilde,qt0,rate,x_logits,reg_x,p0t_sig,p0t_reg,device)

In [20]:
global_step = 0

# Now you train the model
for epoch in range(config.optimizer.num_epochs):
    progress_bar = tqdm(total=num_batches)
    progress_bar.set_description(f"Epoch {epoch}")

    for step in range(num_batches):
        # Obtain data
        sample_ = data_dataloader.sample(config.number_of_paths, device)
        minibatch = sample_.unsqueeze(1).unsqueeze(1)
        B = minibatch.shape[0]
        
        # Sample a random timestep for each image
        ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
        
        # Sample noise to add to the images
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        
        #noisy_images = scheduler.add_noise(clean_images, noise, timesteps)

        x_t, x_tilde, qt0, rate = scheduler.add_noise(minibatch,
                                                      reference_process,
                                                      ts,
                                                      device,
                                                      return_dict=False)
        
        #scheduler_noise_output = scheduler.add_noise(minibatch,
        #                                             reference_process,
        #                                             ts,
        #                                             device,
        #                                             return_dict=True)
        # Predict the noise residual
        #noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
        #loss = F.mse_loss(noise_pred, noise)
        #accelerator.backward(loss)
            
        x_logits,p0t_reg,p0t_sig,reg_x = model.forward(minibatch,
                                                       ts,
                                                       x_tilde)
            
        #model_forward_output = model.forward(minibatch,
        #                                     ts,
        #                                     x_tilde,
        #                                     return_dict=True)
            
        loss_ = loss.calc_loss(minibatch,
                               x_tilde,
                               qt0,
                               rate,
                               x_logits,
                               reg_x,
                               p0t_sig,
                               p0t_reg,
                               device)
        
        loss_.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), 
                "lr": lr_scheduler.get_last_lr()[0], 
                "step": global_step}
        
        progress_bar.set_postfix(**logs)
        #accelerator.log(logs, step=global_step)
        global_step += 1


Epoch 0:   0%|                                                                                    | 0/20 [00:11<?, ?it/s][A

Epoch 0:   0%|                                                                                    | 0/20 [00:00<?, ?it/s][A

NameError: name 'noise_scheduler' is not defined