In [7]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold

from evaluation import evaluate_matrices
from MatrixVectorizer import MatrixVectorizer

import sys
import os
# Add stp_gsr to the path
sys.path.append('stp_gsr')
from stp_gsr.src.models.stp_gsr import STPGSR
import hydra


In [24]:
class DictToObject:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            if isinstance(value, dict):
                setattr(self, key, DictToObject(value))
            else:
                setattr(self, key, value)

# Define the dictionary
config_dict = {
    "model": {
        "name": "stp_gsr",
        "target_edge_initializer": {
            "num_heads": 4,
            "edge_dim": 1,
            "hidden_dim": 256,
            "num_layers": 3,
            "dropout": 0.2,
            "beta": False
        },
        "dual_learner": {
            "in_dim": 1,
            "out_dim": 1,
            "num_heads": 1,
            "dropout": 0.2,
            "beta": False
        }
    },
    "dataset": {
        "name": "custom",
        "n_samples": 167,
        "n_source_nodes": 160,
        "n_target_nodes": 268,
        "node_feat_init": "adj",
        "node_feat_dim": 160  # Assuming this is derived from n_source_nodes
    },
    "experiment": {
        "defaults": [
            "_self_",
            {"dataset": "sbm"},
            {"model": "stp_gsr"}
        ],
        "n_epochs": 60,
        "batch_size": 16,
        "lr": 0.001,
        "log_val_loss": False,
        "base_dir": "results",
        "run_name": "run1",
        "kfold": {
            "n_splits": 3,
            "shuffle": True,
            "random_state": 42
        }
    }
}

In [25]:
config = DictToObject(config_dict)

# Accessing attributes
print(config.model.name)  # Output: stp_gsr
print(config.model.target_edge_initializer.num_heads)  # Output: 4
print(config.dataset.n_samples)  # Output: 167
print(config.experiment.n_epochs)  # Output: 60
print(config.experiment.kfold.n_splits) 

stp_gsr
4
167
60
3


In [30]:
from stp_gsr.src.dataset import load_test
from stp_gsr.src.train import eval

ImportError: cannot import name 'load_test' from 'stp_gsr.src.dataset' (/vol/bitbucket/zz1224/dgl/dgl-project/stp_gsr/src/dataset.py)

In [None]:
# Load the test data and the model, and evaluate the model on the test data.
sys.path.append('stp_gsr')
LR_TEST_DATA_FILE_NAME = 'data/lr_test.csv'
MODEL_FILE_NAME = 'stp_gsr/results/stp_gsr_modified/final/model.pth'

# Load the test data
# test_data = pd.read_csv(LR_TEST_DATA_FILE_NAME)

# Load the model
model = STPGSR(config) 
model.load_state_dict(torch.load(MODEL_FILE_NAME))
model.eval()

# Change test data to torch tensor
# test_data = torch.tensor(test_data.values, dtype=torch.float32)

test_data = load_test(config)

# Predict the test data
y_pred = model()
print(y_pred.shape)



TargetEdgeInitializer(
  (convs): ModuleList(
    (0): TransformerConv(160, 64, heads=4)
    (1): TransformerConv(256, 64, heads=4)
    (2): TransformerConv(256, 67, heads=4)
  )
  (bns): ModuleList(
    (0-1): 2 x GraphNorm(256)
    (2): GraphNorm(268)
  )
  (residual_proj): Sequential(
    (0): Linear(in_features=160, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=268, bias=True)
  )
)
--------------------------------------------------
DualGraphLearner(
  (conv1): TransformerConv(1, 1, heads=1)
  (bn1): GraphNorm(1)
)


TypeError: STPGSR.forward() missing 1 required positional argument: 'target_mat'

In [5]:
test_data

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,12710,12711,12712,12713,12714,12715,12716,12717,12718,12719
0,0.381556,0.008404,0.704672,0.459716,0.638838,0.429525,0.654569,0.670515,0.430826,0.767706,...,0.044062,0.000000,0.027265,0.074954,0.048790,0.521427,0.030686,0.348275,0.431899,0.393250
1,0.697967,0.000000,0.000000,0.661155,0.841448,0.073916,0.739038,0.643867,0.175173,0.590896,...,0.258033,0.377616,0.360398,0.241123,0.476553,0.436636,0.081685,0.587660,0.670123,0.491657
2,0.355362,0.244858,0.249740,0.511696,0.628156,0.123349,0.691025,0.193289,0.214599,0.081607,...,0.180335,0.295169,0.337348,0.392652,0.205985,0.545668,0.376009,0.463320,0.488176,0.565206
3,0.313823,0.278309,0.000000,0.485659,0.755216,0.087324,0.506895,0.681023,0.175964,0.542607,...,0.000000,0.000000,0.045189,0.110365,0.000000,0.669933,0.342143,0.221058,0.614712,0.161293
4,0.561839,0.110753,0.486447,0.335666,0.679871,0.408250,0.491914,0.758904,0.526859,0.534782,...,0.000000,0.279874,0.000000,0.000000,0.000000,0.575274,0.074075,0.566045,0.392693,0.227096
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
107,0.612729,0.022177,0.707731,0.617289,0.799931,0.747803,0.559406,0.914120,0.933610,0.908674,...,0.226344,0.840599,0.672141,0.526567,0.893927,0.795447,0.797217,0.861304,0.721687,0.820384
108,0.441047,0.176477,0.113680,0.423133,0.448606,0.336747,0.426157,0.488343,0.328519,0.696958,...,0.374385,0.308081,0.534995,0.274425,0.326984,0.653298,0.259562,0.735426,0.666054,0.541001
109,0.372749,0.125366,0.389338,0.412203,0.774193,0.356878,0.684710,0.611649,0.374065,0.640083,...,0.277828,0.491823,0.325917,0.450591,0.462834,0.523909,0.642218,0.599594,0.619622,0.474699
110,0.535746,0.418461,0.656996,0.772941,0.751991,0.490863,0.638072,0.517141,0.366644,0.659283,...,0.124788,0.115926,0.098516,0.040183,0.245019,0.434175,0.288351,0.533405,0.254862,0.151185
