In [19]:
import sys

In [20]:
# For colab
# !pip install torch_geometric
# !pip install grakel
# from google.colab import drive
# drive.mount('/content/drive')

In [21]:
path_to_modules = '..'
sys.path.append(path_to_modules)
path_to_modules = '../src/baseline'
sys.path.append(path_to_modules)

In [22]:
from torch_geometric.loader import DataLoader

from src.cond_autoencoder import CondVariationalAutoEncoder
from src.baseline.denoise_model import DenoiseNN, p_losses, sample
from src.baseline.utils import linear_beta_schedule, construct_nx_from_adj, preprocess_dataset
from src.utils import get_features

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error

import torch.nn.functional as F
import torch
import numpy as np


DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [23]:
#deafault args
lr = 1e-3
dropout = 0
batch_size = 256
epochs_autoencoder = 200
hidden_dim_encoder = 64
hidden_dim_decoder =  256
latent_dim = 32
n_max_nodes = 50
n_layers_encoder = 2
n_layers_decoder =3
spectral_emb_dim =10
epochs_denoise = 100
timesteps = 500
hidden_dim_denoise = 512
n_layers_denoise = 3
dim_condition = 128
n_condition = 7

In [24]:
# initialize VGAE model
EARLY_STOP_ROUNS = 10
MAX_NODES=50
SPECTR_EMB_DIM = 10
TRAIN_SIZE = 8000
VAL_SIZE = 1000

trainset = preprocess_dataset("train", MAX_NODES, SPECTR_EMB_DIM)
validset = preprocess_dataset("valid", MAX_NODES, SPECTR_EMB_DIM)
testset = preprocess_dataset("test", MAX_NODES, SPECTR_EMB_DIM)

Dataset ../data/dataset_train.pt loaded from file
Dataset ../data/dataset_valid.pt loaded from file
Dataset ../data/dataset_test.pt loaded from file


In [25]:
def train_autoencoder(spectral_emb_dim,
                      hidden_dim_encoder,
                      hidden_dim_decoder,
                      latent_dim,
                      n_layers_encoder,
                      n_layers_decoder,
                      lr,
                      train_loader,
                      val_loader,
                      epochs_autoencoder=300,
                      ):
  autoencoder = CondVariationalAutoEncoder(spectral_emb_dim+1, hidden_dim_encoder, hidden_dim_decoder, latent_dim, n_layers_encoder, n_layers_decoder, n_max_nodes=MAX_NODES).to(DEVICE)
  optimizer = torch.optim.Adam(autoencoder.parameters(), lr=lr)
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, threshold=5e-3, min_lr=5e-6)
  early_stopping_counts = 0

  best_val_loss = np.inf
  for epoch in range(1, epochs_autoencoder + 1):
      autoencoder.train()
      train_loss_all = 0
      train_count = 0
      train_loss_all_recon = 0
      train_loss_all_kld = 0
      cnt_train=0

      for data in train_loader:
          data = data.to(DEVICE)
          optimizer.zero_grad()
          loss, recon, kld  = autoencoder.loss_function(data, data.stats)
          train_loss_all_recon += recon.item()
          train_loss_all_kld += kld.item()
          cnt_train+=1
          loss.backward()
          train_loss_all += loss.item()
          train_count += torch.max(data.batch)+1
          optimizer.step()

      autoencoder.eval()
      val_loss_all = 0
      val_count = 0
      cnt_val = 0
      val_loss_all_recon = 0
      val_loss_all_kld = 0

      for data in val_loader:
          data = data.to(DEVICE)
          loss, recon, kld  = autoencoder.loss_function(data, data.stats)
          val_loss_all_recon += recon.item()
          val_loss_all_kld += kld.item()
          val_loss_all += loss.item()
          cnt_val+=1
          val_count += torch.max(data.batch)+1

      # if epoch % 1 == 0:
      #     print('Epoch: {:04d}, Train Loss: {:.5f}, Train Reconstruction Loss: {:.2f}, Train KLD Loss: {:.2f}, Val Loss: {:.5f}, Val Reconstruction Loss: {:.2f}, Val KLD Loss: {:.2f}'.format(epoch, train_loss_all/TRAIN_SIZE, train_loss_all_recon/TRAIN_SIZE, train_loss_all_kld/TRAIN_SIZE, val_loss_all/VAL_SIZE, val_loss_all_recon/VAL_SIZE, val_loss_all_kld/VAL_SIZE))

      scheduler.step(val_loss_all)
      early_stopping_counts += 1

      if best_val_loss >= val_loss_all:
          best_val_loss = val_loss_all
          torch.save({
              'state_dict': autoencoder.state_dict(),
              'optimizer' : optimizer.state_dict(),
          }, 'autoencoder.pth.tar')
          early_stopping_counts = 0


      if early_stopping_counts == EARLY_STOP_ROUNS:
        break

  checkpoint = torch.load('autoencoder.pth.tar')
  autoencoder.load_state_dict(checkpoint['state_dict'])
  return autoencoder.eval()

