In [14]:
import os
import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm

%matplotlib inline

In [2]:
from graph_bridges.models.backward_rates.backward_rate import GaussianTargetRateImageX0PredEMA
from graph_bridges.data.dataloaders import DoucetTargetData
from graph_bridges.models.samplers.sampling import ReferenceProcess
from graph_bridges.models.reference_process.ctdd_reference import ReferenceProcess
from graph_bridges.models.losses.ctdd_losses import GenericAux
from graph_bridges.models.schedulers.scheduling_ctdd import CTDDScheduler

from graph_bridges.models.pipelines.pipelines_utils import create_pipelines
from graph_bridges.models.schedulers.scheduling_utils import create_scheduler
from graph_bridges.models.backward_rates.backward_rate_utils import create_model
from graph_bridges.models.reference_process.reference_process_utils import create_reference
from graph_bridges.data.dataloaders_utils import create_dataloader
from graph_bridges.models.losses.loss_utils import create_loss

from graph_bridges.configs.graphs.lobster.config_base import BridgeConfig

In [3]:
config = BridgeConfig()
device = torch.device(config.device)

In [4]:
#=================================================================
# CREATE OBJECTS FROM CONFIGURATION

data_dataloader: DoucetTargetData
model : GaussianTargetRateImageX0PredEMA
reference_process: ReferenceProcess
loss : GenericAux
scheduler:CTDDScheduler

data_dataloader = create_dataloader(config,device)
model = create_model(config,device)
reference_process = create_reference(config,device)
loss = create_loss(config,device)
scheduler = create_scheduler(config,device)

Scheduler


In [5]:
#=================================================================
sample_ = data_dataloader.sample(config.number_of_paths, device)
minibatch = sample_.unsqueeze(1).unsqueeze(1)

In [16]:
# TIME ===========================================================
#ts = torch.rand((minibatch.shape[0],), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
B = minibatch.shape[0]
ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
#==========

x_t, x_tilde, qt0, rate = scheduler.add_noise(minibatch,reference_process,ts,device,return_dict=False)
scheduler_noise_output = scheduler.add_noise(minibatch,reference_process,ts,device,return_dict=True)

x_logits,p0t_reg,p0t_sig,reg_x = model.forward(minibatch,ts,x_tilde)
model_forward_output = model.forward(minibatch,ts,x_tilde,return_dict=True)

loss_ = loss.calc_loss(minibatch,x_tilde,qt0,rate,x_logits,reg_x,p0t_sig,p0t_reg,device)

In [9]:
config.initialize()

In [12]:
#config.best_model_path
#config.sinkhorn_plot_path

'C:\\Users\\cesar\\Desktop\\Projects\\DiffusiveGenerativeModelling\\Codes\\graph-bridges\\results\\graph\\lobster\\testing\\sinkhorn_{0}.tr'

In [17]:
from diffusers.optimization import get_cosine_schedule_with_warmup
num_batches = 20

optimizer = torch.optim.AdamW(model.parameters(), lr=config.optimizer.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.optimizer.lr_warmup_steps,
    num_training_steps=(num_batches * config.optimizer.num_epochs),
)

In [18]:
loss_ = loss.calc_loss(minibatch,x_tilde,qt0,rate,x_logits,reg_x,p0t_sig,p0t_reg,device)

In [22]:
global_step = 0

# Now you train the model
for epoch in range(config.optimizer.num_epochs):
    progress_bar = tqdm(total=num_batches)
    progress_bar.set_description(f"Epoch {epoch}")

    for step in range(num_batches):
        # Obtain data
        sample_ = data_dataloader.sample(config.number_of_paths, device)
        minibatch = sample_.unsqueeze(1).unsqueeze(1)
        B = minibatch.shape[0]
        
        # Sample a random timestep for each image
        ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
        
        # Sample noise to add to the images
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        
        #noisy_images = scheduler.add_noise(clean_images, noise, timesteps)

        x_t, x_tilde, qt0, rate = scheduler.add_noise(minibatch,
                                                      reference_process,
                                                      ts,
                                                      device,
                                                      return_dict=False)
        
        #scheduler_noise_output = scheduler.add_noise(minibatch,
        #                                             reference_process,
        #                                             ts,
        #                                             device,
        #                                             return_dict=True)
        # Predict the noise residual
        #noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
        #loss = F.mse_loss(noise_pred, noise)
        #accelerator.backward(loss)
            
        x_logits,p0t_reg,p0t_sig,reg_x = model.forward(minibatch,
                                                       ts,
                                                       x_tilde)
            
        #model_forward_output = model.forward(minibatch,
        #                                     ts,
        #                                     x_tilde,
        #                                     return_dict=True)
            
        loss_ = loss.calc_loss(minibatch,
                               x_tilde,
                               qt0,
                               rate,
                               x_logits,
                               reg_x,
                               p0t_sig,
                               p0t_reg,
                               device)
        
        loss_.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss_.detach().item(), 
                "lr": lr_scheduler.get_last_lr()[0], 
                "step": global_step}
        
        progress_bar.set_postfix(**logs)
        #accelerator.log(logs, step=global_step)
        global_step += 1


