In [4]:
import sys

In [5]:
# For colab

# !pip install torch_geometric
# !pip install grakel
# from google.colab import drive
# drive.mount('/content/drive')

In [6]:
path_to_modules = '..'
sys.path.append(path_to_modules)
path_to_modules = '../src/baseline'
sys.path.append(path_to_modules)

In [7]:
from torch_geometric.loader import DataLoader

from src.cond_autoencoder import CondVariationalAutoEncoder
from src.baseline.utils import construct_nx_from_adj, preprocess_dataset
from src.utils import get_features

from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error

import torch.nn.functional as F
import torch
import numpy as np

from sklearn.neighbors import NearestNeighbors

In [8]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
#deafault args
lr = 1e-3
dropout = 0
batch_size = 256
epochs_autoencoder = 200
hidden_dim_encoder = 64
hidden_dim_decoder =  256
latent_dim = 32
n_max_nodes = 50
n_layers_encoder = 2
n_layers_decoder = 3
spectral_emb_dim =10

n_condition = 7

In [10]:
# initialize VGAE model
EARLY_STOP_ROUNS = 10
MAX_NODES=50
SPECTR_EMB_DIM = 10
TRAIN_SIZE = 8000
VAL_SIZE = 1000

trainset = preprocess_dataset(f"train", MAX_NODES, SPECTR_EMB_DIM)
validset = preprocess_dataset("valid", MAX_NODES, SPECTR_EMB_DIM)
testset = preprocess_dataset("test", MAX_NODES, SPECTR_EMB_DIM)

Dataset ../data/dataset_train.pt loaded from file
Dataset ../data/dataset_valid.pt loaded from file
Dataset ../data/dataset_test.pt loaded from file


In [11]:
def train_autoencoder(spectral_emb_dim,
                      hidden_dim_encoder,
                      hidden_dim_decoder,
                      latent_dim,
                      n_layers_encoder,
                      n_layers_decoder,
                      lr,
                      train_loader,
                      val_loader,
                      epochs_autoencoder=300,
                      ):
  autoencoder = CondVariationalAutoEncoder(spectral_emb_dim+1, hidden_dim_encoder, hidden_dim_decoder, latent_dim, n_layers_encoder, n_layers_decoder, n_max_nodes=MAX_NODES).to(DEVICE)
  optimizer = torch.optim.Adam(autoencoder.parameters(), lr=lr)
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5, threshold=5e-3, min_lr=5e-6)
  early_stopping_counts = 0

  best_val_loss = np.inf
  for epoch in range(1, epochs_autoencoder + 1):
      autoencoder.train()
      train_loss_all = 0
      train_count = 0
      train_loss_all_recon = 0
      train_loss_all_kld = 0
      cnt_train=0

      for data in train_loader:
          data = data.to(DEVICE)
          optimizer.zero_grad()
          loss, recon, kld  = autoencoder.loss_function(data, data.stats)
          train_loss_all_recon += recon.item()
          train_loss_all_kld += kld.item()
          cnt_train+=1
          loss.backward()
          train_loss_all += loss.item()
          train_count += torch.max(data.batch)+1
          optimizer.step()

      autoencoder.eval()
      val_loss_all = 0
      val_count = 0
      cnt_val = 0
      val_loss_all_recon = 0
      val_loss_all_kld = 0

      for data in val_loader:
          data = data.to(DEVICE)
          loss, recon, kld  = autoencoder.loss_function(data, data.stats)
          val_loss_all_recon += recon.item()
          val_loss_all_kld += kld.item()
          val_loss_all += loss.item()
          cnt_val+=1
          val_count += torch.max(data.batch)+1

      # if epoch % 1 == 0:
      #     print('Epoch: {:04d}, Train Loss: {:.5f}, Train Reconstruction Loss: {:.2f}, Train KLD Loss: {:.2f}, Val Loss: {:.5f}, Val Reconstruction Loss: {:.2f}, Val KLD Loss: {:.2f}'.format(epoch, train_loss_all/TRAIN_SIZE, train_loss_all_recon/TRAIN_SIZE, train_loss_all_kld/TRAIN_SIZE, val_loss_all/VAL_SIZE, val_loss_all_recon/VAL_SIZE, val_loss_all_kld/VAL_SIZE))

      scheduler.step(val_loss_all)
      early_stopping_counts += 1

      if best_val_loss >= val_loss_all:
          best_val_loss = val_loss_all
          torch.save({
              'state_dict': autoencoder.state_dict(),
              'optimizer' : optimizer.state_dict(),
          }, 'autoencoder.pth.tar')
          early_stopping_counts = 0


      if early_stopping_counts == EARLY_STOP_ROUNS:
        break

  checkpoint = torch.load('autoencoder.pth.tar')
  autoencoder.load_state_dict(checkpoint['state_dict'])
  return autoencoder.eval()