In [26]:
def train_denoiser(
    n_layers_denoise,
    timesteps,
    n_condition,
    dim_condition,
    latent_dim,
    hidden_dim_denoise,
    autoencoder,
    epochs_denoise,
    lr,
    beta_fn = linear_beta_schedule,
):

  betas = beta_fn(timesteps=timesteps)
  alphas = 1. - betas
  alphas_cumprod = torch.cumprod(alphas, axis=0)
  alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
  sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
  sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
  sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
  posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

  denoise_model = DenoiseNN(input_dim=latent_dim, hidden_dim=hidden_dim_denoise, n_layers=n_layers_denoise, n_cond=n_condition, d_cond=dim_condition).to(DEVICE)
  optimizer = torch.optim.Adam(denoise_model.parameters(), lr=lr)
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, threshold=5e-3, min_lr=5e-6)

  best_val_loss = np.inf
  early_stopping_counts = 0

  for epoch in range(1, epochs_denoise+1):
      denoise_model.train()
      train_loss_all = 0
      train_count = 0
      for data in train_loader:
          data = data.to(DEVICE)
          optimizer.zero_grad()
          x_g = autoencoder.encode(data, data.stats)
          t = torch.randint(0, timesteps, (x_g.size(0),), device=DEVICE).long()
          loss = p_losses(denoise_model, x_g, t, data.stats, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, loss_type="huber")
          loss.backward()
          train_loss_all += x_g.size(0) * loss.item()
          train_count += x_g.size(0)
          optimizer.step()

      denoise_model.eval()
      val_loss_all = 0
      val_count = 0
      for data in val_loader:
          data = data.to(DEVICE)
          x_g = autoencoder.encode(data, data.stats)
          t = torch.randint(0, timesteps, (x_g.size(0),), device=DEVICE).long()
          loss = p_losses(denoise_model, x_g, t, data.stats, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, loss_type="huber")
          val_loss_all += x_g.size(0) * loss.item()
          val_count += x_g.size(0)

      scheduler.step(val_loss_all)
      early_stopping_counts += 1
      # if epoch % 5 == 0:
      #   print('Epoch: {:04d}, Train Loss: {:.5f}, Val Loss: {:.5f}'.format(epoch, train_loss_all/TRAIN_SIZE, val_loss_all/VAL_SIZE))

      if best_val_loss >= val_loss_all:
          best_val_loss = val_loss_all
          torch.save({
              'state_dict': denoise_model.state_dict(),
              'optimizer' : optimizer.state_dict(),
          }, 'denoise_model.pth.tar')
          early_stopping_counts = 0

      if early_stopping_counts == EARLY_STOP_ROUNS:
        # print('early_stopping', best_val_loss, scheduler.get_last_lr())
        break

  checkpoint = torch.load('denoise_model.pth.tar')
  denoise_model.load_state_dict(checkpoint['state_dict'])
  denoise_model.eval()
  return denoise_model, betas

In [27]:
# apply the pipeline and eval
from tqdm import tqdm


def eval(loader, autoencoder, denoiser, betas):
  targets = []
  preds = []

  with torch.no_grad():
    for data in tqdm(loader):
      data = data.to(DEVICE)
      stat = data.stats
      targets.append(stat.cpu().numpy())
      bs = stat.size(0)
      samples = sample(denoiser, stat, latent_dim=latent_dim, timesteps=timesteps, betas=betas, batch_size=bs)
      x_sample = samples[-1]
      adj = autoencoder.decode_mu(x_sample, stat)
      for i in range(bs):
        Gs_generated = construct_nx_from_adj(adj[i,:,:].detach().cpu().numpy())
        preds.append(get_features(Gs_generated))

  preds = np.array(preds)
  targets = np.concatenate(targets)

  scaler = StandardScaler()
  y_test_scaled = scaler.fit_transform(targets)
  y_pred_scaled = scaler.transform(preds)
  mae = mean_absolute_error(y_test_scaled, y_pred_scaled)
  return mae

In [28]:
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validset, batch_size=batch_size, shuffle=False)

encoder = train_autoencoder(SPECTR_EMB_DIM,
                          hidden_dim_encoder,
                          hidden_dim_decoder,
                          latent_dim,
                          n_layers_encoder,
                          n_layers_decoder,
                          lr,
                          train_loader,
                          val_loader)
denoiser, betas = train_denoiser(
    n_layers_denoise,
    timesteps,
    7,
    dim_condition,
    latent_dim,
    hidden_dim_denoise,
    autoencoder=encoder,
    epochs_denoise=300,
    lr=lr,
    beta_fn = linear_beta_schedule,
)
mae = eval(val_loader, encoder, denoiser, betas)
mae

100%|██████████| 4/4 [00:37<00:00,  9.43s/it]


0.34078903035047065

In [29]:
import csv
n_condition = 7

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
with open("output.csv", "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    # Write the header
    writer.writerow(["graph_id", "edge_list"])
    for k, data in enumerate(test_loader):
        data = data.to(DEVICE)

        stat = data.stats
        bs = stat.size(0)

        graph_ids = data.filename

        samples = sample(denoiser, data.stats, latent_dim=latent_dim, timesteps=timesteps, betas=betas, batch_size=bs)
        x_sample = samples[-1]
        adj = encoder.decode_mu(x_sample, stat)
        stat_d = torch.reshape(stat, (-1, n_condition))


        for i in range(stat.size(0)):
            stat_x = stat_d[i]

            Gs_generated = construct_nx_from_adj(adj[i,:,:].detach().cpu().numpy())
            stat_x = stat_x.detach().cpu().numpy()

            # Define a graph ID
            graph_id = graph_ids[i]

            # Convert the edge list to a single string
            edge_list_text = ", ".join([f"({u}, {v})" for u, v in Gs_generated.edges()])
            # Write the graph ID and the full edge list as a single row
            writer.writerow([graph_id, edge_list_text])