In [None]:
# from pytorch_fid import fid_score
from pytorch_fid import fid_score
import torch

# TODO

explain me + installation of pytorch-fid and run with python -m pytorch_fid path/to/fake_images path/to/real_images

In [None]:
from scenes_dataset import DatasetConstants

def normalize_to_0_1(scene_matrix):
    """
    Normalize the scene matrix to the range [0, 1].
    """
    range_matrix = DatasetConstants.get_range_matrix()
    # TODO: these values are not 100%-percentile but 99%-percentile min/max values
    max_value, min_value = range_matrix[0], range_matrix[1]
    normalized_scene_matrix = (scene_matrix - min_value) / (max_value - min_value)
    return normalized_scene_matrix

# Ground truth dataset preparation

Map every scene to its scene matrix, normalize it between 0 and 1, treat as an L image and store it

In [None]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']
    
gt_data = train_data + val_data

In [None]:

from PIL import Image

for scene in gt_data:
    scene_matrix = torch.tensor(scene["scene_matrix"], dtype=torch.float32)
    scene_matrix = normalize_to_0_1(scene_matrix)
    # print raw data
    # print(f"Min: {scene_matrix.min()}, Max: {scene_matrix.max()}")
    scene_img = Image.fromarray(scene_matrix.numpy(), mode="L")
    scene_img.save(f"fid_data/gt/{scene['scene_id']}.png")

# Synthetic dataset generation and preparation

In [None]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler
from scenes_dataset import ScenesDataset, DatasetConstants

In [None]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']

In [None]:
B = 32 # num of scenes in batch

# Scene hyperparams
N = 20 # num of objects in scene
D = 15 # dim of objects from the scene

# Time hyperparams
T = 14

# Condition hyperparmas
C = 300 # dim of node features
R = 23+1 # num of relations

hparams = {
    # constants
    'epochs': 2000, 'scheduler_loss': 'l2', 'rgc_activation': 'tanh',
    # from hparam search
    'batch_size': B, 'time_dim': 44, 'rgc_hidden_dims': '()', 'rgc_num_bases': 4, 'rgc_aggr': 'mean', 'rgc_dp_rate': 0.14463856683812687, 'rgc_bias': False, 'attention_self_head_dims': 30, 'attention_num_heads': 1, 'attention_cross_head_dims': 30, 'scheduler_timesteps': 1000, 'scheduler_beta_schedule': 'linear', 'cfg_cond_drop_prob': 0.16303181894889107, 'optimizer_lr': 0.000571096217369203, 'optimizer_weight_decay': 0.00010261093147577781, 'lr_scheduler_factor': 0.813888153675873, 'lr_scheduler_patience': 60, 'lr_scheduler_minlr': 0.00036368282361166394
}

In [None]:
general_params = {
    "num_obj": N,
    "obj_cond_dim": C,
    'layer_1_dim': D,
    'layer_2_dim': D + hparams['time_dim'],
    "time_dim": hparams['time_dim'],
}

attention_params = {
    "attention_self_head_dim": hparams['attention_self_head_dims'],
    "attention_num_heads": hparams['attention_num_heads'],
    "attention_cross_head_dim": hparams['attention_cross_head_dims']
}

rgc_params = {
    "rgc_hidden_dims": hparams['rgc_hidden_dims'],
    "rgc_num_relations": R,
    "rgc_num_bases": hparams['rgc_num_bases'],
    "rgc_aggr": hparams['rgc_aggr'],
    "rgc_activation": hparams['rgc_activation'],
    "rgc_dp_rate": hparams['rgc_dp_rate'],
    "rgc_bias": hparams['rgc_bias']
}

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


# --- Load the data
range_matrix = DatasetConstants.get_range_matrix().to(device)

# --- Instantiate the model
model = GuidedDiffusionNetwork(
    general_params=general_params,
    attention_params=attention_params,
    rgc_params=rgc_params,
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

# load the best model
model_name = "models/val-model_0146_l2_all+CFG.pt"
model.load_state_dict(torch.load(model_name))
print(f"Loaded model from {model_name}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)

model.eval()
scheduler.eval()

In [None]:
from torch_geometric.loader import DataLoader

train_dataset = ScenesDataset(train_data)
val_dataset = ScenesDataset(val_data)

all_dataloader = DataLoader(train_dataset + val_dataset, batch_size=B, shuffle=False)

In [None]:
counter = 0

for batch in all_dataloader:
    # x_batch = batch.x.to(device)
    obj_cond_batch = batch.cond.to(device)
    edge_cond_batch = batch.edge_index.to(device)
    relation_cond_batch = batch.edge_attr.to(device)
    
    # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
    # x_batch = x_batch.view(batch.num_graphs, N, D)
    # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
    obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)
    
    labels_batch = batch.labels

    # Run inference
    with torch.no_grad():      
        # Sample from the model with conditions matching train validation data
        sampled_scenes = scheduler.sample(obj_cond_batch, edge_cond_batch, relation_cond_batch, cond_scale=3., return_all_samples=False)
        for i in range(sampled_scenes.shape[0]):
            # Save the sampled scene to a JSON file
            scene_matrix = normalize_to_0_1(sampled_scenes[i])
            # print raw data
            # print(f"Min: {scene_matrix.min()}, Max: {scene_matrix.max()}")
            scene_img = Image.fromarray(scene_matrix.numpy(), mode="L")
            scene_img.save(f"fid_data/synth_l2_double_cfg_3/{counter}.png")
            counter += 1

## Mock data test

In [None]:
from PIL import Image

B = 400
N = 20
D = 15

X1 = torch.randn(B, N, D)
X2 = torch.randn(B, N, D)

for i in range(B):
    # Treat each NxD matrix as a single image and save these images to disk
    x1_img = Image.fromarray(X1[i].numpy(), mode="L")
    x2_img = Image.fromarray(X2[i].numpy(), mode="L")
    x1_img.save(f"tmp/gt/X1_{i}.png")
    x2_img.save(f"tmp/synth/X2_{i}.png")
    



# for i in range(B):
#     torch.save(X1[i], f"tmp/gt/X1_{i}.pt")
#     torch.save(X2[i], f"tmp/synth/X2_{i}.pt")