In [None]:
import pandas as pd
import torch
from torch_geometric.data import Data
from ast import literal_eval

# Input File paths for extracted features
classification_csv = "../2_feature_extraction/TRPM8_graph_features_class_w_atomic_onehot_encoder.csv"
regression_csv = "../2_feature_extraction/TRPM8_graph_features_regression_w_atomic_onehot_encoder.csv"

def load_graph_data(csv_file, label_column):
    df = pd.read_csv(csv_file)
    dataset = []

    for i, row in df.iterrows():
        try:
            node_features = torch.tensor(literal_eval(row["node_features"]), dtype=torch.float)
            edge_features = torch.tensor(literal_eval(row["edge_features"]), dtype=torch.float)
            edge_indices = torch.tensor(literal_eval(row["edge_indices"]), dtype=torch.long).t().contiguous()
            label = torch.tensor([row[label_column]], dtype=torch.float if label_column == "pChEMBL Value" else torch.long)

            data = Data(x=node_features, edge_index=edge_indices, edge_attr=edge_features, y=label)
            dataset.append(data)
        except Exception as e:
            print(f"⚠️ Error processing row {i} ({row['mol_id']}): {e}")
            continue

    return dataset

# Load and save
classification_data = load_graph_data(classification_csv, "class_label")
torch.save(classification_data, "../3_graph_data/TRPM8_classification_graph_dataset_v2.pt")

regression_data = load_graph_data(regression_csv, "pChEMBL Value")
torch.save(regression_data, "../3_graph_data/TRPM8_regression_graph_dataset_v2.pt")
