# Diffusion Test 
Notebook to test the new diffusion model 
It should be loading the model, getting random batches from the test dataset and try sampling few next states
Both the current and the next state should be plotted with blue and sampled next states should be plotted in green

In [2]:
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
import torch
import torch.utils.data as data 

from collections import OrderedDict
from copy import deepcopy
from omegaconf import OmegaConf
from torchvision import transforms
from torch.nn.parallel import DistributedDataParallel as DDP
from contrastive_learning.tests.plotting import plot_corners, plot_rvec_tvec, plot_mean_rot
from tqdm import tqdm
# 
from contrastive_learning.tests.test_model import load_lin_model, predict_traj_actions, load_diff_model
from contrastive_learning.tests.animate_markers import AnimateMarkers
from contrastive_learning.tests.animate_rvec_tvec import AnimateRvecTvec
from contrastive_learning.datasets.dataloaders import get_dataloaders

from contrastive_learning.models.custom_models import LinearInverse, EpsModel
from contrastive_learning.datasets.state_dataset import StateDataset
from contrastive_learning.tests.plotting import plot_rvec_tvec, plot_corners
from contrastive_learning.datasets.dataloaders import get_dataloaders

## Load the Model
Create the distributed group and load the eps model used for the diffusion

In [7]:
# Start the multiprocessing to load the saved models properly
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29506"

torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
torch.cuda.set_device(0)

In [8]:
# Set the device and out_dir
device = torch.device('cuda:0')
out_dir = '/home/irmak/Workspace/DAWGE/contrastive_learning/out/2022.08.16/00-48_diffusion_ref_global_fi_5_pt_mean_rot_bs_32_hd_64_lr_0.0001_zd_8'
cfg = OmegaConf.load(os.path.join(out_dir, '.hydra/config.yaml'))
model_path = os.path.join(out_dir, 'models/eps_model.pt')

In [9]:
# Load the eps model
eps_model = load_diff_model(cfg, device, model_path)

In [10]:
# N_STEPS = cfg.diff_n_steps
# N_SAMPLES = 10 # For each state we'll sample 1k different new states

In [11]:
_, test_loader, dataset = get_dataloaders(cfg)
batch = next(iter(test_loader))
x0, xnext0, a = [b.to(device) for b in batch]

DATASET POS_REF: global
self.action_min: [-0.15000001 -0.30000001], self.action_max: [0.15000001 0.30000001]


In [12]:
class DiffusionTest:
    def __init__(self, eps_model, n_steps, n_samples, device):
        self.eps_model = eps_model
        self.n_steps = n_steps 
        self.n_samples = n_samples 
        self.device = device 

        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps # Number of steps to noise and denoise the data
        self.sigma2 = self.beta

    def gather(self, consts: torch.Tensor, t: torch.Tensor):
        c = consts.gather(-1, t)
        return c.reshape(-1, 1)

    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, x0: torch.Tensor, a: torch.Tensor):
        eps_theta = self.eps_model(xt, t, x0, a) # Input to this will be complete noise
        alpha_bar = self.gather(self.alpha_bar, t)
        alpha = self.gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        var = self.gather(self.sigma2, t) # (1 - self.alpha)

        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var ** 5.) * eps

    def get_sample(self, curr_x0, curr_a): # curr_x0.shape: (1, pos_dim*2) - so this is only for one element in the batch
        xt = torch.randn((curr_x0.shape), device=curr_x0.device)
        for t_ in range(self.n_steps):
            curr_t = self.n_steps - t_ - 1
            t = xt.new_full((xt.shape[0],), curr_t, dtype=torch.long)
            xt = self.p_sample(xt, t, curr_x0, curr_a)
        
        return xt

    def get_all_samples(self, curr_x0, curr_a): # curr_x0.shape: (1, pos_dim*2) - this time the output will be concatenated version of xts
        pbar = tqdm(total=self.n_samples)
        for i in range(self.n_samples):
            if i == 0:
                xt = self.get_sample(curr_x0, curr_a)
            else:
                xt = torch.cat((xt, self.get_sample(curr_x0, curr_a)), dim=0)
            pbar.update(1)

        pbar.close()
        return xt

    def get_all_samples_for_batch(self, x0, a):
        bs = x0.shape[0]
        for i in range(bs):
            curr_x0, curr_a = x0[i:i+1], a[i:i+1]
            if i == 0:
                all_xt = torch.unsqueeze(self.get_all_samples(curr_x0, curr_a),0)
            else:
                all_xt = torch.cat((all_xt, torch.unsqueeze(self.get_all_samples(curr_x0, curr_a),0)), dim=0)

        return all_xt
    
    

In [13]:
N_SAMPLES = 50
N_STEPS = 50

In [14]:
diff_test = DiffusionTest(eps_model, n_steps=N_STEPS, n_samples=N_SAMPLES, device=device)
all_xt = diff_test.get_all_samples_for_batch(x0, a)