Epoch 0:   5%|███▊                                                                        | 1/20 [00:09<03:05,  9.76s/it][A

Epoch 0:   0%|                                                                                    | 0/20 [00:00<?, ?it/s][A
Epoch 0:   5%|██▍                                              | 1/20 [00:00<00:00, 32.55it/s, loss=107, lr=4e-7, step=0][A
Epoch 0:  10%|████▉                                            | 2/20 [00:00<00:00, 32.72it/s, loss=115, lr=6e-7, step=1][A
Epoch 0:  15%|███████▎                                         | 3/20 [00:00<00:00, 36.70it/s, loss=107, lr=8e-7, step=2][A
Epoch 0:  20%|█████████▊                                       | 4/20 [00:00<00:00, 37.06it/s, loss=107, lr=8e-7, step=2][A
Epoch 0:  20%|█████████▊                                       | 4/20 [00:00<00:00, 37.06it/s, loss=107, lr=1e-6, step=3][A
Epoch 0:  25%|███████████▊                                   | 5/20 [00:00<00:00, 37.06it/s, loss=106, lr=1.2e-6, step=4][

Epoch 4:  35%|███████████████▋                             | 7/20 [00:00<00:00, 36.19it/s, loss=111, lr=1.76e-5, step=86][A
Epoch 4:  40%|██████████████████                           | 8/20 [00:00<00:00, 36.08it/s, loss=111, lr=1.76e-5, step=86][A
Epoch 4:  40%|██████████████████                           | 8/20 [00:00<00:00, 36.08it/s, loss=100, lr=1.78e-5, step=87][A
Epoch 4:  45%|████████████████████▋                         | 9/20 [00:00<00:00, 36.08it/s, loss=106, lr=1.8e-5, step=88][A
Epoch 4:  50%|██████████████████████                      | 10/20 [00:00<00:00, 36.08it/s, loss=102, lr=1.82e-5, step=89][A
Epoch 4:  55%|████████████████████████▏                   | 11/20 [00:00<00:00, 36.08it/s, loss=104, lr=1.84e-5, step=90][A
Epoch 4:  60%|██████████████████████████▍                 | 12/20 [00:00<00:00, 36.94it/s, loss=104, lr=1.84e-5, step=90][A
Epoch 4:  60%|██████████████████████████▍                 | 12/20 [00:00<00:00, 36.94it/s, loss=109, lr=1.86e-5, step=91][A


Epoch 8:  70%|██████████████████████████████▊             | 14/20 [00:00<00:00, 35.45it/s, loss=105, lr=3.5e-5, step=173][A
Epoch 8:  75%|███████████████████████████████▌          | 15/20 [00:00<00:00, 35.45it/s, loss=97.9, lr=3.52e-5, step=174][A
Epoch 8:  80%|█████████████████████████████████▌        | 16/20 [00:00<00:00, 35.60it/s, loss=97.9, lr=3.52e-5, step=174][A
Epoch 8:  80%|██████████████████████████████████▍        | 16/20 [00:00<00:00, 35.60it/s, loss=107, lr=3.54e-5, step=175][A
Epoch 8:  85%|███████████████████████████████████▋      | 17/20 [00:00<00:00, 35.60it/s, loss=95.7, lr=3.56e-5, step=176][A
Epoch 8:  90%|██████████████████████████████████████▋    | 18/20 [00:00<00:00, 35.60it/s, loss=112, lr=3.58e-5, step=177][A
Epoch 8:  95%|█████████████████████████████████████████▊  | 19/20 [00:00<00:00, 35.60it/s, loss=105, lr=3.6e-5, step=178][A
Epoch 8: 100%|████████████████████████████████████████████| 20/20 [00:00<00:00, 35.69it/s, loss=105, lr=3.6e-5, step=178][A


Epoch 13: 100%|██████████████████████████████████████████| 20/20 [00:00<00:00, 41.67it/s, loss=111, lr=5.62e-5, step=279]
Epoch 13: 100%|██████████████████████████████████████████| 20/20 [00:00<00:00, 40.77it/s, loss=111, lr=5.62e-5, step=279][A

Epoch 14:   0%|                                                                                   | 0/20 [00:00<?, ?it/s][A
Epoch 14:   5%|██▏                                        | 1/20 [00:00<00:00, 33.95it/s, loss=103, lr=5.64e-5, step=280][A
Epoch 14:  10%|████▏                                     | 2/20 [00:00<00:00, 35.14it/s, loss=78.1, lr=5.66e-5, step=281][A
Epoch 14:  15%|██████▍                                    | 3/20 [00:00<00:00, 35.94it/s, loss=107, lr=5.68e-5, step=282][A
Epoch 14:  20%|████████▌                                  | 4/20 [00:00<00:00, 38.55it/s, loss=107, lr=5.68e-5, step=282][A
Epoch 14:  20%|████████▌                                  | 4/20 [00:00<00:00, 38.55it/s, loss=97.4, lr=5.7e-5, step=283][A
Ep

Epoch 18:  35%|██████████████▋                           | 7/20 [00:00<00:00, 38.81it/s, loss=85.9, lr=7.36e-5, step=366][A
Epoch 18:  40%|█████████████████▏                         | 8/20 [00:00<00:00, 38.81it/s, loss=102, lr=7.38e-5, step=367][A
Epoch 18:  45%|███████████████████▎                       | 9/20 [00:00<00:00, 39.22it/s, loss=102, lr=7.38e-5, step=367][A
Epoch 18:  45%|███████████████████▎                       | 9/20 [00:00<00:00, 39.22it/s, loss=95.5, lr=7.4e-5, step=368][A
Epoch 18:  50%|████████████████████▌                    | 10/20 [00:00<00:00, 39.22it/s, loss=91.3, lr=7.42e-5, step=369][A
Epoch 18:  55%|███████████████████████▋                   | 11/20 [00:00<00:00, 39.22it/s, loss=78, lr=7.44e-5, step=370][A
Epoch 18:  60%|████████████████████████▌                | 12/20 [00:00<00:00, 39.22it/s, loss=96.5, lr=7.46e-5, step=371][A
Epoch 18:  65%|██████████████████████████▋              | 13/20 [00:00<00:00, 38.11it/s, loss=96.5, lr=7.46e-5, step=371][A


Epoch 22:  75%|███████████████████████████████▌          | 15/20 [00:00<00:00, 35.11it/s, loss=106, lr=9.12e-5, step=454][A
Epoch 22:  80%|█████████████████████████████████▌        | 16/20 [00:00<00:00, 35.52it/s, loss=106, lr=9.12e-5, step=454][A
Epoch 22:  80%|████████████████████████████████▊        | 16/20 [00:00<00:00, 35.52it/s, loss=83.5, lr=9.14e-5, step=455][A
Epoch 22:  85%|███████████████████████████████████▋      | 17/20 [00:00<00:00, 35.52it/s, loss=113, lr=9.16e-5, step=456][A
Epoch 22:  90%|█████████████████████████████████████▊    | 18/20 [00:00<00:00, 35.52it/s, loss=108, lr=9.18e-5, step=457][A
Epoch 22:  95%|████████████████████████████████████████▊  | 19/20 [00:00<00:00, 35.52it/s, loss=103, lr=9.2e-5, step=458][A
Epoch 22: 100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 35.96it/s, loss=103, lr=9.2e-5, step=458][A
Epoch 22: 100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 35.10it/s, loss=98.2, lr=9.22e-5, step=459][A


Epoch 28:   0%|                                                                                   | 0/20 [00:00<?, ?it/s][A
Epoch 28:   5%|██                                        | 1/20 [00:00<00:00, 34.41it/s, loss=82.9, lr=9.63e-5, step=560][A
Epoch 28:  10%|████▏                                     | 2/20 [00:00<00:00, 36.23it/s, loss=99.1, lr=9.61e-5, step=561][A
Epoch 28:  15%|██████▍                                    | 3/20 [00:00<00:00, 35.32it/s, loss=92.9, lr=9.6e-5, step=562][A
Epoch 28:  20%|████████▌                                  | 4/20 [00:00<00:00, 34.77it/s, loss=92.9, lr=9.6e-5, step=562][A
Epoch 28:  20%|████████▍                                 | 4/20 [00:00<00:00, 34.77it/s, loss=85.8, lr=9.59e-5, step=563][A
Epoch 28:  25%|██████████▊                                | 5/20 [00:00<00:00, 34.77it/s, loss=104, lr=9.58e-5, step=564][A
Epoch 28:  30%|████████████▌                             | 6/20 [00:00<00:00, 34.77it/s, loss=89.3, lr=9.56e-5, step=565][A


Epoch 32:  45%|███████████████████▎                       | 9/20 [00:00<00:00, 32.52it/s, loss=109, lr=7.94e-5, step=648][A
Epoch 32:  50%|████████████████████▌                    | 10/20 [00:00<00:00, 32.52it/s, loss=89.3, lr=7.91e-5, step=649][A
Epoch 32:  55%|██████████████████████▌                  | 11/20 [00:00<00:00, 32.52it/s, loss=88.5, lr=7.89e-5, step=650][A
Epoch 32:  60%|████████████████████████▌                | 12/20 [00:00<00:00, 34.26it/s, loss=88.5, lr=7.89e-5, step=650][A
Epoch 32:  60%|█████████████████████████▏                | 12/20 [00:00<00:00, 34.26it/s, loss=101, lr=7.86e-5, step=651][A
Epoch 32:  65%|██████████████████████████▋              | 13/20 [00:00<00:00, 34.26it/s, loss=86.2, lr=7.84e-5, step=652][A
Epoch 32:  70%|████████████████████████████▋            | 14/20 [00:00<00:00, 34.26it/s, loss=95.1, lr=7.81e-5, step=653][A
Epoch 32:  75%|███████████████████████████████▌          | 15/20 [00:00<00:00, 34.26it/s, loss=101, lr=7.78e-5, step=654][A


Epoch 36:  85%|██████████████████████████████████▊      | 17/20 [00:00<00:00, 36.47it/s, loss=94.7, lr=5.38e-5, step=736][A
Epoch 36:  90%|██████████████████████████████████████▋    | 18/20 [00:00<00:00, 36.47it/s, loss=96, lr=5.35e-5, step=737][A
Epoch 36:  95%|██████████████████████████████████████▉  | 19/20 [00:00<00:00, 36.47it/s, loss=95.6, lr=5.31e-5, step=738][A
Epoch 36: 100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 35.78it/s, loss=95.6, lr=5.31e-5, step=738][A
Epoch 36: 100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 35.47it/s, loss=85.6, lr=5.28e-5, step=739][A
Epoch 37: 100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 38.42it/s, loss=83.1, lr=4.65e-5, step=759]
Epoch 37: 100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 39.02it/s, loss=83.1, lr=4.65e-5, step=759][A

Epoch 38:   0%|                                                                                   | 0/20 [00:00<?, ?it/s][A
Ep

Epoch 42:  15%|██████▎                                   | 3/20 [00:00<00:00, 35.76it/s, loss=81.3, lr=2.22e-5, step=842][A
Epoch 42:  20%|████████▍                                 | 4/20 [00:00<00:00, 37.79it/s, loss=81.3, lr=2.22e-5, step=842][A
Epoch 42:  20%|████████▍                                 | 4/20 [00:00<00:00, 37.79it/s, loss=85.3, lr=2.19e-5, step=843][A
Epoch 42:  25%|██████████▌                               | 5/20 [00:00<00:00, 37.79it/s, loss=81.2, lr=2.16e-5, step=844][A
Epoch 42:  30%|████████████▌                             | 6/20 [00:00<00:00, 37.79it/s, loss=84.3, lr=2.14e-5, step=845][A
Epoch 42:  35%|██████████████▋                           | 7/20 [00:00<00:00, 37.79it/s, loss=93.8, lr=2.11e-5, step=846][A
Epoch 42:  40%|████████████████▊                         | 8/20 [00:00<00:00, 38.94it/s, loss=93.8, lr=2.11e-5, step=846][A
Epoch 42:  40%|████████████████▊                         | 8/20 [00:00<00:00, 38.94it/s, loss=90.5, lr=2.09e-5, step=847][A


Epoch 46:  55%|██████████████████████▌                  | 11/20 [00:00<00:00, 34.92it/s, loss=96.5, lr=4.49e-6, step=930][A
Epoch 46:  60%|████████████████████████▌                | 12/20 [00:00<00:00, 33.47it/s, loss=96.5, lr=4.49e-6, step=930][A
Epoch 46:  60%|████████████████████████▌                | 12/20 [00:00<00:00, 33.47it/s, loss=91.9, lr=4.37e-6, step=931][A
Epoch 46:  65%|██████████████████████████▋              | 13/20 [00:00<00:00, 33.47it/s, loss=70.3, lr=4.24e-6, step=932][A
Epoch 46:  70%|██████████████████████████████             | 14/20 [00:00<00:00, 33.47it/s, loss=87, lr=4.11e-6, step=933][A
Epoch 46:  75%|██████████████████████████████▊          | 15/20 [00:00<00:00, 33.47it/s, loss=85.9, lr=3.99e-6, step=934][A
Epoch 46:  80%|████████████████████████████████▊        | 16/20 [00:00<00:00, 35.26it/s, loss=85.9, lr=3.99e-6, step=934][A
Epoch 46:  80%|█████████████████████████████████▌        | 16/20 [00:00<00:00, 35.26it/s, loss=109, lr=3.87e-6, step=935][A