In [12]:
# apply the pipeline and eval
def eval(loader, autoencoder):
  targets = []
  preds = []

  with torch.no_grad():
    for data in loader:
      data = data.to(DEVICE)
      stat = data.stats
      targets.append(stat.cpu().numpy())
      bs = stat.size(0)
      x_sample = torch.randn(bs, latent_dim, device=DEVICE)
      adj = autoencoder.decode_mu(x_sample, stat)
      for i in range(bs):
        Gs_generated = construct_nx_from_adj(adj[i,:,:].detach().cpu().numpy())
        preds.append(get_features(Gs_generated))

  preds = np.array(preds)
  targets = np.concatenate(targets)

  scaler = StandardScaler()
  y_test_scaled = scaler.fit_transform(targets)
  y_pred_scaled = scaler.transform(preds)
  mae = mean_absolute_error(y_test_scaled, y_pred_scaled)
  return mae

In [13]:
from tqdm import tqdm

In [14]:
torch.set_warn_always(False)
import warnings
warnings.filterwarnings('ignore')

In [15]:
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validset, batch_size=batch_size, shuffle=False)

encoder = train_autoencoder(SPECTR_EMB_DIM,
                          hidden_dim_encoder,
                          hidden_dim_decoder,
                          latent_dim,
                          n_layers_encoder,
                          n_layers_decoder,
                          lr,
                          train_loader,
                          val_loader)



In [16]:
mae = eval(val_loader, encoder)

mae

0.38683516567940307

In [17]:
train_stat_np = []
for data in trainset:
  train_stat_np.append(data.stats.numpy())

train_stat_np = np.concatenate(train_stat_np)
scaler = StandardScaler()
nn = NearestNeighbors(n_neighbors=1).fit(scaler.fit_transform(train_stat_np))

In [18]:
import csv
n_condition = 7

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
with open("knn_condVAE_test_stat.csv", "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    # Write the header
    writer.writerow(["graph_id", "edge_list"])
    for k, data in enumerate(tqdm(test_loader, desc='Processing test set',)):
        _, indices_neigh = nn.kneighbors(scaler.transform(data.stats.numpy()))
        data = data.to(DEVICE)
        graph_ids = data.filename

        closest_neigh_from_train = []
        for ind in indices_neigh.reshape(-1):
          closest_neigh_from_train.append(trainset[ind])

        loader_closest = DataLoader(closest_neigh_from_train, batch_size=batch_size, shuffle=False)
        batch_closest = next(iter(loader_closest)).to(DEVICE)
        stat = data.stats

        x_g  = encoder.encoder(batch_closest, stat)
        mu = encoder.fc_mu(x_g)
        logvar = encoder.fc_logvar(x_g)
        x_g = encoder.reparameterize(mu, logvar)
        adj = encoder.decoder(x_g, stat)
        stat_d = torch.reshape(data.stats, (-1, n_condition))

        for i in range(stat.size(0)):
            stat_x = stat_d[i]

            Gs_generated = construct_nx_from_adj(adj[i,:,:].detach().cpu().numpy())
            stat_x = stat_x.detach().cpu().numpy()

            # Define a graph ID
            graph_id = graph_ids[i]

            # Convert the edge list to a single string
            edge_list_text = ", ".join([f"({u}, {v})" for u, v in Gs_generated.edges()])
            # Write the graph ID and the full edge list as a single row
            writer.writerow([graph_id, edge_list_text])

Processing test set: 100%|██████████| 4/4 [00:02<00:00,  1.82it/s]


In [19]:
import csv
n_condition = 7

test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
with open("knn_condVAE_train_stat.csv", "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    # Write the header
    writer.writerow(["graph_id", "edge_list"])
    for k, data in enumerate(tqdm(test_loader, desc='Processing test set',)):
        _, indices_neigh = nn.kneighbors(scaler.transform(data.stats.numpy()))
        data = data.to(DEVICE)
        graph_ids = data.filename

        closest_neigh_from_train = []
        for ind in indices_neigh.reshape(-1):
          closest_neigh_from_train.append(trainset[ind])

        loader_closest = DataLoader(closest_neigh_from_train, batch_size=batch_size, shuffle=False)
        batch_closest = next(iter(loader_closest)).to(DEVICE)
        stat = batch_closest.stats

        x_g  = encoder.encoder(batch_closest, stat)
        mu = encoder.fc_mu(x_g)
        logvar = encoder.fc_logvar(x_g)
        x_g = encoder.reparameterize(mu, logvar)
        adj = encoder.decoder(x_g, stat)
        stat_d = torch.reshape(data.stats, (-1, n_condition))

        for i in range(stat.size(0)):
            stat_x = stat_d[i]

            Gs_generated = construct_nx_from_adj(adj[i,:,:].detach().cpu().numpy())
            stat_x = stat_x.detach().cpu().numpy()

            # Define a graph ID
            graph_id = graph_ids[i]

            # Convert the edge list to a single string
            edge_list_text = ", ".join([f"({u}, {v})" for u, v in Gs_generated.edges()])
            # Write the graph ID and the full edge list as a single row
            writer.writerow([graph_id, edge_list_text])

Processing test set: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s]
