In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
from datasets.transform import Transform

from models import Encoder, Decoder, ConditionalDiffusion
from datasets.datasets import DiffusionDataset, DiffusionDatasetNoBG
from datasets.datasets import CondDiffusionDataset, CondDiffusionDatasetNoBG

## Evaluate Encoder

In [None]:
remove_bg = False
data_path = 'data/'
batch_size = 16
rgb_data = np.load(os.path.join(data_path, 'rgb_128.npy'))
mask_data = np.load(os.path.join(data_path, 'segmap_128.npy'))
if remove_bg:
    dataset = DiffusionDatasetNoBG(rgb_data, mask_data)
else:
    dataset = DiffusionDataset(rgb_data)
data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=1)

In [None]:
encoder_path = '1129_1701'
ckpt_dir = '/home/gun/ssd/disk/PreferenceDiffusion/tidying-line-diffusion'

device = torch.device('cuda')
encoder = Encoder(output_dim=latent_dim).to(device)
decoder = Decoder(input_dim=latent_dim).to(device)

state_dict = torch.load(os.path.join(ckpt_dir, 'encoder_%s/checkpoint_best.pt' %encoder_path))
encoder.load_state_dict(state_dict['encoder'])
decoder.load_state_dict(state_dict['decoder'])

In [None]:
transform = Transform()
#resize = Resize((128, 128))

for batch in data_loader:
    x, mask = batch
    x = transform(x.permute((0, 3, 1, 2))
    x = x.to(torch.float32).to(device)
    posterior, prior_loss = encoder(x)
    z = posterior.rsample()
    x_recon = decoder(z)
    print(x.recon_shape)

## Evaluate Diffusion

In [None]:
remove_bg = False
data_path = 'data/'
batch_size = 16
rgb_data = np.load(os.path.join(data_path, 'rgb_128.npy'))
segmap_data = np.load(os.path.join(data_path, 'segmap_16.npy'))
if remove_bg:
    mask_data = np.load(os.path.join(data_path, 'segmap_128.npy'))
    dataset = CondDiffusionDatasetNoBG(rgb_data, segmap_data, mask_data)
else:
    dataset = CondDiffusionDataset(rgb_data, segmap_data)
data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=1)

In [None]:
ldm_path = '1208_2347'

latent_dim = 16
n_timesteps = 1000
ckpt_dir = '/home/gun/ssd/disk/PreferenceDiffusion/tidying-line-diffusion'

device = torch.device('cuda')
encoder = Encoder(output_dim=latent_dim).to(device)
decoder = Decoder(input_dim=latent_dim).to(device)
diffusion = ConditionalDiffusion(input_dim=latent_dim, cond_dim=latent_dim, n_timesteps=n_timesteps).to(device)

ldm_dir = [d for d in os.listdir(ckpt_dir) if d.startswith('diffusion_') and d.endswith(ldm_path)][0]
state_dict = torch.load(os.path.join(ckpt_dir, ldm_dir, 'checkpoint_best.pt'))
encoder.load_state_dict(state_dict['encoder'])
decoder.load_state_dict(state_dict['decoder'])
diffusion.load_state_dict(state_dict['diffusion'])

In [None]:
transform = Transform()
#resize = Resize((128, 128))

for batch in data_loader:
    x, mask = batch
    x = transform(x.permute((0, 3, 1, 2))
    x = x.to(torch.float32).to(device)
    posterior, _ = encoder(x)
    feature = posterior.mean

    cond = (masks != 0).to(torch.float32).view(-1, 1, 16, 16) * feature
    if cond_type=='point':
        cond = (masks!=0).to(torch.float32).view(-1, 1, 16, 16) * feature
    elif cond_type=='mask':
        cond = torch.zeros_like(feature)
        for m in range(1, 5):
            masks_m == (masks==m).to(torch.float32).view(-1, 1, 16, 16)
            count_m = mask_m.sum((2, 3))
            feature_m = mask_m * feature
            feature_m_mean = feature_m.sum((2, 3)) / count_m
            feature_m_mean = torch.where(count_m==0, torch.zeros_like(feature_m_mean), feature_m_mean)
            cond += mask_m * feature_m_mean.view(-1, 16, 1, 1)
    feature_recon = diffusion(cond)
    
    img = decoder(feature)
    img_recon = decoder(feature_recon)
    print(img_recon.shape)