In [8]:
import torch
import torch.nn as nn

# import from guided-diffusion folder
from model_alternative import GuidedDiffusionNetwork
from ddpm_scheduler import DDPMScheduler
from scenes_dataset import ScenesDataset, DatasetConstants

In [9]:
import torch
from torch_geometric.data import Data
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, scenes):
        super(CustomDataset, self).__init__()
        self.scenes = scenes

    def __len__(self):
        return len(self.scenes)

    def __getitem__(self, index):
        scene = self.scenes[index]
        
        scene_matrix = torch.tensor(scene["scene_matrix"], dtype=torch.float32)
        graph_objects = torch.tensor(scene["graph_objects"], dtype=torch.float32)
        graph_edges = torch.tensor(scene["graph_edges"], dtype=torch.long)
        graph_relationships = torch.tensor(scene["graph_relationships"], dtype=torch.long)

        return {
            'x': scene_matrix,
            'obj_cond': graph_objects,
            'edge_cond': graph_edges,
            'relation_cond': graph_relationships
        }


    def collate_fn(self, batch):
        x_batch = torch.stack([item['x'] for item in batch], dim=0)
        obj_cond_batch = torch.cat([item['obj_cond'] for item in batch], dim=0)
        edge_cond_batch = torch.cat([item['edge_cond'] for item in batch], dim=1)
        relation_cond_batch = torch.cat([item['relation_cond'] for item in batch], dim=0)

        return {
            'x': x_batch,
            'obj_cond': obj_cond_batch,
            'edge_cond': edge_cond_batch,
            'relation_cond': relation_cond_batch
        }


In [10]:
import json

# Load data from JSON file
with open('datasets/data/train.json', 'r') as file:
    train_data = json.load(file)['scenes']

with open('datasets/data/val.json', 'r') as file:
    val_data = json.load(file)['scenes']
  
# Not available yet  
# with open('datasets/data/test.json', 'r') as file:
#     test_data = json.load(file)['scenes']

In [11]:
B = 3 # num of scenes in batch

# Scene hyperparams
N = 20 # num of objects in scene
D = 15 # dim of objects from the scene

# Condition hyperparmas
C = 300 # dim of node features
R = 23+1 # num of relations

hparams = {
    'batch_size': B, # num of graphs in batch
    'layer_2_dim': 29, # must be a divisor of 300

    # --- RGCN hyperparams ---
    'rgc_hidden_dims': f"{()}", # (C+D, C+D, D),
    'rgc_num_bases': 5, # Alternative: None
    'rgc_aggr': 'mean',
    'rgc_activation': 'tanh',
    'rgc_dp_rate': 0.,
    'rgc_bias': True,
    
    # --- Attention hyperparams ---
    'attention_self_head_dims': 10,
    'attention_num_heads': 3, 
    'attention_cross_head_dims': 30,
    
    # Scheduler hyperparams
    'scheduler_timesteps': 1000,
    'scheduler_loss': 'l2',
    'scheduler_beta_schedule': 'cosine',
    # Note: not needed for now
    # 'scheduler_sampling_timesteps': None,
    # "scheduler_objective": 'pred_noise',
    # 'scheduler_ddim_sampling_eta': 1.0,
    # 'scheduler_min_snr_loss_weight': False,
    # 'scheduler_min_snr_gamma': 5,
    
    # Classifier-free guidance parameters
    'cfg_cond_drop_prob': 0.,
    
    # Training and optimizer hyperparams
    'epochs': 2000,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'lr_scheduler_factor': 0.8,
    'lr_scheduler_patience': 20,
    'lr_scheduler_minlr': 8e-5,
}


In [12]:
general_params = {
    "num_obj": N,
    "obj_cond_dim": C
}

attention_params = {
    "attention_self_head_dim": hparams['attention_self_head_dims'],
    "attention_num_heads": hparams['attention_num_heads'],
    "attention_cross_head_dim": hparams['attention_cross_head_dims']
}

