In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import torchvision
import torchmetrics
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
entity = "haraghi"
project = "DGCNN"

In [None]:
run_ids = ['dbdjw7lr']

artifact_dirs = [WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best") for run_id in run_ids]

In [None]:
api = wandb.Api()

In [None]:
cfg_bare = OmegaConf.load("config_bare.yaml")
configs = [api.run(osp.join(entity, project, run_id)).config for run_id in run_ids]
cfgs = [OmegaConf.merge(cfg_bare,OmegaConf.create(config)) for config in configs]
cfg_files = []
for cfg in cfgs:
    if "cfg_path" in cfg.keys():
        print(cfg.cfg_path)
        cfg_files.append(OmegaConf.merge(cfg_bare,OmegaConf.load(cfg.cfg_path)))
    else:
        cfg_files.append(cfg)
            
    
# cfg = OmegaConf.merge(cfg_file, cfg)
# print(OmegaConf.to_yaml(cfg))

In [None]:
def recursive_dict_compare(all_cfg, other_cfg):
    """
    Recursively compare two dictionaries and return their differences.
    """

    
    # Initialize the result dictionary
    diff = {}

    # Check for keys in dict1 that are not in dict2
    for key in other_cfg:
        if key not in all_cfg:
            diff[key] = other_cfg[key]
        else:
            # If the values are dictionaries, recursively compare them
            if isinstance(all_cfg[key], dict) and isinstance(other_cfg[key], dict):
                nested_diff = recursive_dict_compare(all_cfg[key], other_cfg[key])
                if nested_diff:
                    diff[key] = nested_diff
            # Otherwise, compare the values directly
            elif all_cfg[key] != other_cfg[key]:
                if not(key == "num_classes" and other_cfg[key] is None and all_cfg[key] is not None):
                    diff[key] = other_cfg[key]
                    

    return diff


In [None]:
print([recursive_dict_compare(OmegaConf.to_object(cfg),OmegaConf.to_object(cfg_file)) for cfg, cfg_file in zip(cfgs, cfg_files)])

In [None]:
# Seed everything. Note that this does not make training entirely
# deterministic.
for cfg in cfgs:
    pl.seed_everything(cfg.seed, workers=True)

for cfg in cfgs[1:]:
    compare_dict = recursive_dict_compare(OmegaConf.to_object(cfgs[0].dataset),OmegaConf.to_object(cfg.dataset))
    if len(compare_dict)!=0:
        if not (len(compare_dict) == 1 and 'num_workers' in compare_dict.keys()):
            print(compare_dict)
            print(cfg.dataset)
            print(cfgs[0].dataset)
            # raise Exception("Datasets are not the same")
# Create datasets using factory pattern


gdm = GraphDataModule(cfgs[0])
for cfg in cfgs:
    cfg.dataset.num_classes = gdm.num_classes

In [None]:
models = [model_factory.factory(cfg) for cfg in cfgs]

# Tie it all together with PyTorch Lightning: Runner contains the model,
# optimizer, loss function and metrics; Trainer executes the
# training/validation loops and model checkpointing.
 
runners = [Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model) for artifact_dir, cfg, model in zip(artifact_dirs, cfgs, models)]

gdms = [GraphDataModule(cfg) for cfg in cfgs]
dss = []
for cfg,gdm in zip(cfgs, gdms):
    cfg.dataset.num_classes = gdm.num_classes

    dss.append(create_dataset(
        dataset_path = gdm.dataset_path,
        dataset_name  = gdm.dataset_name,
        dataset_type = 'test',
        transform = gdm.transform_dict['test'],
        num_workers=gdm.num_workers
    ))

In [None]:
trainer = pl.Trainer(
    enable_progress_bar=True,
    # Use DDP training by default, even for CPU training
    # strategy="ddp_notebook",
    devices=torch.cuda.device_count(),
    accelerator="auto"
)

In [None]:
num_test = 5
for runner in runners:
    print(runner.cfg.model)
    acc_list = []
    for i in range(num_test):
        test_results = trainer.test(runner, datamodule=gdm, verbose=False)
        acc_list.append(test_results[0]['test/acc'])

In [None]:
# assuming you have a list of accuracy values called `acc_list`
plt.hist(acc_list)
plt.xlabel('Accuracy')
plt.ylabel('Frequency')
plt.title('Histogram of Accuracy')
plt.show()
print(np.mean(acc_list))
print(np.std(acc_list))

In [None]:
dl = DataLoader(
            create_dataset(
                dataset_path = gdm.dataset_path,
                dataset_name  = gdm.dataset_name,
                dataset_type = 'test',
                transform = gdm.transform_dict['test'],
                num_workers=gdm.num_workers
            ),
            batch_size=cfg.train.batch_size,
            shuffle=False,
            num_workers=gdm.num_workers,
        )
ds = create_dataset(
        dataset_path = gdm.dataset_path,
        dataset_name  = gdm.dataset_name,
        dataset_type = 'test',
        transform = gdm.transform_dict['test'],
        num_workers=gdm.num_workers
    )

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = runners[0].model.to(device)
model.eval()

correct = 0
total = 0

y=torch.tensor([],device=device)
y_hat=torch.tensor([],device=device)
preds = []
targets = []
files = []
with torch.no_grad():
    for data in tqdm(gdm.test_dataloader()):
        files.extend(data.file_id)
        targets.append(data.y)
        data = data.to(device)
        y = torch.cat((y,data.y))
        out = model(data)
        preds.append(out.clone().detach().cpu())
        label = torch.argmax(out, dim=1) 
        y_hat = torch.cat((y_hat,label))
        correct += torch.sum(label == data.y)
        total += data.y.shape[0]
  
y = y.clone().detach().cpu().numpy()
y_hat = y_hat.clone().detach().cpu().numpy() 
preds_ = torch.cat(preds,dim=0) #.permute(0,2,1)
targets_ = torch.cat(targets,dim=0)
metrics = torchmetrics.classification.Accuracy(num_classes=runner.cfg.dataset.num_classes, task="multiclass", top_k=1) 

acc = metrics(preds_, targets_).detach().cpu().numpy()
print(acc)
# return confusion_matrix_computed, y, y_hat, files

In [None]:
class2ind = {c:[] for c in ds.categories}
for i,d in enumerate(ds):
    
    class2ind[d.label[0]].append(i)

In [None]:
name2ind = {d.file_id: i  for i, d in enumerate(ds)}

In [None]:
idx = torch.randint(0, len(ds),(1,))
# idx = [6]
# idx = (name2ind['l_0011.mat'],)
correct = 0
total = 0

y1_hat = []
model.eval()

with torch.no_grad():
    for ii in tqdm(range(1000)):
        
        data = ds[idx][0]
        data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
        data = data.to(device)
        
        out = model(data)
        label = torch.argmax(out, dim=1)
        y1_hat.append(label.cpu().item())
        correct += torch.sum(label == data.y)
        total += data.y.shape[0]
    # loss = runner.loss_fn(y_hat, y)
print(data.file_id)  
print(data.label[0])
print(correct/total)
print([ds.categories[int(c)] for c in np.unique(np.array(y1_hat))])