In [1]:
import graph_tool as gt
import os
import pathlib
import warnings
import numpy as np

import random
import pickle

import torch
torch.cuda.empty_cache()
import hydra
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from src.diffusion import diffusion_utils

import src.utils
from src.metrics.abstract_metrics import TrainAbstractMetricsDiscrete, TrainAbstractMetrics

from src.diffusion_model import LiftedDenoisingDiffusion
from src.diffusion_model_discrete import DiscreteDenoisingDiffusion
from src.diffusion.extra_features import DummyExtraFeatures, ExtraFeatures
import src.utils
from torch_geometric.utils import  to_dense_batch
from src.datasets.schenker_dataset import SchenkerDiffHeteroGraphData
import torch.nn.functional as F
from src.schenker_gnn.config import DEVICE

warnings.filterwarnings("ignore", category=PossibleUserWarning)

torch.set_float32_matmul_precision('medium')


In [2]:
import os
import torch.distributed as dist

# Set environment variables required by the env:// init_method
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

if not dist.is_initialized():
    dist.init_process_group(backend="gloo", init_method="env://", rank=0, world_size=1)


In [3]:
from src.datasets.schenker_dataset import SchenkerGraphDataModule, SchenkerDatasetInfos
from src.analysis.spectre_utils import PlanarSamplingMetrics, SBMSamplingMetrics, Comm20SamplingMetrics
from src.analysis.visualization import NonMolecularVisualization

from hydra import initialize, compose
from omegaconf import OmegaConf

# Initialize Hydra with the desired config path and version_base
with initialize(config_path="../SchenkerDiff/configs", version_base="1.3"):
    # Compose the configuration by specifying the config name
    cfg = compose(config_name="config")

dataset_config = cfg["dataset"]
datamodule = SchenkerGraphDataModule(cfg)
sampling_metrics = PlanarSamplingMetrics(datamodule)

dataset_infos = SchenkerDatasetInfos(datamodule, dataset_config)
train_metrics = TrainAbstractMetricsDiscrete() if cfg.model.type == 'discrete' else TrainAbstractMetrics()
visualization_tools = NonMolecularVisualization()

if cfg.model.type == 'discrete' and cfg.model.extra_features is not None:
    extra_features = ExtraFeatures(cfg.model.extra_features, dataset_info=dataset_infos)
else:
    extra_features = DummyExtraFeatures()
domain_features = DummyExtraFeatures()

dataset_infos.compute_input_output_dims(datamodule=datamodule, extra_features=extra_features,
                                        domain_features=domain_features)

model_kwargs = {'dataset_infos': dataset_infos, 'train_metrics': train_metrics,
                'sampling_metrics': sampling_metrics, 'visualization_tools': visualization_tools,
                'extra_features': extra_features, 'domain_features': domain_features}



In [4]:
from pprint import pprint
pprint(cfg["dataset"])

{'name': 'schenker', 'remove_h': None, 'datadir': 'data/schenker/processed/heterdatacleaned/'}


In [5]:
loaded_model = DiscreteDenoisingDiffusion.load_from_checkpoint(checkpoint_path= "last-v1.ckpt", **model_kwargs)



Marginal distribution of the classes: tensor([9.9998e-07, 9.9998e-07, 2.0061e-01, 9.9998e-07, 9.9998e-07, 9.9998e-07,
        3.1970e-02, 1.5754e-01, 2.7365e-02, 4.5067e-02, 9.0963e-02, 8.3377e-02,
        1.7351e-01, 9.9998e-07, 1.4494e-01, 4.4652e-02, 9.9998e-07, 9.9998e-07]) for nodes, tensor([7.5582e-01, 2.8959e-02, 2.8959e-02, 2.2831e-02, 2.2831e-02, 4.7389e-02,
        4.0524e-02, 4.0528e-02, 1.1330e-02, 1.5548e-05, 1.5548e-05, 2.8430e-04,
        2.3988e-04, 2.3988e-04, 3.5537e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]) for edges


In [None]:
# Run this if you want to fine tune: 