rgc_params = {
    "rgc_hidden_dims": hparams['rgc_hidden_dims'],
    "rgc_num_relations": R,
    "rgc_num_bases": hparams['rgc_num_bases'],
    "rgc_aggr": hparams['rgc_aggr'],
    "rgc_activation": hparams['rgc_activation'],
    "rgc_dp_rate": hparams['rgc_dp_rate'],
    "rgc_bias": hparams['rgc_bias']
}

In [13]:
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.loader import DataLoader

if torch.cuda.is_available():
    device = torch.device('cuda')
# Not all operations support MPS yet so this option is not available for now
# elif torch.has_mps:
#     device = torch.device('mps')
else:
    device = torch.device('cpu')


# --- Load the data
range_matrix = DatasetConstants.get_range_matrix().to(device)

# --- Instantiate the model
model = GuidedDiffusionNetwork(
    layer_1_dim=D,
    layer_2_dim=hparams['layer_2_dim'],
    general_params=general_params,
    attention_params=attention_params,
    rgc_params=rgc_params,
    cond_drop_prob=hparams['cfg_cond_drop_prob']
)

# Load best model
model.load_state_dict(torch.load('models/overfit-model.pt'))

print(f"Model:\n{model}")

scheduler = DDPMScheduler(
    model=model,
    N=N,
    D=D,
    range_matrix = range_matrix[:, C:],
    timesteps=hparams['scheduler_timesteps'],
    sampling_timesteps=None,
    loss_type=hparams['scheduler_loss'],
    objective='pred_noise',
    beta_schedule=hparams['scheduler_beta_schedule'],
    ddim_sampling_eta=1.0,
    min_snr_loss_weight=False,
    min_snr_gamma=5
)

print(f"DDPM Scheduler:\n{scheduler}")

# Move to device
model = model.to(device)
scheduler = scheduler.to(device)

model.eval()
scheduler.eval()

