In [3]:
# Cell 0
import os, sys, random
import numpy as np
import torch

REPO_ROOT = os.getcwd()   # or simply "."
sys.path.insert(0, REPO_ROOT)

print("Repo root:", REPO_ROOT)
print("CUDA:", torch.cuda.is_available())

Repo root: c:\Users\qshah\Documents\Spring 2026\COMBINEX
CUDA: True


In [12]:
# Cell 1
from omegaconf import OmegaConf
from hydra import initialize, compose

from torch.nn import functional as F

from src.utils.dataset import get_dataset
from src.datasets.dataset import DataInfo
from src.utils.models import get_model

# If you want to reuse your training pipeline:
from src.oracles.train.train import Trainer

In [13]:
# Cell 2
CFG_PATH = os.path.join(REPO_ROOT, "config", "config.yaml")  # adjust if different
cfg = OmegaConf.load(CFG_PATH)

with initialize(config_path="config", version_base="1.3"):
    cfg = compose(
        config_name="config",
        overrides=[
            "task=node",
            "dataset=citeseer",   # or cora / pubmed / etc
        ],
    )


print(cfg.task.name)      # "Node"
print(cfg.dataset.name)   # "CiteSeer"

Node
citeseer


In [15]:
# Cell 3

def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(int(cfg.general.seed))

device = "cuda" if torch.cuda.is_available() and cfg.device == "cuda" else "cpu"
print("Device:", device)

data = get_dataset(cfg.dataset.name, test_size=cfg.test_size)
data = data.to(device)

datainfo = DataInfo(cfg, data)  # note: your DataInfo deletes self.data internally
print("num_features:", datainfo.num_features)
print("num_classes:", datainfo.num_classes)

Device: cuda
num_features: 3703
num_classes: 6


In [22]:
OracleClass = get_model(name=cfg.model.name, task=cfg.task.name)
oracle = OracleClass(
    num_features=datainfo.num_features,
    num_classes=datainfo.num_classes,
    cfg=cfg,
).to(device)

trainer = Trainer(cfg=cfg, dataset=data, model=oracle, loss=F.cross_entropy)
trainer.start_training()

oracle = trainer.model
oracle.eval()

print("Oracle ready.")

name='CHEB', task='Node'
Epoch:    0 Train Loss: 1.7756 Train Acc: 0.2015 Test Loss: 1.7727 Test Acc: 0.1892


Error: You must call wandb.init() before wandb.log()