In [1]:
import os
import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from tqdm import tqdm
import tensorboard

%matplotlib inline

In [1]:
from graph_bridges.models.backward_rates.backward_rate import GaussianTargetRateImageX0PredEMA
from graph_bridges.data.dataloaders import DoucetTargetData, GraphSpinsDataLoader
from graph_bridges.models.samplers.sampling import ReferenceProcess
from graph_bridges.models.reference_process.ctdd_reference import ReferenceProcess
from graph_bridges.models.losses.ctdd_losses import GenericAux
from graph_bridges.models.schedulers.scheduling_ctdd import CTDDScheduler
from graph_bridges.data.dataloaders_config import GraphSpinsDataLoaderConfig

    
from graph_bridges.models.pipelines.pipelines_utils import create_pipelines
from graph_bridges.models.schedulers.scheduling_utils import create_scheduler
from graph_bridges.models.backward_rates.backward_rate_utils import create_model
from graph_bridges.models.reference_process.reference_process_utils import create_reference
from graph_bridges.data.dataloaders_utils import create_dataloader
from graph_bridges.models.losses.loss_utils import create_loss

from graph_bridges.configs.graphs.lobster.config_mlp import BridgeMLPConfig
from graph_bridges.configs.graphs.lobster.config_base import BridgeConfig

In [2]:
mlp_config = BridgeMLPConfig()
config = BridgeConfig()
config.data = GraphSpinsDataLoaderConfig()
config.model = mlp_config.model
device = torch.device(config.device)

NameError: name 'torch' is not defined

In [None]:
#=================================================================
# CREATE OBJECTS FROM CONFIGURATION

data_dataloader: GraphSpinsDataLoader
model : GaussianTargetRateImageX0PredEMA
reference_process: ReferenceProcess
loss : GenericAux
scheduler:CTDDScheduler

data_dataloader = create_dataloader(config,device)
model = create_model(config,device)
reference_process = create_reference(config,device)
loss = create_loss(config,device)
scheduler = create_scheduler(config,device)

In [None]:
#=================================================================
minibatch = next(data_dataloader.train().__iter__())[0]
#minibatch = sample_.unsqueeze(1).unsqueeze(1)

In [6]:
# TIME ===========================================================
#ts = torch.rand((minibatch.shape[0],), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
B = minibatch.shape[0]
ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
#==========

x_t, x_tilde, qt0, rate = scheduler.add_noise(minibatch,reference_process,ts,device,return_dict=False)
scheduler_noise_output = scheduler.add_noise(minibatch,reference_process,ts,device,return_dict=True)

x_logits,p0t_reg,p0t_sig,reg_x = model.forward(minibatch,ts,x_tilde)
model_forward_output = model.forward(minibatch,ts,x_tilde,return_dict=True)

loss_ = loss.calc_loss(minibatch,x_tilde,qt0,rate,x_logits,reg_x,p0t_sig,p0t_reg,device)

In [7]:
config.initialize()

In [8]:
config.tensorboard_path

'C:\\Users\\cesar\\Desktop\\Projects\\DiffusiveGenerativeModelling\\Codes\\graph-bridges\\results\\graph\\lobster\\testing\\tensorboard'

In [9]:
#config.best_model_path
#config.sinkhorn_plot_path

In [10]:
config.optimizer.learning_rate = 0.01
config.optimizer.learning_rate

0.01

In [11]:
from diffusers.optimization import get_cosine_schedule_with_warmup
optimizer = torch.optim.AdamW(model.parameters(), lr=config.optimizer.learning_rate)
#lr_scheduler = get_cosine_schedule_with_warmup(
#    optimizer=optimizer,
#    num_warmup_steps=config.optimizer.lr_warmup_steps,
#    num_training_steps=(num_batches * config.optimizer.num_epochs),
#)

In [12]:
loss_ = loss.calc_loss(minibatch,x_tilde,qt0,rate,x_logits,reg_x,p0t_sig,p0t_reg,device)

In [13]:
#logs_dir = Path(save_dir).joinpath('tensorboard')
#logs_dir.mkdir(exist_ok=True)


#writer = SummaryWriter(logs_dir)


In [3]:
config.sampler.num_steps

1000

In [None]:
        # Sample a random timestep for each image
        ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
        

In [14]:
config.initialize()
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(config.tensorboard_path)
global_step = 0