# name = cfg.general.name
# use_gpu = cfg.general.gpus > 0 and torch.cuda.is_available()
# datamodule_tune = SchenkerGraphDataModule(cfg, is_tune = True)
# trainer = Trainer(gradient_clip_val = cfg.train.clip_grad,
#                 strategy = "ddp_find_unused_parameters_true",  # Needed to load old checkpoints
#                 accelerator = 'gpu' if use_gpu else 'cpu',
#                 devices = cfg.general.gpus if use_gpu else 1,
#                 max_epochs = cfg.train.n_epochs*2,
#                 check_val_every_n_epoch = cfg.general.check_val_every_n_epochs,
#                 fast_dev_run = cfg.general.name == 'debug',
#                 enable_progress_bar = True,
#                 log_every_n_steps = 50 if name != 'debug' else 1,
#                 logger = [])

# trainer.fit(loaded_model, datamodule = datamodule_tune, ckpt_path = cfg.general.resume)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


: 

: 

In [7]:
batch_size = 100
keep_chain = 10
number_chain_steps = 99
save_final = 10
use_rules = True
num_nodes=None

In [8]:
def sample_r_E(batch_size):
    """
    Samples `batch_size` random pickle files from the directory 
    where i is a random integer between 0 and 1080. Each pickle file contains a dictionary 
    that is converted to a PyG Data object using the pre-defined function `hetero_to_data`.
    
    The PyG Data object is expected to have at least the following attributes:
    - x: Tensor of node features with shape (num_nodes, feature_dim)
    - edge_index: LongTensor with shape (2, num_edges)
    - edge_attr: Tensor with shape (num_edges, 2) representing edge attributes
    - r: Tensor with shape (num_nodes, dr) representing additional node-level features
    
    For each sample, the function creates:
    - An adjacency tensor E_sample of shape (n_nodes, n_nodes, 2) where each edge's attribute 
        is placed at the corresponding (u, v) location. If the original graph has fewer than 
        `n_nodes` nodes, the tensors are padded with zeros; if it has more, they are truncated.
    - A node feature tensor r_sample of shape (n_nodes, dr) similarly padded or truncated.
    
    Finally, the function stacks these into:
    - E_tensor: Tensor of shape (batch_size, n_nodes, n_nodes, 2)
    - r_tensor: Tensor of shape (batch_size, n_nodes, dr)
    
    Returns:
        E_tensor, r_tensor
    """
    E_list = []
    r_list = []
    name_list = []
    node_sizes = []

    # get samples from OOS distribution
    np.random.seed(42)
    n_samples = 1780

    # Randomly select 90 indices for the test set
    test_indices = np.random.choice(n_samples, 200, replace=False)
    # test_indices = np.array([1])
    
    for i in range(batch_size):
        # Select a random index between 0 and 1080 (inclusive)
        
        idx = test_indices[i]
        file_path = f"data/schenker/processed/heterdatacleaned/processed/{idx}_processed.pt"
        
        # Load the pickle file containing a dictionary
        data_dict = torch.load(file_path)
        
        # Convert dictionary to a PyG Data object using the provided function
        data = SchenkerDiffHeteroGraphData.hetero_to_data(data_dict)
        
        # Determine the actual number of nodes in the current sample
        m = data.x.shape[0]
        
        
        # Initialize an adjacency tensor for this sample
        E_sample = torch.zeros((m, m, 30))
        # Fill in the edge attributes: iterate over each edge
        for i in range(data.edge_index.shape[1]):
            u = data.edge_index[0, i].item()
            v = data.edge_index[1, i].item()
            # Only consider nodes within the allowed range (pad/truncate as needed)
            if u < m and v < m:
                E_sample[u, v, :] = data.edge_attr[i, :]
                
        # Process the r tensor (node-level additional features)
        dr = data.r.shape[1]  # feature dimension of r
        r_sample = torch.zeros((m, dr))
        # Copy available node features; pad with zeros if necessary or truncate if too many nodes
        r_sample[:m, :] = data.r[:m, :]
        
        # Append this sample's results to the lists
        E_list.append(E_sample)
        r_list.append(r_sample)
        name_list.append(data_dict['name'])
        node_sizes.append(m)
    
    # Stack all samples to form the batch tensors
    # Determine the maximum number of nodes in the batch
    max_nodes = max(tensor.shape[0] for tensor in r_list)

    # Pad the E_list tensors to shape (max_nodes, max_nodes, 3)
    E_padded = []
    for e in E_list:
        n = e.shape[0]
        # F.pad expects pad in the format: (pad_last_dim_left, pad_last_dim_right,
        # pad_second_last_dim_left, pad_second_last_dim_right, ...)
        # For a tensor of shape (n, n, 3): pad last dimension (3) with (0,0),
        # second dimension with (0, max_nodes-n), and first dimension with (0, max_nodes-n).
        pad_amount = (0, 0, 0, max_nodes - n, 0, max_nodes - n)
        E_padded.append(F.pad(e, pad_amount))

    # Stack the padded tensors along a new batch dimension
    E_tensor = torch.stack(E_padded, dim=0)  # Shape: (batch_size, max_nodes, max_nodes, 3)

    # Pad the r_list tensors to shape (max_nodes, dr)
    r_padded = []
    for r in r_list:
        n = r.shape[0]
        # For a tensor of shape (n, dr), pad the first dimension with (0, max_nodes-n)
        pad_amount = (0, 0, 0, max_nodes - n)
        r_padded.append(F.pad(r, pad_amount))

    # Stack the padded tensors along the batch dimension
    r_tensor = torch.stack(r_padded, dim=0)  # Shape: (batch_size, max_nodes, dr)
    
    return E_tensor, r_tensor, name_list, node_sizes

