<a href="https://colab.research.google.com/github/gerritgr/Alia/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 💊🌀 MoleculeDiffusionGAN 🌀💊

## Setup

In [None]:
# This is used for the naming of files and folders
PROJECT_NAME = "MoldDiffGAN_gcnweak3"
PATH_PATTERN_BASE = "moldiffusion"
PATH_PATTERN = PATH_PATTERN_BASE

# Setting BASELINE to True would deactivate the discriminator.
BASELINE = False
DEBUG = False


### Handle Colab

On Colab, we need to install some additional packages.
If running on Colab, we use Google Drive to store results.

In [None]:
import os
import torch

# Check for Google Colab and WandB
USE_COLAB = False
try:
  from google.colab import drive
  USE_COLAB = True
except:
  pass

try:
  import wandb # need to do this before chaning CWD
except:
  os.system("pip install wandb")

# Load Google Drive
if USE_COLAB:
  if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')
  dir_path = f'/content/drive/MyDrive/colab/{PROJECT_NAME}/'
  if not os.path.exists(dir_path):
    os.makedirs(dir_path)
  print("Current Working Directory: ", os.getcwd())
  if os.getcwd() != dir_path:
    os.chdir(dir_path)
    print("New Working Directory: ", os.getcwd())


torch_version = torch.__version__.split("+")
try:
  import torch_geometric
except:
  os.system("pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html")
  os.system("pip install torch-geometric")

try:
  import rdkit
except:
  os.system("pip install rdkit")


Mounted at /content/drive
Current Working Directory:  /content
New Working Directory:  /content/drive/MyDrive/colab/MoldDiffGAN_gcn


### Imports

In [None]:
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 100  # Set this to 300 to get better image quality
import seaborn as sns

import networkx as nx
import glob
import random
import os
import traceback
import time
import copy
import pickle
import numpy as np
import math
from tqdm import tqdm
import gzip

from rdkit import Chem
from rdkit.Chem import Draw

import torch
from torch import nn
from torch.optim import Adam
from torch.nn import Sequential as Seq
from torch.nn import Linear as Lin
import torch.nn.functional as F
import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    PNA,
    GATv2Conv,
    GraphNorm,
    BatchNorm,
    global_mean_pool,
    global_add_pool
)
from torch_geometric.utils import erdos_renyi_graph, to_networkx, from_networkx, degree

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [None]:
# Load code to convert molecules to pyg tensors if using Colab
if USE_COLAB and not os.path.exists("smiles_to_pyg"):
  os.system("git clone https://github.com/gerritgr/MoleculeDiffusionGAN.git && cp -R MoleculeDiffusionGAN/* .")

from smiles_to_pyg.molecule_load_and_convert import *

## Hyperparams

In [None]:
##
## Diffusion
##
TIMESTEPS = 1000
START = 0.0001
END = 0.015

# Training
BATCH_SIZE = 256
GAMMA = 0.2

##
## Prediction/Denoising
##
LEARNING_RATE_GEN = 0.001
EPOCHS_GEN = 60

# PNA Pred
DROPOUT_PRED = 0.05
DEPTH_PRED = 6
HIDDEN_CHANNELS_PRED = 32
TOWERS_PRED = 2
NORMALIZATION_PRED = True

##
## Discriminator
##
EPOCHS_DISC_MODEL = 50
DISC_NOISE = 0.3

# PNA Disc
HIDDEN_CHANNELS_DISC = 4
DEPTH_DISC = 3 # 4 in original
DROPOUT_DISC = 0.05 # 0.03 in original
NORMALIZATION_DISC = True

##
## Molecule Encoding
##
INDICATOR_FEATURE_DIM = 1
FEATURE_DIM = 5  # (has to be the same for atom and bond)
ATOM_FEATURE_DIM = FEATURE_DIM
BOND_FEATURE_DIM = FEATURE_DIM
NON_NODES = [True] + [False] * 5 + [True] * 5
NON_EDGES = [True] + [True] * 5 + [False] * 5

TIME_FEATURE_DIM = 1


## Utils

In [None]:
def log(d):
  try:
    import wandb
    wandb.log(d)
  except:
    print(d)


def load_file(filepath):
  print("Trying to read", filepath)
  try:
    with gzip.open(filepath, 'rb') as f:
      return pickle.load(f)
  except Exception as e:
    print(f"An error occurred: {str(e)}")
    raise


def write_file(filepath, data):
  try:
    data = data.cpu()
  except:
    pass
  print("Trying to write", filepath)
  with gzip.open(filepath, 'wb') as f:
    pickle.dump(data, f)