# Now you train the model
for epoch in range(config.optimizer.num_epochs):
    progress_bar = tqdm(total=len(data_dataloader.train()))
    progress_bar.set_description(f"Epoch {epoch}")

    for step,minibatch in enumerate(data_dataloader.train()):
        minibatch = minibatch[0]
        B = minibatch.shape[0]
        
        # Sample a random timestep for each image
        ts = torch.rand((B,), device=device) * (1.0 - config.loss.min_time) + config.loss.min_time
        
        # Sample noise to add to the images
        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        
        #noisy_images = scheduler.add_noise(clean_images, noise, timesteps)

        x_t, x_tilde, qt0, rate = scheduler.add_noise(minibatch,
                                                      reference_process,
                                                      ts,
                                                      device,
                                                      return_dict=False)
        
        #scheduler_noise_output = scheduler.add_noise(minibatch,
        #                                             reference_process,
        #                                             ts,
        #                                             device,
        #                                             return_dict=True)
        # Predict the noise residual
        #noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
        #loss = F.mse_loss(noise_pred, noise)
        #accelerator.backward(loss)
            
        x_logits,p0t_reg,p0t_sig,reg_x = model.forward(minibatch,
                                                       ts,
                                                       x_tilde)
            
        #model_forward_output = model.forward(minibatch,
        #                                     ts,
        #                                     x_tilde,
        #                                     return_dict=True)
            
        loss_ = loss.calc_loss(minibatch,
                               x_tilde,
                               qt0,
                               rate,
                               x_logits,
                               reg_x,
                               p0t_sig,
                               p0t_reg,
                               device)
        
        loss_.backward()
        optimizer.step()
        #lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.update(1)
        logs = {"loss": loss_.detach().item(), 
                #"lr": lr_scheduler.get_last_lr()[0], 
                "step": global_step}
        writer.add_scalar("loss",loss_.detach().item(),global_step)
        
        progress_bar.set_postfix(**logs)
        #accelerator.log(logs, step=global_step)
        global_step += 1
writer.close()

