# Import Required Libraries

In [None]:
import matplotlib.pyplot as plt
import torch,yaml
from utils import dict2namespace
from runners.DiffusionBasedModelRunners.BBDMRunner import DualBrownianBridgeModel
from torch.utils.data import DataLoader
from runners.utils import get_dataset

%matplotlib inline

# Load Configs and Models

In [None]:
with open("configs/Dual.yaml", 'r') as f:
    dict_config = yaml.load(f, Loader=yaml.FullLoader)

nconfig = dict2namespace(dict_config)
nconfig.training.use_DDP = False
nconfig.training.device = [torch.device(f"cuda:1")]
batch_size = 1
models = DualBrownianBridgeModel(nconfig.model)
checkpoint = torch.load('results/Cityscapes/DualBrownianBridge/checkpoint/latest_model_200.pth', map_location='cpu')
models.load_state_dict(checkpoint['model'], strict=False)


# Load the Dataset

In [None]:
_, val_dataset, test_dataset = get_dataset(nconfig.data)
if test_dataset is None:
    test_dataset = val_dataset
# test_dataset = val_dataset
if nconfig.training.use_DDP:
    test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
    test_loader = DataLoader(test_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            drop_last=True,
                            sampler=test_sampler)
else:
    test_loader = DataLoader(test_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1,
                            drop_last=True)

test_iter = iter(test_loader)

# Sampling Visualization


In [None]:
import os
from PIL import Image

from runners.utils import get_image_grid, make_dir


models = models.to(nconfig.training.device[0])
models.eval()

sample_path ="results"
sample_path = make_dir(os.path.join(sample_path, f'_sample'))
print(sample_path)

for i, ((x, x_name), (x_cond, x_cond_name)) in enumerate(test_loader):

    print(x.shape[0])
    batch_size = x.shape[0] if x.shape[0] < 4 else 4

    x = x[0:batch_size].to(nconfig.training.device[0])
    x_cond = x_cond[0:batch_size].to(nconfig.training.device[0])

    grid_size = 1

    sample = models.sample(x_cond, clip_denoised=nconfig.testing.clip_denoised).to('cpu')
    image_grid = get_image_grid(sample, grid_size, to_normal=nconfig.data.dataset_config.to_normal)
    im = Image.fromarray(image_grid)
    im.save(os.path.join(sample_path, f'skip_sample_{i}.png'))

    image_grid = get_image_grid(x_cond.to('cpu'), grid_size, to_normal=nconfig.data.dataset_config.to_normal)
    im = Image.fromarray(image_grid)
    im.save(os.path.join(sample_path, f'condition_{i}.png'))

    image_grid = get_image_grid(x.to('cpu'), grid_size, to_normal=nconfig.data.dataset_config.to_normal)
    im = Image.fromarray(image_grid)
    im.save(os.path.join(sample_path, f'ground_truth_{i}.png'))
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(get_image_grid(sample, grid_size, to_normal=nconfig.data.dataset_config.to_normal))
    axes[0].set_title('Sample')
    axes[0].axis('off')

    axes[1].imshow(get_image_grid(x_cond.to('cpu'), grid_size, to_normal=nconfig.data.dataset_config.to_normal))
    axes[1].set_title('Condition')
    axes[1].axis('off')

    axes[2].imshow(get_image_grid(x.to('cpu'), grid_size, to_normal=nconfig.data.dataset_config.to_normal))
    axes[2].set_title('Ground Truth')
    axes[2].axis('off')

    plt.show()