In [9]:
num_diff_depth_1 = 100
num_diff_depth_2 = 100
num_diff_depth_3 = 100

In [10]:
"""
:param batch_id: int
:param batch_size: int
:param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
:param save_final: int: number of predictions to save to file
:param keep_chain: int: number of chains to save to file
:param keep_chain_steps: number of timesteps to save for each chain
:return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
"""
E, r, names, n_nodes_list = sample_r_E(batch_size)
print(E.shape)
num_nodes = torch.tensor([int(x) for x in n_nodes_list]).to(loaded_model.device)
if num_nodes is None:
    n_nodes = loaded_model.node_dist.sample_n(batch_size, loaded_model.device)
elif type(num_nodes) == int:
    n_nodes = num_nodes * torch.ones(batch_size, device=loaded_model.device, dtype=torch.int)
else:
    assert isinstance(num_nodes, torch.Tensor)
    n_nodes = num_nodes
n_max = torch.max(n_nodes).item()
# Build the masks
arange = torch.arange(n_max, device=loaded_model.device).unsqueeze(0).expand(batch_size, -1)
node_mask = arange < n_nodes.unsqueeze(1)

# Sample a piece, and use the R matrix from that
# Get a random sample from the data
# pass through Stephen's script to get the S matrix, and the R matrix through the data processing (process_file_for_GUI)

# Sample noise  -- z has size (n_samples, n_nodes, n_features)
z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=loaded_model.limit_dist, node_mask=node_mask)
X, _, y = z_T.X, z_T.E, z_T.y

E_transpose = E.permute(0, 2, 1, 3)  # Shape remains (bs, n_nodes, n_nodes, 2)

# Symmetrize using max operation (ensures strongest connection remains)
E = torch.maximum(E, E_transpose).to(DEVICE)     
r = r.to(DEVICE)

assert (E == torch.transpose(E, 1, 2)).all()
assert number_chain_steps < loaded_model.T
chain_X_size = torch.Size((number_chain_steps, keep_chain, X.size(1)))
chain_E_size = torch.Size((number_chain_steps, keep_chain, E.size(1), E.size(2)))

chain_X = torch.zeros(chain_X_size)
chain_E = torch.zeros(chain_E_size)

# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in reversed(range(0, loaded_model.T)):
    s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
    t_array = s_array + 1
    s_norm = s_array / loaded_model.T
    t_norm = t_array / loaded_model.T

    # Sample z_s
    if use_rules:
        sampled_s, discrete_sampled_s = loaded_model.sample_p_zs_given_zt_with_rules(s_norm, t_norm, X, E, r, y, node_mask)
    else:
        sampled_s, discrete_sampled_s = loaded_model.sample_p_zs_given_zt(s_norm, t_norm, X, E, r, y, node_mask)

    X, _, y = sampled_s.X, sampled_s.E, sampled_s.y

    discrete_sampled_s_E, _ = loaded_model.apply_node_mask_E_r(E,r, node_mask)
    
    # Save the first keep_chain graphs
    write_index = (s_int * number_chain_steps) // loaded_model.T
    chain_X[write_index] = discrete_sampled_s.X[:keep_chain]
    chain_E[write_index] = discrete_sampled_s_E[:keep_chain]


# # Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)

# unique_depths = sorted({int(d.item()) for d in torch.unique(r[:, :, -1])}, reverse=True)

# # We'll accumulate updates into these tensors
# X_acc = X.clone()
# E_acc = E.clone()
# r_acc = r.clone()


# for depth in unique_depths:
#     steps = 200
#     # Build a mask of nodes at exactly this structural depth
#     depth_mask = (r_acc[:, :, -1] >= depth)

#     # print(depth_mask.shape)
#     # print(node_mask.shape)
#     depth_mask = depth_mask & node_mask
#     if not depth_mask.any():
#         continue

#     # Extract the subgraph at this depth
#     # apply_node_mask_E_r will zero out everything except the masked nodes & edges
#     # E_sub, r_sub = loaded_model.apply_node_mask_E_r(E_acc, r_acc, depth_mask)
#     # X_sub = X_acc.clone()
#     X_sub = X.clone()
#     E_sub = E.clone()
#     r_sub = r.clone()


#     # Run diffusion on this subgraph for `steps` timesteps
#     for s_int in reversed(range(0, loaded_model.T)):
#         s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
#         t_array = s_array + 1
#         s_norm = s_array / loaded_model.T
#         t_norm = t_array / loaded_model.T
#         # print(E_sub.shape)
#         sampled_sub, _ = loaded_model.sample_p_zs_given_zt(
#             s_norm, t_norm, X_sub, E_sub, r_sub, y, depth_mask
#         )
#         # Unpack the results
#         X_sub = sampled_sub.X
#         discrete_sampled_s_E, _ = loaded_model.apply_node_mask_E_r(E_sub, r_sub, depth_mask)

#     # Integrate the updated subgraph back into the full graph
#     # Update node features
#     print( X_acc[depth_mask].shape)
#     print(X_sub[depth_mask].shape)

#     # ensure both sides are floats
#     X_sub = X_sub.to(X_acc.dtype)
#     X_acc[depth_mask] = X_sub[depth_mask]

#     # and likewise for E_sub/E_acc
#     E_sub = E_sub.to(E_acc.dtype)
#     E_acc[:, depth_mask.squeeze(), :][:, :, depth_mask.squeeze()] = E_sub[:, depth_mask.squeeze(), :][:, :, depth_mask.squeeze()]


#     # X_acc[depth_mask] = X_sub[depth_mask]
#     # Update edges among these nodes
#     # E_acc[:, depth_mask, :][:, :, depth_mask] = E_sub[:, depth_mask, :][:, :, depth_mask]
#     # Update rhythm/depth tensor for these nodes
#     r_sub = r_sub.to(r_acc.dtype)
#     r_acc[depth_mask] = r_sub[depth_mask]

# # Replace the originals with the depth‐driven updated versions
# X, E, r = X_acc, E_acc, r_acc


# sampled_s = sampled_sub.mask(depth_mask, collapse=True)

# --- End of structural‐depth‐driven block ---


#[TODO] replace the sampling loop above with the following:
# Also remember to implement dataprocessing and the model such that it can take in depth information d

# For each sampled piece with noisy nodes:
# For each edge type of depth d, itterating from max(d) to min(d)

# 1. Get only the nodes that are connected by structural edges of depth d from the last column of the R tensor
# 2. Use it to create a noisy subgraph defined by X', E', R'. If X', E', R' already exists, append the new nodes and edges to it
# 3. Feed X', E', R' into the Diffusion model with num_diff_step_d steps
# 4. Save the resulting graph as X', E', R'
# 5. Repeat until all nodes and edges in the original graph has been added





X, _, y = sampled_s.X, sampled_s.E, sampled_s.y

E, _ = loaded_model.apply_node_mask_E_r(E,r, node_mask)