Epoch 0: 100%|██████████████████████████████████████████████████████| 25/25 [00:00<00:00, 107.32it/s, loss=86.1, step=24]
Epoch 0: 100%|██████████████████████████████████████████████████████| 25/25 [00:00<00:00, 105.03it/s, loss=86.1, step=24][A

Epoch 1:   0%|                                                                                    | 0/25 [00:00<?, ?it/s][A
Epoch 1:   4%|██▏                                                     | 1/25 [00:00<00:00, 83.98it/s, loss=88.9, step=25][A
Epoch 1:   8%|████▍                                                   | 2/25 [00:00<00:00, 77.96it/s, loss=99.3, step=26][A
Epoch 1:  12%|██████▋                                                 | 3/25 [00:00<00:00, 82.02it/s, loss=98.8, step=27][A
Epoch 1:  16%|████████▉                                               | 4/25 [00:00<00:00, 85.87it/s, loss=95.7, step=28][A
Epoch 1:  20%|███████████▏                                            | 5/25 [00:00<00:00, 89.96it/s, loss=97.2, step=29][A
Ep

Epoch 9:  20%|███████████                                            | 5/25 [00:00<00:00, 63.90it/s, loss=83.2, step=229][A
Epoch 9:  24%|█████████████▍                                          | 6/25 [00:00<00:00, 63.99it/s, loss=102, step=230][A
Epoch 9:  28%|███████████████▋                                        | 7/25 [00:00<00:00, 66.44it/s, loss=102, step=230][A
Epoch 9:  28%|███████████████▋                                        | 7/25 [00:00<00:00, 66.44it/s, loss=101, step=231][A
Epoch 9:  32%|█████████████████▌                                     | 8/25 [00:00<00:00, 66.44it/s, loss=89.3, step=232][A
Epoch 9:  36%|████████████████████▏                                   | 9/25 [00:00<00:00, 66.44it/s, loss=100, step=233][A
Epoch 9:  40%|█████████████████████▌                                | 10/25 [00:00<00:00, 66.44it/s, loss=81.5, step=234][A
Epoch 9:  44%|███████████████████████▊                              | 11/25 [00:00<00:00, 66.44it/s, loss=96.5, step=235][A


Epoch 17:  48%|█████████████████████████▍                           | 12/25 [00:00<00:00, 73.07it/s, loss=82.3, step=436][A
Epoch 17:  52%|███████████████████████████▌                         | 13/25 [00:00<00:00, 73.07it/s, loss=91.2, step=437][A
Epoch 17:  56%|█████████████████████████████▋                       | 14/25 [00:00<00:00, 73.07it/s, loss=92.4, step=438][A
Epoch 17:  60%|███████████████████████████████▊                     | 15/25 [00:00<00:00, 73.07it/s, loss=85.7, step=439][A
Epoch 17:  64%|█████████████████████████████████▉                   | 16/25 [00:00<00:00, 71.98it/s, loss=85.7, step=439][A
Epoch 17:  64%|█████████████████████████████████▉                   | 16/25 [00:00<00:00, 71.98it/s, loss=89.4, step=440][A
Epoch 17:  68%|████████████████████████████████████                 | 17/25 [00:00<00:00, 71.98it/s, loss=98.1, step=441][A
Epoch 17:  72%|███████████████████████████████████████▌               | 18/25 [00:00<00:00, 71.98it/s, loss=88, step=442][A


Epoch 25:  76%|████████████████████████████████████████▎            | 19/25 [00:00<00:00, 82.39it/s, loss=87.1, step=643][A
Epoch 25:  80%|███████████████████████████████████████████▏          | 20/25 [00:00<00:00, 82.39it/s, loss=109, step=644][A
Epoch 25:  84%|█████████████████████████████████████████████▎        | 21/25 [00:00<00:00, 82.39it/s, loss=108, step=645][A
Epoch 25:  88%|██████████████████████████████████████████████▋      | 22/25 [00:00<00:00, 82.39it/s, loss=95.6, step=646][A
Epoch 25:  92%|████████████████████████████████████████████████▊    | 23/25 [00:00<00:00, 82.39it/s, loss=92.3, step=647][A
Epoch 25:  96%|██████████████████████████████████████████████████▉  | 24/25 [00:00<00:00, 82.39it/s, loss=94.2, step=648][A
Epoch 25: 100%|███████████████████████████████████████████████████████| 25/25 [00:00<00:00, 79.28it/s, loss=94, step=649][A
Epoch 26: 100%|████████████████████████████████████████████████████| 25/25 [00:00<00:00, 104.88it/s, loss=89.5, step=674]
Epo

Epoch 34: 100%|█████████████████████████████████████████████████████| 25/25 [00:00<00:00, 91.14it/s, loss=93.4, step=874]
Epoch 34: 100%|█████████████████████████████████████████████████████| 25/25 [00:00<00:00, 85.11it/s, loss=93.4, step=874][A

Epoch 35:   0%|                                                                                   | 0/25 [00:00<?, ?it/s][A
Epoch 35:   4%|██▏                                                   | 1/25 [00:00<00:00, 59.37it/s, loss=90.4, step=875][A
Epoch 35:   8%|████▎                                                 | 2/25 [00:00<00:00, 69.58it/s, loss=81.3, step=876][A
Epoch 35:  12%|██████▋                                                 | 3/25 [00:00<00:00, 69.36it/s, loss=92, step=877][A
Epoch 35:  16%|████████▋                                             | 4/25 [00:00<00:00, 71.12it/s, loss=93.1, step=878][A
Epoch 35:  20%|██████████▊                                           | 5/25 [00:00<00:00, 71.68it/s, loss=92.8, step=879][A
Ep

Epoch 43:  28%|███████████████                                       | 7/25 [00:00<00:00, 82.80it/s, loss=107, step=1081][A
Epoch 43:  32%|█████████████████▎                                    | 8/25 [00:00<00:00, 82.51it/s, loss=109, step=1082][A
Epoch 43:  36%|███████████████████▍                                  | 9/25 [00:00<00:00, 84.94it/s, loss=109, step=1082][A
Epoch 43:  36%|███████████████████▍                                  | 9/25 [00:00<00:00, 84.94it/s, loss=106, step=1083][A
Epoch 43:  40%|████████████████████▊                               | 10/25 [00:00<00:00, 84.94it/s, loss=92.7, step=1084][A
Epoch 43:  44%|██████████████████████▉                             | 11/25 [00:00<00:00, 84.94it/s, loss=93.9, step=1085][A
Epoch 43:  48%|████████████████████████▉                           | 12/25 [00:00<00:00, 84.94it/s, loss=98.3, step=1086][A
Epoch 43:  52%|███████████████████████████                         | 13/25 [00:00<00:00, 84.94it/s, loss=82.1, step=1087][A
