In [1]:
# !git init .
# !git remote add origin https://github.com/kashparty/STP-GSR.git
# !git pull origin discriminator

In [2]:
# !pip install -r requirements.txt

# 3-Fold Cross-Validation

In [None]:
import os
import hydra
import torch
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import KFold
from hydra import compose, initialize

from src.train import train, eval
from src.plot_utils import plot_adj_matrices
from src.dataset import load_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

def main():
    with initialize(version_base=None, config_path="configs"):
        config = compose(config_name="experiment")

    torch.cuda.empty_cache()

    if torch.cuda.is_available():
        print("Running on GPU")
    else:
        print("Running on CPU")

    kf = KFold(n_splits=config.experiment.kfold.n_splits, 
               shuffle=config.experiment.kfold.shuffle, 
               random_state=config.experiment.kfold.random_state)

    # Initialize folder structure for this run
    base_dir = config.experiment.base_dir
    model_name = config.model.name
    dataset_type = config.dataset.name
    run_name = config.experiment.run_name
    run_dir = f'{base_dir}/{model_name}/{dataset_type}/{run_name}/'

    # Load dataset
    source_data, target_data = load_dataset(config)


    for fold, (train_idx, val_idx) in enumerate(kf.split(source_data)):
        print(f"Training Fold {fold+1}/3")

        # Initialize results directory
        res_dir = f'{run_dir}fold_{fold+1}/'
        if not os.path.exists(res_dir):
            os.makedirs(res_dir)

        # Fetch training and val data for this fold
        source_data_train = [source_data[i] for i in train_idx]
        target_data_train = [target_data[i] for i in train_idx]
        source_data_val = [source_data[i] for i in val_idx]
        target_data_val = [target_data[i] for i in val_idx]

        # Train model for this fold
        train_output = train(config, 
                              source_data_train, 
                              target_data_train,
                              source_data_val,
                              target_data_val, 
                              res_dir)

        # Evaluate model for this fold
        eval_output, eval_loss = eval(config, 
                                      train_output['model'], 
                                      source_data_val, 
                                      target_data_val, 
                                      train_output['criterion_L1'])

        # Final evaluation loss for this fold
        print(f"Final Validation Loss (Target): {eval_loss}")

        # Save source, taregt, and eval output for this fold
        np.save(f'{res_dir}/eval_output.npy', np.array(eval_output))
        np.save(f'{res_dir}/source.npy', np.array([s['mat'] for s in source_data_val]))
        np.save(f'{res_dir}/target.npy', np.array([t['mat'] for t in target_data_val]))


        # Plot predictions for a random sample
        idx = 6
        source_mat_test = source_data_val[idx]['mat']
        target_mat_test = target_data_val[idx]['mat']
        eval_output_t = eval_output[idx]

        plot_adj_matrices(source_mat_test, 
                          target_mat_test, 
                          eval_output_t, 
                          idx, 
                          res_dir, 
                          file_name=f'eval_sample{idx}')


main()

  warn(


Running on GPU
Training Fold 1/3
Model parameters: 317,599
STPGSR(
  (target_edge_initializer): TargetEdgeInitializer(
    (conv1): TransformerConv(160, 67, heads=4)
    (bn1): GraphNorm(268)
  )
  (dual_learner): DualGraphLearner(
    (conv1): TransformerConv(3, 1, heads=1)
    (bn1): GraphNorm(1)
  )
  (discriminator): Discriminator(
    (dense_1): Dense()
    (relu_1): ReLU(inplace=True)
    (dense_2): Dense()
    (relu_2): ReLU(inplace=True)
    (dense_3): Dense()
    (sigmoid): Sigmoid()
  )
)


100%|██████████| 111/111 [05:27<00:00,  2.95s/it]


Epoch 1/60, Generator Loss: 0.18148562526917672, Discriminator Loss: 47.14263756049646


100%|██████████| 111/111 [05:32<00:00,  2.99s/it]


Epoch 2/60, Generator Loss: 0.16111744376453194, Discriminator Loss: 49.664641268618475


  5%|▍         | 5/111 [00:14<05:15,  2.98s/it]

# Evaluating each fold

In [None]:
from hydra import compose, initialize
from src.models.stp_gsr import STPGSR
import torch
from tqdm import tqdm

with initialize(version_base=None, config_path="configs"):
    config = compose(config_name="experiment")

model = STPGSR(config)
model.load_state_dict(torch.load("results/stp_gsr/train/run4/fold_3/model.pth", map_location=torch.device("cuda")))

# Training final model and running predictions

In [None]:
from src.matrix_vectorizer import MatrixVectorizer
from src.dataset import create_pyg_graph
from functools import partial
import numpy as np

source_vectorized = np.genfromtxt("lr_test.csv", delimiter=",", skip_header=1)
source_mat_all = [MatrixVectorizer.anti_vectorize(A, 160) for A in source_vectorized]

source_mat_all = [torch.tensor(x, dtype=torch.float) for x in source_mat_all]
pyg_partial = partial(create_pyg_graph, node_feature_init="adj", node_feat_dim=160)

source_pyg_all = [pyg_partial(x, 160) for x in source_mat_all]
source_data = [{'pyg': source_pyg, 'mat': source_mat} for source_pyg, source_mat in zip(source_pyg_all, source_mat_all)]

In [None]:
from src.dual_graph_utils import revert_dual

model.eval()
eval_output = []

with torch.no_grad():
    for source in tqdm(source_data):
        source_g = source['pyg']

        model_pred, model_target = model(source_g, None)
        pred_m = revert_dual(model_pred, 268)    # (n_t, n_t)
        pred_m = pred_m.cpu().numpy()
        eval_output.append(pred_m)

eval_output

# Submission Generation code

In [None]:
from src.matrix_vectorizer import MatrixVectorizer
import pandas as pd

test_array = np.concatenate([MatrixVectorizer.vectorize(eo) for eo in eval_output])

output_df = pd.DataFrame({"Predicted": test_array.flatten()})
output_df.index = np.arange(1, len(output_df) + 1)
output_df.to_csv("submission.csv", index_label="ID")
output_df