# Prepare the chain for saving
if keep_chain > 0:
    final_X_chain = X[:keep_chain]
    final_E_chain = E[:keep_chain]

    chain_X[0] = final_X_chain                  # Overwrite last frame with the resulting X, E
    chain_E[0] = final_E_chain

    chain_X = diffusion_utils.reverse_tensor(chain_X)
    chain_E = diffusion_utils.reverse_tensor(chain_E)

    # Repeat last frame to see final sample better
    chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1, 1)], dim=0)
    chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1, 1)], dim=0)
    assert chain_X.size(0) == (number_chain_steps + 10)

molecule_list = []
for i in range(batch_size):
    n = n_nodes[i]
    atom_types = X[i, :n].cpu()
    edge_types = E[i, :n, :n].cpu()
    rhythm_types = r[i, :n, :].cpu()
    sample_names = names[i]
    molecule_list.append([atom_types, edge_types, rhythm_types, sample_names])

# Visualize chains
# if loaded_model.visualization_tools is not None:
#     loaded_model.print('Visualizing chains...')
#     current_path = os.getcwd()
#     num_molecules = chain_X.size(1)       # number of molecules
#     for i in range(num_molecules):
#         result_path = os.path.join(current_path, f'chains/{loaded_model.cfg.general.name}/'
#                                                     f'epoch{loaded_model.current_epoch}/')
#         if not os.path.exists(result_path):
#             os.makedirs(result_path)
#             _ = loaded_model.visualization_tools.visualize_chain(result_path,
#                                                             chain_X[:, i, :].numpy(),
#                                                             chain_E[:, i, :].numpy())
#         loaded_model.print('\r{}/{} complete'.format(i+1, num_molecules), end='', flush=True)
#     loaded_model.print('\nVisualizing molecules...')

#     # Visualize the final molecules 
#     result_path = os.path.join(current_path,
#                                 f'graphs/{loaded_model.name}/epoch{loaded_model.current_epoch}/')
#     loaded_model.visualization_tools.visualize(result_path, molecule_list, save_final)
#     loaded_model.print("Done.")



torch.Size([100, 37, 37, 30])
[-30, -44, -42, -36, -37, -46, -35, -27]
[-32, -33, -33, -34, -29, -37, -35, -39]
[-35, -34, -35, -31, -37, -35, -33, -23]
[-21, -26, -37, -21, -29, -29, -40, -35]
[-28, -39, -35, -23, -31, -29, -26, -35]
[-32, -38, -27, -39, -26, -25, -25, -22]
[-40, -33, -34, -31, -35, -32, -36, -38]
[-30, -40, -18, -39, -31, -40, -28, -39]
[-43, -31, -32, -36, -34, -40, -25, -37]
[-28, -52, -42, -38, -33, -40, -41, -43]
[-29, -41, -34, -26, -30, -38, -31, -25]
[-34, -28, -32, -40, -33, -23, -26, -23]
[-30, -36, -39, -36, -46, -30, -39, -38]
[-31, -27, -29, -42, -25, -38, -46, -20]
[-33, -43, -23, -27, -37, -35, -26, -34]
[-36, -40, -32, -39, -43, -33, -31, -47]
[-41, -27, -26, -31, -41, -38, -32, -31]
[-39, -28, -35, -31, -31, -36, -36, -41]
[-29, -37, -34, -27, -33, -33, -47, -30]
[-34, -35, -36, -27, -33, -41, -34, -38]
[-32, -34, -24, -29, -31, -31, -31, -24]
[-31, -27, -37, -34, -34, -21, -28, -26]
[-37, -25, -34, -30, -37, -25, -32, -34]
[-22, -32, -45, -36, -43, -

In [11]:
len(molecule_list)

100

In [12]:
samples = molecule_list
filename = f'generated_samples1.txt'


with open(filename, 'w') as f:
    for item in samples:
        f.write(f"N={item[0].shape[0]}\n")
        atoms = item[0].tolist()
        f.write("X: \n")
        for at in atoms:
            f.write(f"{at} ")
        f.write("\n")
        f.write("E: \n")
        for bond_list in item[1]:
            for bond in bond_list:
                f.write(f"{int(bond)} ")
            f.write("\n")
        f.write("R: \n")
        for r_list in item[2]:
            for r in r_list:
                f.write(f"{r} ")
            f.write("\n")
        for name in item[3]:
            f.write(name)
        f.write("\n")
        f.write("\n")
print("Generated graphs Saved. ")

Generated graphs Saved. 