100%|██████████| 50/50 [00:01<00:00, 32.26it/s]
100%|██████████| 50/50 [00:00<00:00, 50.08it/s]
100%|██████████| 50/50 [00:00<00:00, 51.64it/s]
100%|██████████| 50/50 [00:01<00:00, 49.69it/s]
100%|██████████| 50/50 [00:01<00:00, 45.31it/s]
100%|██████████| 50/50 [00:00<00:00, 51.48it/s]
100%|██████████| 50/50 [00:00<00:00, 51.46it/s]
100%|██████████| 50/50 [00:00<00:00, 52.96it/s]
100%|██████████| 50/50 [00:00<00:00, 51.76it/s]
100%|██████████| 50/50 [00:00<00:00, 52.62it/s]
100%|██████████| 50/50 [00:00<00:00, 52.51it/s]
100%|██████████| 50/50 [00:00<00:00, 51.41it/s]
100%|██████████| 50/50 [00:01<00:00, 47.29it/s]
100%|██████████| 50/50 [00:00<00:00, 51.56it/s]
100%|██████████| 50/50 [00:00<00:00, 51.88it/s]
100%|██████████| 50/50 [00:00<00:00, 51.51it/s]
100%|██████████| 50/50 [00:00<00:00, 52.44it/s]
100%|██████████| 50/50 [00:00<00:00, 52.10it/s]
100%|██████████| 50/50 [00:00<00:00, 52.18it/s]
100%|██████████| 50/50 [00:00<00:00, 52.25it/s]
100%|██████████| 50/50 [00:00<00:00, 50.

In [15]:
print(all_xt[0,:])

tensor([[0.7478, 0.1294, 0.2482, 0.7092, 0.1449, 0.0974],
        [0.7457, 0.1306, 0.2492, 0.7084, 0.1464, 0.1135],
        [0.7495, 0.1300, 0.2489, 0.7180, 0.1458, 0.0743],
        [0.7443, 0.1313, 0.2501, 0.7077, 0.1476, 0.1211],
        [0.7421, 0.1169, 0.2545, 0.6978, 0.1532, 0.4100],
        [0.7482, 0.1296, 0.2495, 0.7121, 0.1453, 0.0918],
        [0.7421, 0.1086, 0.2525, 0.6783, 0.1419, 0.3889],
        [0.7446, 0.1354, 0.2511, 0.7249, 0.1519, 0.1100],
        [0.7451, 0.1311, 0.2495, 0.7084, 0.1471, 0.1173],
        [0.7543, 0.1261, 0.2483, 0.7175, 0.1429, 0.0482],
        [0.7466, 0.1303, 0.2494, 0.7103, 0.1460, 0.1034],
        [0.7458, 0.1308, 0.2532, 0.7188, 0.1471, 0.0946],
        [0.7435, 0.1160, 0.2488, 0.6878, 0.1497, 0.3892],
        [0.7472, 0.1300, 0.2477, 0.7079, 0.1454, 0.1032],
        [0.7423, 0.1263, 0.2447, 0.6880, 0.1449, 0.1790],
        [0.7467, 0.1300, 0.2489, 0.7085, 0.1458, 0.1046],
        [0.7517, 0.1276, 0.2486, 0.7149, 0.1437, 0.0675],
        [0.747

In [16]:

# Plot the sampled states
bs = x0.shape[0]
pos_dim = int(x0.shape[1] / 2)
if pos_dim == 8: # Pos type is corners
    plotting_fn = plot_corners
    denormalize_fn = dataset.denormalize_corner
elif pos_dim == 6: # Pos type is rotational and translational vectors
    plotting_fn = plot_rvec_tvec
    denormalize_fn = dataset.denormalize_pos_rvec_tvec # NOTE: This will def cause some problems
elif pos_dim == 3: # Pos type is just the mean and rotation of the box
    plotting_fn = plot_mean_rot
    denormalize_fn = dataset.denormalize_mean_rot

ncols = 10
nrows = math.ceil(bs / ncols)
fig, axs = plt.subplots(figsize=(10*ncols, 10*nrows), nrows=nrows, ncols=ncols)

pbar = tqdm(total=bs)
# Denormalize all the positions
for i in range(bs):
    x0_curr = denormalize_fn(x0[i].cpu().detach().numpy())
    xnext0_curr = denormalize_fn(xnext0[i].cpu().detach().numpy())

    # Plot the denormalized corners
    axs_row = int(i / ncols)
    axs_col = int(i % ncols)
    axs[axs_row, axs_col].set_title("Data {} in the batch".format(i))
    _, frame_axis = plotting_fn(axs[axs_row, axs_col], x0_curr, color_scheme=1)
    _, frame_axis = plotting_fn(axs[axs_row, axs_col], xnext0_curr, use_frame_axis=True, frame_axis=frame_axis, color_scheme=1)

    for j in range(N_SAMPLES):
        xt_curr = denormalize_fn(all_xt[i,j].cpu().detach().numpy())
        _, frame_axis = plotting_fn(axs[axs_row, axs_col], xt_curr, use_frame_axis=True, frame_axis=frame_axis, color_scheme=2)

    pbar.update(1)
    
# Save the saving plot
pbar.close()
exp_name = '{}_{}'.format(out_dir.split('/')[-2], out_dir.split('/')[-1].split('_')[0])
plt.savefig('diff_samples_{}_steps_{}_samples_{}.png'.format(exp_name, N_STEPS, N_SAMPLES))


100%|██████████| 32/32 [00:05<00:00,  6.09it/s]