In [None]:
def build_dataset(seed=1234):
  try:
    dataset_train, dataset_test = load_file('dataset.pickle')
    if DEBUG:
      return dataset_train[:len(dataset_train) // 10], dataset_test[:len(dataset_test) // 10]
    return dataset_train, dataset_test
  except Exception as e:
    print(f"Could not load dataset due to error: {str(e)}, generate it now")

  dataset = read_qm9()
  dataset_all = [g for g in dataset if g.x.shape[0] > 1]
  dataset = list()

  for g in tqdm(dataset_all):
    try:
      assert "None" not in str(pyg_to_smiles(g))
      dataset.append(g)
    except:
      pass

  print(f"Built and cleaned dataset, length is {len(dataset)}, old length was {len(dataset_all)}")
  random.Random(seed).shuffle(dataset)
  split = int(len(dataset) * 0.8 + 0.5)
  dataset_train = dataset[:split]
  dataset_test = dataset[split:]
  assert(dataset_train[0].x[0, :].numel() == INDICATOR_FEATURE_DIM + ATOM_FEATURE_DIM + BOND_FEATURE_DIM)

  write_file("dataset.pickle", (dataset_train, dataset_test))
  return dataset_train, dataset_test


In [None]:
def generate_schedule(start = START, end = END, timesteps=TIMESTEPS):
  """
  Generates a schedule of beta and alpha values for a forward process.

  Args:
  start (float): The starting value for the beta values. Default is START.
  end (float): The ending value for the beta values. Default is END.
  timesteps (int): The number of timesteps to generate. Default is TIMESTEPS.

  Returns:
  tuple: A tuple of three tensors containing the beta values, alpha values, and
  cumulative alpha values (alpha bars).
  """
  betas = torch.linspace(start, end, timesteps, device = DEVICE)
  assert(betas.numel() == TIMESTEPS)
  return betas

In [None]:
def visualize_smiles_from_file(filepath):
    print("Visualize molecules.")
    # Read SMILES from file
    with open(filepath, 'r') as file:
        smiles_list = [line.split("'")[1] for line in file.readlines() if "'" in line]

    # Convert SMILES to RDKit Mol objects, filtering out invalid ones
    mols = [Chem.MolFromSmiles(smile) for smile in smiles_list[:100]]
    mols = [mol for mol in mols if mol is not None]

    if len(mols) == 0:
        return

    # Determine grid size
    num_mols = len(mols)
    cols = 10
    rows = min(10, -(-num_mols // cols))  # ceil division

    # Create a subplot grid
    fig, axs = plt.subplots(rows, cols, figsize=(20, 20),
                            gridspec_kw={'wspace': 0.3, 'hspace': 0.3})

    for i in range(rows):
        for j in range(cols):
            ax = axs[i, j]
            ax.axis("off")  # hide axis
            idx = i * cols + j  # index in mols list
            if idx < num_mols:
                img = Draw.MolToImage(mols[idx], size=(200, 200))
                ax.imshow(img)
            else:
                break

    # Save the figure
    plt.savefig(filepath + '.jpg', format='jpg', bbox_inches='tight')

    time.sleep(0.01)
    try:
        wandb.log_artifact(filepath + '.jpg', name=f"jpg_{SWEEP_ID}_{filepath.replace('.','')}", type="smiles_grid_graph")
    except:
        pass

In [None]:
def get_pred_from_noise(noise_pred, x_with_noise, future_t):
  row_num = x_with_noise.shape[0]
  betas = generate_schedule()
  alphas = 1. - betas
  alphas_cumprod = torch.cumprod(alphas, axis=0)
  alphabar_t = torch.gather(alphas_cumprod, 0, future_t).view(row_num, 1)

  scaled_noise = torch.sqrt(1.0 - alphabar_t)
  x_without_noise = x_with_noise - scaled_noise * noise_pred
  x_without_noise = x_without_noise / torch.sqrt(alphabar_t)
  return x_without_noise


def get_noise_from_pred(original_pred, x_with_noise, future_t):
  row_num = x_with_noise.shape[0]
  betas = generate_schedule()
  alphas = 1. - betas
  alphas_cumprod = torch.cumprod(alphas, axis=0)
  alphabar_t = torch.gather(alphas_cumprod, 0, future_t).view(row_num, 1)

  scaled_noise = torch.sqrt(alphabar_t)
  noise = x_with_noise - scaled_noise * original_pred
  noise = noise / torch.sqrt(1.0 - alphabar_t)

  return noise


In [None]:
def log_smiles(smiles, filename):
  try:
    with open(filename, "w") as file:
      for string in smiles:
        file.write(str(string) + "\n")

    try:
      wandb.log_artifact(filename, name=f"src_txt_{SWEEP_ID}_{filename}", type="smiles")
    except Exception as e:
      print(e)

    time.sleep(0.01)
    visualize_smiles_from_file(filename)
  except Exception as e:
    print("An error occurred during training: \n", str(e))
    traceback.print_exc()


## Forward Process

In [None]:
def forward_diffusion(node_features, future_t):
  """
  Performs a forward diffusion process on an node_features tensor.
  Each row can theoreetically have its own future time point.
  Implements the second equation from https://youtu.be/a4Yfz2FxXiY?t=649
  """
  row_num = node_features.shape[0]

  if "class 'int'" in str(type(future_t)) or "class 'float'" in str(type(future_t)):
    future_t = torch.tensor([int(future_t)] * row_num).to(DEVICE)

  feature_dim = node_features.shape[1]
  future_t = future_t.view(-1)
  assert(row_num == future_t.numel())
  assert(future_t[0] == future_t[1]) # Let's assume they belong to the same graph.

  betas = generate_schedule()

  noise = torch.randn_like(node_features, device=DEVICE)
  alphas = 1. - betas
  alphas_cumprod = torch.cumprod(alphas, axis=0)
  alphabar_t = torch.gather(alphas_cumprod, 0, future_t).view(row_num, 1)
  assert(alphabar_t.numel() == row_num)

  new_node_features_mean = torch.sqrt(alphabar_t) * node_features # Column-wise multiplication, now it is a matrix
  assert(new_node_features_mean.shape == node_features.shape)
  new_node_features_std = torch.sqrt(1.-alphabar_t) # This is a col. vector
  new_node_features_std = new_node_features_std.repeat(1,feature_dim) # This is a matrix
  assert(new_node_features_mean.shape == new_node_features_std.shape)
  noisey_node_features =  new_node_features_mean + new_node_features_std * noise

  return noisey_node_features, noise

#forward_diffusion(torch.tensor([1,2,3.], device=DEVICE).view(3,1), torch.tensor([0,0,999], device=DEVICE)), print(""), forward_diffusion(torch.tensor([1,2,3.], device=DEVICE).view(3,1), torch.tensor([999,999,999], device=DEVICE))

## Denoising NN

In [None]:
def dataset_to_degree_bin(train_dataset):
  """
  Convert a dataset to a histogram of node degrees (in-degrees).
  Load from file if available; otherwise, compute from the dataset.
  """
  try:
    # Attempt to load the degree histogram from a file.
    deg = load_file('deg.pickle')
    deg = deg.to(DEVICE)
    return deg
  except Exception as e:
    print(f"Could not find degree bin due to error: {str(e)}, generate it now")

  # Assert that the dataset is provided.
  assert(train_dataset is not None)

  # Compute the maximum in-degree in the training data.
  max_degree = -1
  for data in train_dataset:
    data = data.to(DEVICE)
    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
    max_degree = max(max_degree, int(d.max()))

  # Create an empty histogram for degrees.
  deg = torch.zeros(max_degree + 1, dtype=torch.long, device=DEVICE)

  # Populate the histogram with data from the dataset.
  for data in train_dataset:
    data = data.to(DEVICE)
    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())

  # Save the computed histogram to a file.
  write_file("deg.pickle", deg.cpu())

  return deg


In [None]:
class PNAnet(torch.nn.Module):
  def __init__(self, train_dataset=None, hidden_channels=HIDDEN_CHANNELS_PRED, depth=DEPTH_PRED, dropout=DROPOUT_PRED, towers=TOWERS_PRED, normalization=NORMALIZATION_PRED, pre_post_layers=1):
    super(PNAnet, self).__init__()
    self.sigmoid = nn.Sigmoid()

    # Adjust hidden channels for the given towers.
    hidden_channels = towers * ((hidden_channels // towers) + 1) # must match

    # Calculate input and output channels.
    in_channels = INDICATOR_FEATURE_DIM + ATOM_FEATURE_DIM + BOND_FEATURE_DIM + TIME_FEATURE_DIM
    out_channels = FEATURE_DIM

    # Get degree histogram for the dataset
    deg = dataset_to_degree_bin(train_dataset)

    # Set aggregators and scalers for the PNA layer.
    aggregators = ['mean', 'min', 'max', 'std']
    scalers = ['identity', 'amplification', 'attenuation']

    # Create a normalization layer if required.
    self.normalization = BatchNorm(hidden_channels) if normalization else None

    # Define the PNA layer.
    self.pnanet = PNA(
        in_channels=in_channels,
        hidden_channels=hidden_channels,
        out_channels=hidden_channels,
        num_layers=depth,
        aggregators=aggregators,
        scalers=scalers,
        deg=deg,
        dropout=dropout,
        towers=towers,
        norm=self.normalization,
        pre_layers=pre_post_layers,
        post_layers=pre_post_layers
    )

    # Define the final MLP layer.
    self.final_mlp = Seq(
        Lin(hidden_channels, hidden_channels),
        nn.ReLU(),
        Lin(hidden_channels, hidden_channels),
        nn.ReLU(),
        Lin(hidden_channels, hidden_channels),
        nn.ReLU(),
        Lin(hidden_channels, out_channels)
    )

  def forward(self, x_in, t, edge_index):
    """
    Perform a forward pass through the PNAnet.
    """
    row_num = x_in.shape[0]
    t = t.view(-1, TIME_FEATURE_DIM)
    x = torch.concat((x_in, t), dim=1)

    x = self.pnanet(x, edge_index)
    x = self.final_mlp(x)

    # Assertions for sanity checks
    assert(x.numel() > 1)
    assert(x.shape[0] == row_num)

    return x


In [None]:
def load_latest_checkpoint(model, optimizer, loss_list, epoch_i, path_pattern_checkpoint=None):
  """
  Load the latest checkpoint from the disk.
  """
  if path_pattern_checkpoint is None:
    path_pattern_checkpoint = PATH_PATTERN + "_model_epoch_*.pth"

  try:
    checkpoint_paths = sorted(glob.glob(path_pattern_checkpoint))
    if len(checkpoint_paths) == 0:
      return model, optimizer, loss_list, epoch_i

    latest_checkpoint_path = checkpoint_paths[-1]
    checkpoint = torch.load(latest_checkpoint_path, map_location=DEVICE)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_i = checkpoint['epoch']
    loss_list = checkpoint['loss_list']

    print(f"Loaded checkpoint of epoch {epoch_i:08} from disk.")
  except Exception as e:
    print(f"Failed to load checkpoint. Error: {str(e)}")

  return model, optimizer, loss_list, epoch_i

def save_model(model, optimizer, loss_list, epoch_i, upload=False):
  """
  Save the model state to the disk.
  """
  if epoch_i == 0: # Relevant for load_base_model()
    return

  save_path = f"{PATH_PATTERN}_model_epoch_{epoch_i:08}.pth" # Will do lexicographical ordering to load.

  # Save the model and optimizer state dicts in a dictionary.
  torch.save({
    'epoch': epoch_i,
    'loss_list': loss_list,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
  }, save_path)

  if upload:
    try:
      wandb.log_artifact(save_path, name=f"weights_{SWEEP_ID}_{epoch_i:08}_weightfile", type="weight")
    except Exception as e:
      print(f"Failed to upload model. Error: {str(e)}")


In [None]:
def load_base_model(dataset_train, path_pattern_checkpoint=None):
  model_base = PNAnet(dataset_train)
  model_base = model_base.to(DEVICE)
  loss_list = None
  optimizer = Adam(model_base.parameters(), lr = LEARNING_RATE_GEN)
  model_base, optimizer, loss_list, epoch_start = load_latest_checkpoint(model_base, optimizer, loss_list, epoch_i=0, path_pattern_checkpoint=path_pattern_checkpoint)

  return model_base

## Inference / Reverse Process

There is a _normal_ and a _restart_ method for inference. The restart version is not implemented in this notebook.

In [None]:
def denoise_one_step(model, g, i):
  """
  Performs one step of denoising using the provided model.
  """
  row_num = g.x.shape[0]

  # Generate and calculate betas, alphas, and related parameters
  betas = generate_schedule()
  t = TIMESTEPS - i - 1  # i=0 indicates full noise
  beta_t = betas[t]
  alphas = 1. - betas
  alphas_cumprod = torch.cumprod(alphas, axis=0)
  alphas_cumprod_t = alphas_cumprod[t]
  sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1. - alphas_cumprod_t)
  sqrt_recip_alphas_t = torch.sqrt(1.0 / alphas[t])
  alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

  # Create the mask
  mask = torch.concat(
      (torch.tensor([False] * g.x_old.shape[0], device=DEVICE).view(-1, 1),
       g.x_old[:, 1:] > -0.5),
      dim=1
  )

  # Define future_t for the model predictions
  future_t = torch.tensor([float(t)] * g.x.shape[0], device=DEVICE).view(-1, 1)
  original_pred = model(g.x, future_t, g.edge_index)

  # Extract noisy values and predict noise
  x_with_noise = g.x[mask].view(row_num, -1)
  future_t = torch.tensor([int(t)] * g.x.shape[0], device=DEVICE).view(-1)
  noise_pred = get_noise_from_pred(original_pred, x_with_noise, future_t)

  # Set endpoints values
  values_now = g.x[mask].view(row_num, -1)
  values_endpoint = noise_pred.view(row_num, -1)
  assert values_now.shape == values_endpoint.shape

  # Compute denoised values
  model_mean = sqrt_recip_alphas_t * (values_now - beta_t * values_endpoint / sqrt_one_minus_alphas_cumprod_t)
  values_one_step_denoised = model_mean  # in case that t == 0

  if t != 0:
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)  # in the paper this is in 3.2. Note that sigma^2 is variance, not std.
    posterior_std_t = torch.sqrt(posterior_variance[t])
    noise = torch.randn_like(values_now, device=DEVICE)
    values_one_step_denoised = model_mean + posterior_std_t * noise

  # Clone and update with denoised values
  denoised_x = g.x.clone()
  denoised_x[mask] = values_one_step_denoised.flatten()

  return denoised_x


In [None]:
def overwrite_with_noise(g):
  g.x_old = g.x.clone()
  mask = torch.concat((torch.tensor([False]*g.x_old.shape[0], device=DEVICE).view(-1,1), g.x_old[:,1:]>-0.5), dim=1)
  g.x[mask] = torch.randn_like(g.x[mask])
  return g


In [None]:
@torch.inference_mode()
def generate_examples(model, dataset_train, num=100, restart_inference_method=False):
  """
  Generate graph samples in batches using the provided model.
  """
  # Setup
  print("generate samples batched")
  model.eval()
  dataset_train_start = list()

  while len(dataset_train_start) < num:
    g = dataset_train[random.choice(range(len(dataset_train)))]
    dataset_train_start.append(g.clone().to(DEVICE))

  #old
  #while len(dataset_train_start) < num:
  #  g = dataset_train[random.sample(range(len(dataset_train)),1)[0]]
  #  dataset_train_start.append(g.clone().to(DEVICE))
  #  g = dataset_train_start[-1]

  assert(len(dataset_train_start) == num)
  dataloader = DataLoader(dataset_train_start, batch_size=num)

  # Inference
  for g in dataloader:
    g = g.to(DEVICE)
    print("load g", g, g.batch)
    g = overwrite_with_noise(g)

    for i in tqdm(range(TIMESTEPS)):
      t = int(TIMESTEPS - i - 1)
      if restart_inference_method:
        x_with_less_noise = denoise_one_step_restart(model, g, i) # not implemented
      else:
        x_with_less_noise = denoise_one_step(model, g, i)
      g.x = x_with_less_noise

    graph_list = g.to_data_list()
    graph_list = [g.cpu() for g in graph_list]

    print("generated graphs ", graph_list[:10])
    return graph_list


In [None]:
from tqdm import tqdm

def find_frac_correct(graphs):
  """
  Determine the fraction and unique of correct graphs based on their conversion to SMILES.
  """
  correct = 0
  smiles_list = list()

  for i, g in tqdm(enumerate(graphs)):
    smiles = pyg_to_smiles(g)
    if smiles and '.' not in smiles:
      mol = Chem.MolFromSmiles(smiles)
      if mol:
        correct += 1
        smiles_list.append((smiles, i))

  frac_correct = correct / len(graphs)
  smiles_list_0 = [s[0] for s in smiles_list]
  unique_frac = len(set(smiles_list_0)) / len(graphs)

  return frac_correct, smiles_list, unique_frac


In [None]:
def gen_graphs(num_per_generation=1000, num_generations=40, restart_inference_method=False, model_path=None):
  """
  Generate a specified number of graphs.
  """
  print(f"Generate {num_generations*num_per_generation} graphs.")
  if DEBUG:
    num_generations = int(num_generations / 10)

  if model_path is None:
    model_path = PATH_PATTERN + "_model_epoch_*.pth"

  path = sorted(glob.glob(model_path))[-1]
  num_samples = num_per_generation * num_generations
  filepath = path.replace(".pth", f'_{num_samples:06d}_w{restart_inference_method}_generated.pickle')

  results = list()
  try:
    results = load_file(filepath)
  except:
    pass

  if len(results) == num_samples:
    return results

  dataset_base, dataset_base_test = build_dataset()
  model_base = load_base_model(dataset_base, path_pattern_checkpoint=path)

  i = 0
  while len(results) < num_samples:
    i += 1
    num = max(num_per_generation, len(results) - num_samples)
    graphs = generate_examples(model_base, dataset_base, num=num, restart_inference_method=restart_inference_method)
    results.extend(graphs)
    if i % 5 == 0 or len(results) >= num_samples:
      write_file(filepath, results)

  assert(len(results) == num_samples)
  return results


def test_graph_generation(path_pattern=None, restart_inference_method=False):
  generated_graphs = gen_graphs(restart_inference_method=restart_inference_method, model_path=path_pattern)
  return find_frac_correct(generated_graphs)

## Discriminator NN

In [None]:
from torch_geometric.nn import PNA, GCN

class PNAdisc(torch.nn.Module):
  def __init__(self, train_dataset=None, hidden_channels=HIDDEN_CHANNELS_DISC,
               depth=DEPTH_DISC, dropout=DROPOUT_DISC, towers=1,
               normalization=NORMALIZATION_DISC, pre_post_layers=1):
    super(PNAdisc, self).__init__()

    self.sigmoid = nn.Sigmoid()

    # Adjust hidden channels based on towers
    hidden_channels = towers * ((hidden_channels // towers) + 1)

    in_channels = INDICATOR_FEATURE_DIM + ATOM_FEATURE_DIM + BOND_FEATURE_DIM
    assert in_channels == 11

    deg = dataset_to_degree_bin(train_dataset).to(DEVICE)
    aggregators = ['mean', 'min', 'max', 'std']
    scalers = ['identity', 'amplification', 'attenuation']
    self.normalization = BatchNorm(hidden_channels) if normalization else None
    self.pnanet = PNA(in_channels=in_channels,
                     hidden_channels=hidden_channels,
                     out_channels=1,
                     num_layers=depth,
                     aggregators=aggregators,
                     scalers=scalers,
                     deg=deg,
                     dropout=dropout,
                     towers=towers,
                     norm=self.normalization,
                     pre_layers=pre_post_layers,
                     post_layers=pre_post_layers)

    self.gcnnet = GCN(in_channels=in_channels,
                     hidden_channels=5,
                     out_channels=1,
                     num_layers=depth,
                     dropout=dropout)

  def forward(self, x, edge_index, batch=None):
    x = x + torch.randn_like(x) * DISC_NOISE
    x = self.gcnnet(x, edge_index) # or pna
    x = global_mean_pool(x, batch)
    x = self.sigmoid(x)

    return x

In [None]:
def train_epoch_disc(model_disc, dataloader, optimizer):
  model_disc.train()
  start_time = time.time()
  loss_list = []
  acc_list = []

  for batch in dataloader:
    batch = batch.to(DEVICE)
    optimizer.zero_grad()
    pred = model_disc(batch.x, batch.edge_index, batch.batch)
    loss = F.binary_cross_entropy(pred.flatten(), batch.y.flatten())
    loss.backward()
    optimizer.step()

    acc = (torch.abs(pred.flatten() - batch.y.flatten()) < 0.5).float()
    acc_list.extend(acc.detach().cpu().tolist())
    loss_list.append(loss.item())

  return np.mean(loss_list), np.mean(acc_list), time.time() - start_time

In [None]:
def test_disc(model_disc, dataloader):
  model_disc.eval()
  start_time = time.time()
  loss_list = list()
  acc_list = list()
  for batch in dataloader:
    batch = batch.to(DEVICE)
    pred = model_disc(batch.x, batch.edge_index, batch.batch)
    loss = F.binary_cross_entropy(pred.flatten(), batch.y.flatten())
    acc = (torch.abs(pred.flatten()-batch.y.flatten()) < 0.5).float()
    acc_list = acc_list + acc.detach().cpu().tolist()
    loss_list.append(loss.item())

  return np.mean(loss_list), np.mean(acc_list), time.time()-start_time

In [None]:
def train_disc_model(dataloader_disc, dataloader_disc_test, round_i):
  model_disc = PNAdisc(dataloader_disc).to(DEVICE)
  weight_path = f"{PATH_PATTERN}_discriminator_model_round_{round_i:05}.pth"

  try:
    checkpoint = torch.load(weight_path)
    model_disc.load_state_dict(checkpoint['model_state_dict'])
    print(f"found disc model in round {round_i:05}")
    return model_disc
  except:
    pass

  epochs = []
  losses_train = []
  losses_test = []

  optimizer_disc = Adam(model_disc.parameters(), lr=0.0001)
  for epoch_i in range(EPOCHS_DISC_MODEL):
    loss_train, acc_train, t_train = train_epoch_disc(model_disc, dataloader_disc, optimizer_disc)
    #if epoch_i % 1 == 1 or epoch_i == EPOCHS_DISC_MODEL - 1:
    loss_test, acc_test, t_test = test_disc(model_disc, dataloader_disc_test)
    print(f"train discriminator: epoch: {epoch_i:05}, loss: {loss_train:.4f}, loss test: {loss_test:.4f}, acc: {acc_train:.3f}, acc test: {acc_test:.3f}, time: {t_train:.3f}")
    log({
        "disc/step": epoch_i + (1+round_i) * EPOCHS_DISC_MODEL,
        "disc/epoch": epoch_i + (1+round_i) * EPOCHS_DISC_MODEL,
        "disc/loss_train": loss_train,
        'disc/loss_test': loss_test,
        "disc/acc_train": acc_train,
        "disc/acc_test": acc_test,
        "disc/time": t_train
    })
    epochs.append(epoch_i)
    losses_train.append(loss_train)
    losses_test.append(loss_test)

  # Plotting losses
  plt.clf()
  plt.plot(epochs, losses_train, label='train')
  plt.plot(epochs, losses_test, label='test')
  plt.legend()
  plt.savefig(f"discriminator_model_{round_i:05}.png")

  torch.save({
      'model_state_dict': model_disc.state_dict(),
      'epochs': epochs,
      "losses_train": losses_train,
      "losses_test": losses_test
  }, weight_path)

  return model_disc


In [None]:
def run_disc(round_i=1):
  print(f"Train discriminator round {round_i}.")
  fake_graphs = gen_graphs(restart_inference_method=False)
  dataset_base, dataset_base_test = build_dataset()
  real_graphs = random.sample(dataset_base, len(fake_graphs))
  dataset = list()

  for g in fake_graphs:
    g_i = g.clone()
    g_i.y = torch.tensor(0.1) # use 0.1 and 0.9 for better stability
    dataset.append(g_i)

  for g in real_graphs:
    g_i = g.clone()
    g_i.y = torch.tensor(0.9)
    dataset.append(g_i)

  random.shuffle(dataset)
  cut_off = int(len(dataset) * 0.8)
  dataloader_train = DataLoader(dataset[:cut_off], batch_size = BATCH_SIZE, shuffle=True)
  dataloader_test = DataLoader(dataset[cut_off:], batch_size = BATCH_SIZE, shuffle=True)

  model_disc = train_disc_model(dataloader_train, dataloader_test, round_i)
  return model_disc


## Train Jointly

In [None]:
def train_epoch(model, dataloader, optimizer, model_disc=None):
  schedule = generate_schedule()
  model.train()
  start_time = time.time()
  loss_list = []
  loss_list_disc = []

  for batch in tqdm(dataloader):
    if batch.x.shape[0] < 2:
      continue

    optimizer.zero_grad()
    batch = batch.to(DEVICE)
    row_num = batch.x.shape[0]

    num_graphs_in_batch = int(torch.max(batch.batch).item() + 1)
    future_t_select = torch.randint(0, TIMESTEPS, (num_graphs_in_batch,), device=DEVICE)
    future_t = torch.gather(future_t_select, 0, batch.batch)
    assert future_t.numel() == row_num

    mask = torch.cat((torch.tensor([False] * row_num, device=DEVICE).view(-1, 1), batch.x[:, 1:] > -0.5), dim=1)
    x_start_gt = batch.x[mask].view(row_num, FEATURE_DIM)
    x_with_noise, noise_gt = forward_diffusion(x_start_gt, future_t)

    x_in = batch.x.clone()
    x_in[mask] = x_with_noise.flatten()
    x_start_pred = model(x_in, future_t, batch.edge_index)
    loss = F.mse_loss(x_start_gt, x_start_pred)

    disc_loss = torch.tensor(0.0, device=DEVICE)
    if model_disc is not None:
      x_in[mask] = x_start_pred.flatten()
      disc_loss = torch.mean((1.0 - model_disc(x_in, batch.edge_index, batch=batch.batch))**2)
      loss = (1.0 - GAMMA) * loss + GAMMA * disc_loss

    loss.backward()
    loss_list.append(loss.item())
    loss_list_disc.append(disc_loss.item())
    optimizer.step()

  return np.mean(loss_list), np.mean(loss_list_disc), time.time() - start_time


In [None]:
def train_base_model(train_loader, epoch_num=EPOCHS_GEN, model_disc=None):
  print("Train denoising model.")
  if DEBUG:
    epoch_num = int(epoch_num / 10)

  dataset_train = train_loader.dataset
  model_base = PNAnet(dataset_train).to(DEVICE)

  optimizer = Adam(model_base.parameters(), lr=LEARNING_RATE_GEN * 0.01) # the mutliplication makes no real sense
  loss_list = []
  model_base, optimizer, loss_list, epoch_start = load_latest_checkpoint(model_base, optimizer, loss_list, epoch_i=0)
  epoch_start = min(epoch_start, epoch_num)
  print(f"from {epoch_start} to {epoch_num}")

  for epoch_i in range(epoch_start, epoch_num):
    try:
      loss, loss_disc, time_elapsed = train_epoch(model_base, train_loader, optimizer, model_disc=model_disc)
      loss_list.append((epoch_i, loss))
      mean_loss = np.mean([y for _, y in loss_list] + [loss])
      print(f"loss in epoch {epoch_i:07} is: {loss:05.4f} with mean loss {mean_loss:05.4f} with disc loss {loss_disc:05.4f} with runtime {time_elapsed:05.4f}")
      log({
        "gen/step": epoch_i,
        "gen/epoch": epoch_i,
        "gen/loss": loss,
        "gen/mean_loss": mean_loss,
        "gen/start_loss": loss_disc,
        "gen/runtime": time_elapsed
      })

      if (epoch_i % 20 == 0 and epoch_i > epoch_start) or epoch_i == epoch_num - 1 or BATCH_SIZE == 1:
        print("save")
        save_model(model_base, optimizer, loss_list, epoch_i + 1, upload = epoch_i == epoch_num - 1)
        time.sleep(0.01)
        frac, smiles_list, unique_frac = test_graph_generation(restart_inference_method=False)
        frac_restart, smiles_list_restart, unique_frac_restart = 0, list(), 0 #test_graph_generation(restart_inference_method=True)
        print(f"Fraction of correct graphs: {frac}, with restart_inference_method inference {frac_restart}")
        log({
          "inference/step": epoch_i,
          "inference/epoch": epoch_i,
          "inference/frac_normal": frac,
          "inference/frac_restart": frac_restart,
          "inference/frac_normal_unique": unique_frac,
          "inference/frac_restart_unique": unique_frac_restart
        })
        log_smiles(smiles_list, f"{PATH_PATTERN}_smiles_{epoch_i}_normal.txt")
        log_smiles(smiles_list_restart, f"{PATH_PATTERN}_smiles_{epoch_i}_restart.txt")
        try:
          print(smiles_list[:20])
          print(smiles_list_restart[:20])
        except Exception as e:
          print(e)
    except Exception as e:
      print(f"An error occurred during training: \n{str(e)}")
      traceback.print_exc()
      raise e

  return model_base


### Putting Everything Together

In [None]:
def start_experiments(rounds=6): #originally 5
  global DISC_NOISE
  if DEBUG:
    rounds = rounds // 2
  dataset_base, dataset_base_test = build_dataset()
  dataloader_base = DataLoader(dataset_base, batch_size=BATCH_SIZE, shuffle=True)
  model_base = train_base_model(dataloader_base, epoch_num = EPOCHS_GEN*1)

  for round_i in range(1, rounds):
    if BASELINE:
      model_disc = None
    else:
      model_disc = run_disc(round_i=round_i)
    model_base = train_base_model(dataloader_base, epoch_num = EPOCHS_GEN*(round_i+1), model_disc=model_disc)
    #DISC_NOISE = DISC_NOISE*0.5

  save_src_file()
  return  model_base


### Start Training

In [None]:
try:
  import wandb
except:
  # Train with discriminator (our method)
  start_experiments(rounds=5)
  # Train without discriminator (baseline)
  BASELINE = True
  start_experiments(rounds=5)

## Training with WandB

We can use WandB to save the training results.

In [None]:
import wandb
print(wandb.__path__) # this should look like ['/usr/local/lib/python3.10/dist-packages/wandb']. Make sure to not install wandb into your current working dir.

['/usr/local/lib/python3.10/dist-packages/wandb']


In [None]:
WANDB_TOKEN = "" # Add you WandB token here.

In [None]:
sweep_config = {
    "name": "AliaMol",
    "method": "random",
    "metric": {
        "name": "inference/frac_normal_unique",
        "goal": "maximize",
    },
    "parameters": {
        "BATCH_SIZE": {"values": [64]}, #256
        "GAMMA": {"values": [0.2]}, #0.1
        "DISC_NOISE": {"values": [1.0]},  # 0.3 in generation for paper
        "EPOCHS_DISC_MODEL": {"values": [5]},
        "EPOCHS_GEN": {"values": [100]},
    },
}

In [None]:
def save_src_file():
  try:
    os.system("pip list > pip_list.txt 2>&1")
    for txt_file in sorted(glob.glob('*.txt')):
      z = "".join(filter(str.isalnum, txt_file))
      wandb.log_artifact(txt_file, name=f"src_txt_{SWEEP_ID}_{z}", type="my_dataset_txt")
    for python_file in sorted(glob.glob('*.ipynb')):
      z = "".join(filter(str.isalnum, python_file))
      wandb.log_artifact(python_file, name=f"src_ipynb_{SWEEP_ID}_{z}", type="my_dataset_ipynb")
    for python_file in sorted(glob.glob('*.py')):
      z = "".join(filter(str.isalnum, python_file))
      wandb.log_artifact(python_file, name=f"src_py_{SWEEP_ID}_{z}", type="my_dataset_py")
  except Exception as e:
    print(e)




In [None]:
def get_wand_api_key():
  global WANDB_TOKEN
  if len(WANDB_TOKEN) > 0:
    return WANDB_TOKEN
  import sys
  IN_COLAB = 'google.colab' in sys.modules
  if not IN_COLAB:
    os.system("cp ~/api_key.txt api_key.txt")
  file_path = 'api_key.txt'
  with open(file_path, 'r') as file:
      api_key = file.read().strip()
  return api_key


def main():
  global PATH_PATTERN
  with wandb.init() as run:
    PATH_PATTERN = PATH_PATTERN_BASE + '_' +str(run.name) + '_' +str(BASELINE)
    save_src_file()
    for hyper_param_name in sweep_config['parameters']:
      globals()[hyper_param_name] = run.config[hyper_param_name]
      print("set ", hyper_param_name, "=", run.config[hyper_param_name])
    start_experiments()

def start_with_wandb(set_baseline_true=False):
  global SWEEP_ID, USE_WANDB, PATH_PATTERN, BASELINE
  if set_baseline_true:
    BASELINE = True
  else:
    BASELINE = False
  USE_WANDB = True
  os.environ["WANDB_MODE"] = "online"
  try:
    SWEEP_ID = wandb.sweep(sweep_config, project=PROJECT_NAME)
    wandb.agent(SWEEP_ID, function=main, count=10)
  except Exception as e:
    error_message = traceback.format_exc()
    print("final error:\n", error_message)
    with open('_error_log.txt', 'a') as f:
      f.write(error_message + '\n')
    time.sleep(10)


In [None]:
wandb.login(key=get_wand_api_key())

#for _ in range(10):
start_with_wandb()

#for _ in range(10):
#start_with_wandb(set_baseline_true=True)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: 7bs1xqew
Sweep URL: https://wandb.ai/nextaid/MoldDiffGAN_noise/sweeps/7bs1xqew


[34m[1mwandb[0m: Agent Starting Run: vrqdlze2 with config:
[34m[1mwandb[0m: 	BATCH_SIZE: 256
[34m[1mwandb[0m: 	DISC_NOISE: 0.4
[34m[1mwandb[0m: 	EPOCHS_DISC_MODEL: 10
[34m[1mwandb[0m: 	EPOCHS_GEN: 100
[34m[1mwandb[0m: 	GAMMA: 0.2
[34m[1mwandb[0m: Currently logged in as: [33mgerritgr[0m ([33mnextaid[0m). Use [1m`wandb login --relogin`[0m to force relogin


set  BATCH_SIZE = 256
set  GAMMA = 0.2
set  DISC_NOISE = 0.4
set  EPOCHS_DISC_MODEL = 10
set  EPOCHS_GEN = 100
Trying to read dataset.pickle
Train denoising model.
Trying to read deg.pickle
from 0 to 100


100%|██████████| 419/419 [00:37<00:00, 11.11it/s]


loss in epoch 0000000 is: 0.1748 with mean loss 0.1748 with disc loss 0.0000 with runtime 37.7110


100%|██████████| 419/419 [00:39<00:00, 10.74it/s]


loss in epoch 0000001 is: 0.0913 with mean loss 0.1191 with disc loss 0.0000 with runtime 39.0135


100%|██████████| 419/419 [00:36<00:00, 11.52it/s]


loss in epoch 0000002 is: 0.0865 with mean loss 0.1098 with disc loss 0.0000 with runtime 36.3803


100%|██████████| 419/419 [00:36<00:00, 11.46it/s]


loss in epoch 0000003 is: 0.0816 with mean loss 0.1032 with disc loss 0.0000 with runtime 36.5631


100%|██████████| 419/419 [00:36<00:00, 11.49it/s]


loss in epoch 0000004 is: 0.0780 with mean loss 0.0984 with disc loss 0.0000 with runtime 36.4876


100%|██████████| 419/419 [00:36<00:00, 11.40it/s]


loss in epoch 0000005 is: 0.0758 with mean loss 0.0948 with disc loss 0.0000 with runtime 36.7587


100%|██████████| 419/419 [00:36<00:00, 11.48it/s]


loss in epoch 0000006 is: 0.0744 with mean loss 0.0921 with disc loss 0.0000 with runtime 36.5148


100%|██████████| 419/419 [00:36<00:00, 11.35it/s]


loss in epoch 0000007 is: 0.0735 with mean loss 0.0899 with disc loss 0.0000 with runtime 36.9453


100%|██████████| 419/419 [00:36<00:00, 11.51it/s]


loss in epoch 0000008 is: 0.0721 with mean loss 0.0880 with disc loss 0.0000 with runtime 36.4275


100%|██████████| 419/419 [00:36<00:00, 11.34it/s]


loss in epoch 0000009 is: 0.0715 with mean loss 0.0864 with disc loss 0.0000 with runtime 36.9639


100%|██████████| 419/419 [00:36<00:00, 11.46it/s]


loss in epoch 0000010 is: 0.0708 with mean loss 0.0851 with disc loss 0.0000 with runtime 36.5773


100%|██████████| 419/419 [00:36<00:00, 11.33it/s]


loss in epoch 0000011 is: 0.0702 with mean loss 0.0839 with disc loss 0.0000 with runtime 36.9815


100%|██████████| 419/419 [00:36<00:00, 11.47it/s]


loss in epoch 0000012 is: 0.0699 with mean loss 0.0829 with disc loss 0.0000 with runtime 36.5416


100%|██████████| 419/419 [00:37<00:00, 11.30it/s]


loss in epoch 0000013 is: 0.0697 with mean loss 0.0820 with disc loss 0.0000 with runtime 37.0749


100%|██████████| 419/419 [00:36<00:00, 11.45it/s]


loss in epoch 0000014 is: 0.0696 with mean loss 0.0812 with disc loss 0.0000 with runtime 36.6018


100%|██████████| 419/419 [00:37<00:00, 11.28it/s]


loss in epoch 0000015 is: 0.0694 with mean loss 0.0805 with disc loss 0.0000 with runtime 37.1671


100%|██████████| 419/419 [00:36<00:00, 11.48it/s]


loss in epoch 0000016 is: 0.0692 with mean loss 0.0799 with disc loss 0.0000 with runtime 36.5046


100%|██████████| 419/419 [00:37<00:00, 11.30it/s]


loss in epoch 0000017 is: 0.0690 with mean loss 0.0793 with disc loss 0.0000 with runtime 37.0732


100%|██████████| 419/419 [00:36<00:00, 11.45it/s]


loss in epoch 0000018 is: 0.0687 with mean loss 0.0787 with disc loss 0.0000 with runtime 36.5924


100%|██████████| 419/419 [00:37<00:00, 11.28it/s]


loss in epoch 0000019 is: 0.0687 with mean loss 0.0783 with disc loss 0.0000 with runtime 37.1565


100%|██████████| 419/419 [00:36<00:00, 11.46it/s]


loss in epoch 0000020 is: 0.0686 with mean loss 0.0778 with disc loss 0.0000 with runtime 36.5676
save
Generate 40000 graphs.
Trying to read moldiffusion_run_royal-sweep-1_False_model_epoch_00000021_040000_wFalse_generated.pickle
An error occurred: [Errno 2] No such file or directory: 'moldiffusion_run_royal-sweep-1_False_model_epoch_00000021_040000_wFalse_generated.pickle'
Trying to read dataset.pickle
Trying to read deg.pickle
Loaded checkpoint of epoch 00000021 from disk.
generate samples batched
load g DataBatch(edge_index=[2, 137876], x=[43273, 11], batch=[43273], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


100%|██████████| 1000/1000 [01:57<00:00,  8.51it/s]


generated graphs  [Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11])]
generate samples batched
load g DataBatch(edge_index=[2, 137432], x=[43147, 11], batch=[43147], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


100%|██████████| 1000/1000 [01:56<00:00,  8.55it/s]


generated graphs  [Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 112], x=[36, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 112], x=[36, 11])]
generate samples batched
load g DataBatch(edge_index=[2, 137244], x=[43093, 11], batch=[43093], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


100%|██████████| 1000/1000 [01:56<00:00,  8.62it/s]


generated graphs  [Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11])]
generate samples batched
load g DataBatch(edge_index=[2, 136964], x=[43014, 11], batch=[43014], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


100%|██████████| 1000/1000 [01:55<00:00,  8.67it/s]


generated graphs  [Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 84], x=[28, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11])]
generate samples batched
load g DataBatch(edge_index=[2, 137672], x=[43214, 11], batch=[43214], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


100%|██████████| 1000/1000 [01:55<00:00,  8.62it/s]


generated graphs  [Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 112], x=[36, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11])]
Trying to write moldiffusion_run_royal-sweep-1_False_model_epoch_00000021_040000_wFalse_generated.pickle
generate samples batched
load g DataBatch(edge_index=[2, 137720], x=[43227, 11], batch=[43227], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


100%|██████████| 1000/1000 [01:56<00:00,  8.61it/s]


generated graphs  [Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11]), Data(edge_index=[2, 144], x=[45, 11])]
generate samples batched
load g DataBatch(edge_index=[2, 137816], x=[43256, 11], batch=[43256], ptr=[1001]) tensor([  0,   0,   0,  ..., 999, 999, 999], device='cuda:0')


 87%|████████▋ | 873/1000 [01:42<00:16,  7.52it/s]