In [None]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler

In [None]:
import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, scenes):
        super(CustomDataset, self).__init__()
        self.scenes = scenes

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, index):
        scene = self.scenes[index]
        
        scene_matrix = torch.tensor(scene["scene_matrix"], dtype=torch.float32)
        graph_objects = torch.tensor(scene["graph_objects"], dtype=torch.float32)
        graph_edges = torch.tensor(scene["graph_edges"], dtype=torch.long)
        graph_relationships = torch.tensor(scene["graph_relationships"], dtype=torch.long)

        return {
            'x': scene_matrix,
            'obj_cond': graph_objects,
            'edge_cond': graph_edges,
            'relation_cond': graph_relationships
        }


    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.cat([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }


In [None]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']
  
# Not available yet  
# with open('datasets/data/test.json', 'r') as file:
#     test_data = json.load(file)['scenes']

In [None]:
# range matrix for real data
text_max = torch.ones(300)
text_min = -torch.ones(300)

location_max = torch.tensor([3.285, 3.93, 0.879])
location_min = torch.tensor([-3.334, -2.619, -1.329])

normalized_axes_max = torch.ones(9)
normalized_axes_min = -torch.ones(9)

size_max = torch.tensor([4.878, 2.655, 2.305])
size_min = torch.tensor([0.232, 0.14, 0.094])

range_max = torch.cat((text_max, location_max, normalized_axes_max, size_max), dim=0)
range_min = torch.cat((text_min, location_min, normalized_axes_min, size_min), dim=0)

range_matrix = torch.cat((range_max.unsqueeze(0), range_min.unsqueeze(0)), dim=0)

In [None]:
B = 64 # num of scenes in batch

# Scene hyperparams
N = 20 # num of nodes

# RGCN hyperparams
C = 300 # dim of node features
R = 23 # num of relations

# Attention hyperparams
D = 315 # dim of objects from the scene

hparams = {
    'batch_size': B, # num of graphs in batch
    
    # --- Attention hyperparams ---
    'attention_out_dim': D,
    'attention_num_heads': 5, # must be a divisor of D
    
    # --- Encoder RGCN hyperparams ---
    'encoder_out_dim': C,
    'encoder_hidden_dims': f"{()}", # (C, C),
    'encoder_num_bases': None,
    'encoder_aggr': 'mean',
    'encoder_activation': 'leakyrelu',
    'encoder_dp_rate': 0.,
    'encoder_bias': True,
    
    # --- Fusion RGCN hyperparams ---
    'fusion_hidden_dims': f"{()}", # (C+D, C+D, D),
    'fusion_num_bases': None,
    'fusion_aggr': 'mean',
    'fusion_activation': 'leakyrelu',
    'fusion_dp_rate': 0.,
    'fusion_bias': True,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l1',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.1,
    
    # Training and optimizer hyperparams
    'epochs': 5000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.8,
    'lr_scheduler_patience': 20,
    'lr_scheduler_minlr': 8e-5,
}

In [None]:
from torch.utils.data import DataLoader

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')

range_matrix = range_matrix.to(device)

# --- Instantiate the model
model = GuidedDiffusionNetwork(
    attention_in_dim=D,
    attention_out_dim=hparams['attention_out_dim'],
    attention_num_heads=hparams['attention_num_heads'],
    
    rgcn_num_relations=R,
    
    encoder_in_dim=C,
    encoder_out_dim=hparams['encoder_out_dim'],
    encoder_hidden_dims=hparams['encoder_hidden_dims'],
    encoder_num_bases=hparams['encoder_num_bases'],
    encoder_aggr=hparams['encoder_aggr'],
    encoder_activation=hparams['encoder_activation'],
    encoder_dp_rate=hparams['encoder_dp_rate'],
    encoder_bias=hparams['encoder_bias'],
    
    fusion_hidden_dims=hparams['fusion_hidden_dims'],
    fusion_num_bases=hparams['fusion_num_bases'],
    fusion_aggr=hparams['fusion_aggr'],
    fusion_activation=hparams['fusion_activation'],
    fusion_dp_rate=hparams['fusion_dp_rate'],
    fusion_bias=hparams['fusion_bias'],
    
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

# Load best model
model.load_state_dict(torch.load('models/best-model_035.pt'))

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix,
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)

model.eval()
scheduler.eval()

In [None]:
# Run inference using scene conditions from the validation set
train_dataset = CustomDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=train_dataset.collate_fn)

val_dataset = CustomDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True, collate_fn=val_dataset.collate_fn)

for batch in val_dataloader:
    x = batch['x'].to(device)
    obj_cond = batch['obj_cond'].to(device)
    edge_cond = batch['edge_cond'].to(device)
    relation_cond = batch['relation_cond'].to(device)

    # Run inference
    with torch.no_grad():        
        # Sample from the model (use the same conditioning as the overfitting)
        sampled_scene = scheduler.sample(obj_cond, edge_cond, relation_cond, cond_scale=5.0)
    break

In [None]:
scan_id = 0

# Take the first element in the batch and only the last 15 dimensions of it
filtered_scene = sampled_scene[0, :, -15:]

objs = []
for i in range(20):
    label = 'unknown' # TODO: generate label from neighborhood search from the embeddings
    location = filtered_scene[i, 0:3]
    normalized_axes = filtered_scene[i, 3:12]
    sizes = filtered_scene[i, 12:15]
    
    objs.append({
        'obb': {
            'centroid': location.tolist(),
            'normalizedAxes': normalized_axes.tolist(),
            'axesLengths': sizes.tolist()
        },
        'label': label,
        'dominantNormal': [0, 0, 0], # not used for now
    })

# Store the sampled scene to visualize using DVIS
encoded_scene = {
    'scan_id': scan_id,
    'segGroups': objs, # TODO: add segGroups
}

# save the sampled scene to a JSON file (create the folder if it doesn't exist)
with open(f'datasets/data/gen/scene_{scan_id}/semseg.v2.json', 'w') as file:
    json.dump(encoded_scene, file, indent=2)