Model:
GuidedDiffusionNetwork(
  (block1): GuidedDiffusionBlock(
    (time_embedding_module): TimeEmbedding()
    (max_pool): MaxPool1d(kernel_size=(20,), stride=(20,), padding=0, dilation=1, ceil_mode=False)
    (rgc_module): RelationalRGCN(
      (layers): ModuleList(
        (0): RGCNConv(15, 15, num_relations=24)
        (1): Tanh()
      )
    )
    (self_attention_module): SelfMultiheadAttention(
      (qkv_proj): Linear(in_features=15, out_features=90, bias=False)
      (o_proj): Linear(in_features=30, out_features=15, bias=False)
      (layer_norm): LayerNorm((20, 15), eps=1e-05, elementwise_affine=True)
    )
    (cross_attention_module): CrossMultiheadAttention(
      (q_proj): Linear(in_features=15, out_features=90, bias=False)
      (kv_proj): Linear(in_features=300, out_features=180, bias=False)
      (o_proj): Linear(in_features=90, out_features=15, bias=False)
      (layer_norm): LayerNorm((20, 15), eps=1e-05, elementwise_affine=True)
    )
  )
  (linear1): Linear(in_fea

DDPMScheduler(
  (model): GuidedDiffusionNetwork(
    (block1): GuidedDiffusionBlock(
      (time_embedding_module): TimeEmbedding()
      (max_pool): MaxPool1d(kernel_size=(20,), stride=(20,), padding=0, dilation=1, ceil_mode=False)
      (rgc_module): RelationalRGCN(
        (layers): ModuleList(
          (0): RGCNConv(15, 15, num_relations=24)
          (1): Tanh()
        )
      )
      (self_attention_module): SelfMultiheadAttention(
        (qkv_proj): Linear(in_features=15, out_features=90, bias=False)
        (o_proj): Linear(in_features=30, out_features=15, bias=False)
        (layer_norm): LayerNorm((20, 15), eps=1e-05, elementwise_affine=True)
      )
      (cross_attention_module): CrossMultiheadAttention(
        (q_proj): Linear(in_features=15, out_features=90, bias=False)
        (kv_proj): Linear(in_features=300, out_features=180, bias=False)
        (o_proj): Linear(in_features=90, out_features=15, bias=False)
        (layer_norm): LayerNorm((20, 15), eps=1e-05, elem

In [14]:
train_dataset = ScenesDataset(train_data)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

val_dataset = ScenesDataset(val_data)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)

In [15]:
def generate_semseg_file(sampled_scene, scan_id=0):
    # Take the first element in the batch and only the last 15 dimensions of it
    filtered_scene = sampled_scene[0, :, -15:]

    objs = []
    for i in range(20):
        label = 'unknown' # TODO: generate label from neighborhood search from the embeddings
        location = filtered_scene[i, 0:3]
        normalized_axes = filtered_scene[i, 3:12]
        sizes = filtered_scene[i, 12:15]
        
        objs.append({
            'obb': {
                'centroid': location.tolist(),
                'normalizedAxes': normalized_axes.tolist(),
                'axesLengths': sizes.tolist()
            },
            'label': label,
            'dominantNormal': [0, 0, 0], # not used for now
        })

    # Store the sampled scene to visualize using DVIS
    encoded_scene = {
        'scan_id': scan_id,
        'segGroups': objs, # TODO: add segGroups
    }

    # save the sampled scene to a JSON file (create the folder if it doesn't exist)
    with open(f'datasets/data/gen/{scan_id}_semseg.v2.json', 'w') as file:
        json.dump(encoded_scene, file, indent=2)

In [16]:
for batch in val_dataloader:
    x_batch = batch.x.to(device)[:, C:]
    obj_cond_batch = batch.cond.to(device)
    edge_cond_batch = batch.edge_index.to(device)
    relation_cond_batch = batch.edge_attr.to(device)
    
    # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
    x_batch = x_batch.view(batch.num_graphs, N, D)
    # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
    obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)

    # Run inference
    with torch.no_grad():      
        # Sample from the model (use the same conditioning as the overfitting)
        sampled_scene = scheduler.sample(obj_cond_batch, edge_cond_batch, relation_cond_batch, cond_scale=5.0, return_all_samples=False)
    break

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

RuntimeError: quantile() input tensor must be non-empty

In [None]:
generate_semseg_file(sampled_scene, scan_id=0)

In [None]:
for batch in val_dataloader:
    x_batch = batch.x.to(device)[:, C:]
    obj_cond_batch = batch.cond.to(device)
    edge_cond_batch = batch.edge_index.to(device)
    relation_cond_batch = batch.edge_attr.to(device)
    
    # X is read as [B*N, D] and needs to be reshaped to [B, N, D]
    x_batch = x_batch.view(batch.num_graphs, N, D)
    # obj_cond is read as [B*N, C] and needs to be reshaped to [B, N, C]
    obj_cond_batch = obj_cond_batch.view(batch.num_graphs, N, C)

    # Run inference
    with torch.no_grad():      
        # Sample from the model (use the same conditioning as the overfitting)
        # (!) NOTICE: this will return all the samples from the scheduler
        sampled_scenes = scheduler.sample(obj_cond_batch, edge_cond_batch, relation_cond_batch, cond_scale=5.0, return_all_samples=True)
    break

In [None]:
# Map all the sampled scenes to semseg files
for t, sampled_scene in sampled_scenes:
    generate_semseg_file(sampled_scene, scan_id=t)

## DVIS Visualizer

In [None]:
from dvis import dvis
from mathutils import Matrix
import numpy as np
from scipy.spatial.transform import Rotation
import os
import json

In [None]:
def encode_rotation(normalized_axes, rotation_angle, rotation_axis):
    # Convert rotation angle to radians
    rotation_angle_rad = np.deg2rad(rotation_angle)

    if rotation_axis == 'x':
        rotation_matrix = np.array([
            [1, 0, 0],
            [0, np.cos(rotation_angle_rad), -np.sin(rotation_angle_rad)],
            [0, np.sin(rotation_angle_rad), np.cos(rotation_angle_rad)]
        ])
    elif rotation_axis == 'y':
        rotation_matrix = np.array([
            [np.cos(rotation_angle_rad), 0, np.sin(rotation_angle_rad)],
            [0, 1, 0],
            [-np.sin(rotation_angle_rad), 0, np.cos(rotation_angle_rad)]
        ])
    elif rotation_axis == 'z':
        rotation_matrix = np.array([
            [np.cos(rotation_angle_rad), -np.sin(rotation_angle_rad), 0],
            [np.sin(rotation_angle_rad), np.cos(rotation_angle_rad), 0],
            [0, 0, 1]
        ])
    else:
        raise ValueError("Invalid rotation axis. Supported values are 'x', 'y', and 'z'.")
    
    encoded_normalized_axes = np.dot(normalized_axes, rotation_matrix)

    return encoded_normalized_axes

def translate_corners(corners, translation):
    translated_corners = corners + translation
    return translated_corners

# Unit cube definition
unit_cube_corners = np.array([
    [0, 0, 0],
    [0, 0, 1],
    [1, 0, 0],
    [1, 0, 1],
    
    [0, 1, 0],
    [0, 1, 1],
    [1, 1, 0],
    [1, 1, 1],
])

centroid = np.mean(unit_cube_corners, axis=0)
unit_cube_corners = unit_cube_corners - centroid

# Original normalized_axes matrix representing the unit cube's orientation
normalized_axes = np.array([
    [1, 0, 0],
    [0, 1, 0],
    [0, 0, 1]
])

# Encode degree rotation around the axis
rotation_angle, axis = 0, 'y'
encoded_normalized_axes = encode_rotation(normalized_axes, rotation_angle, axis)

# Apply the encoded rotation to the unit cube corners
unit_cube_corners = np.dot(unit_cube_corners, encoded_normalized_axes)

# Translate the rotated cube
translation = np.array([0, 0, 0])
unit_cube_corners = translate_corners(unit_cube_corners, translation)

# dvis(unit_cube_corners, 'corners', c=-1)

In [None]:
# Specify the path to the dataset folder
dataset_path = 'datasets/data'

def generate_corners(obj):
    obb = obj['obb']
    axes_lengths = obb['axesLengths']
    centroid = obb['centroid']
    normalized_axes = np.reshape(obb['normalizedAxes'], (3, 3))
    
    axes_lengths = np.array(axes_lengths)
    centroid = np.array(centroid)
    normalized_axes = np.array(normalized_axes)
    
    # Swap y and z axes
    # normalized_axes[[1, 2]] = normalized_axes[[2, 1]] # TODO: rotation is off
    axes_lengths[[1, 2]] = axes_lengths[[2, 1]]
    centroid[[1, 2]] = centroid[[2, 1]]

    corners = np.zeros((8, 3))
    for i in range(8):
        corner = unit_cube_corners[i]
        scaled_corner = corner * axes_lengths
        transformed_corner = np.dot(normalized_axes, scaled_corner)
        corners[i] = transformed_corner + centroid

    return corners

def visualize_scene(scene_id, t=0):
    scan_folder_path = os.path.join(dataset_path, scene_id)

    # Check if the folder contains semseg.v2.json file
    semseg_file = os.path.join(scan_folder_path, f'{t}_semseg.v2.json')
    if not os.path.isfile(semseg_file):
        exit(1)

    # Read and parse the semseg.v2.json file
    with open(semseg_file, 'r') as file:
        semseg_data = json.load(file)

    scan_id = semseg_data['scan_id']
    seg_groups = semseg_data['segGroups']
    
    colors_labels_map = {}
    col_index = 0

    for obj in seg_groups:
        # if obj['dominantNormal'][0] != 0:
        #     continue
        
        # if obj['label'] in ['wall', 'floor', 'ceiling']:
        #     continue
        
        corners = generate_corners(obj)
        
        # print(obj['label'])
        
        colors_labels_map[obj['label']] = colors_labels_map.get(obj['label'], col_index)
        col_index += 1
        # Pass the corners to the visualizer
        dvis(corners, "corners", name=obj['label'], c=colors_labels_map[obj['label']], t=t)

In [None]:
scene_id = 'gen'

# Single scene visualization
# visualize_scene(scene_id)

# Visualize certain timesteps in the reverse order (clean to noisy)
for t in range(0, 500, 20):
    visualize_scene(scene_id, t=t)