In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import torchvision
import torchmetrics
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
entity = "haraghi"
project = "DGCNN"

In [None]:
# run_ids = ['syifhzlv','1iyq2lum']
run_ids = ['9rrxu350','x4lf35wy']

artifact_dirs = [WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best") for run_id in run_ids]

In [None]:
api = wandb.Api()

In [None]:
cfg_bare = OmegaConf.load("config_bare.yaml")
configs = [api.run(osp.join(entity, project, run_id)).config for run_id in run_ids]
cfgs = [OmegaConf.merge(cfg_bare,OmegaConf.create(config)) for config in configs]
cfg_files = []
for cfg in cfgs:
    if "cfg_path" in cfg.keys():
        print(cfg.cfg_path)
        cfg_files.append(OmegaConf.merge(cfg_bare,OmegaConf.load(cfg.cfg_path)))
    else:
        cfg_files.append(cfg)
            
    
# cfg = OmegaConf.merge(cfg_file, cfg)
# print(OmegaConf.to_yaml(cfg))

In [None]:
def recursive_dict_compare(all_cfg, other_cfg):
    """
    Recursively compare two dictionaries and return their differences.
    """

    
    # Initialize the result dictionary
    diff = {}

    # Check for keys in dict1 that are not in dict2
    for key in other_cfg:
        if key not in all_cfg:
            diff[key] = other_cfg[key]
        else:
            # If the values are dictionaries, recursively compare them
            if isinstance(all_cfg[key], dict) and isinstance(other_cfg[key], dict):
                nested_diff = recursive_dict_compare(all_cfg[key], other_cfg[key])
                if nested_diff:
                    diff[key] = nested_diff
            # Otherwise, compare the values directly
            elif all_cfg[key] != other_cfg[key]:
                if not(key == "num_classes" and other_cfg[key] is None and all_cfg[key] is not None):
                    diff[key] = other_cfg[key]
                    

    return diff


In [None]:
print([recursive_dict_compare(OmegaConf.to_object(cfg),OmegaConf.to_object(cfg_file)) for cfg, cfg_file in zip(cfgs, cfg_files)])

In [None]:
# Seed everything. Note that this does not make training entirely
# deterministic.
for cfg in cfgs:
    pl.seed_everything(cfg.seed, workers=True)

for cfg in cfgs[1:]:
    compare_dict = recursive_dict_compare(OmegaConf.to_object(cfgs[0].dataset),OmegaConf.to_object(cfg.dataset))
    if len(compare_dict)!=0:
        if not (len(compare_dict) == 1 and 'num_workers' in compare_dict.keys()):
            print(compare_dict)
            print(cfg.dataset)
            print(cfgs[0].dataset)
            # raise Exception("Datasets are not the same")
# Create datasets using factory pattern



In [None]:
models = [model_factory.factory(cfg) for cfg in cfgs]

# Tie it all together with PyTorch Lightning: Runner contains the model,
# optimizer, loss function and metrics; Trainer executes the
# training/validation loops and model checkpointing.
 
runners = [Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model) for artifact_dir, cfg, model in zip(artifact_dirs, cfgs, models)]

gdms = [GraphDataModule(cfg) for cfg in cfgs]
dss = []
for cfg,gdm in zip(cfgs, gdms):
    cfg.dataset.num_classes = gdm.num_classes

    dss.append(create_dataset(
        dataset_path = gdm.dataset_path,
        dataset_name  = gdm.dataset_name,
        dataset_type = 'test',
        transform = gdm.transform_dict['test'],
        num_workers=gdm.num_workers
    ))

In [None]:
trainer = pl.Trainer(
    enable_progress_bar=True,
    # Use DDP training by default, even for CPU training
    # strategy="ddp_notebook",
    devices=torch.cuda.device_count(),
    accelerator="auto"
)

In [None]:

# Function to visualize feature maps and filters in intermediate layers
def visualize_layers(model, layer_num, input_image):
    # Define hooks to retrieve feature maps and filters
    feature_maps = []
    filters = []
    
    torch.cuda.empty_cache()

    model.eval()

    def hook_fn(module, input, output):
        feature_maps.append(output.data.cpu().numpy())

    model.classifier.conv1.register_forward_hook(hook_fn)
    model.classifier.layer1[-1].conv2.register_forward_hook(hook_fn)
    model.classifier.layer2[-1].conv2.register_forward_hook(hook_fn)
    model.classifier.layer3[-1].conv2.register_forward_hook(hook_fn)
    model.classifier.layer4[-1].conv2.register_forward_hook(hook_fn)

    # Forward pass to get the feature maps
    with torch.no_grad():
        model(input_image)

    # Visualize feature maps
    


    # Visualize learned filters
    # filter_weights = model.classifier.layer4[2].conv2.weight.data.cpu().numpy()
    # plt.figure(figsize=(15, 15))
    # for i in range(np.min([filter_weights.shape[0],64])):
    #     plt.subplot(16, 16, i + 1)
    #     plt.imshow(filter_weights[i, 2], cmap='viridis')
    #     plt.axis('off')
    # plt.show()
    return feature_maps


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = dss[0][5]
data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
data = data.to(device)

# Visualize the specified layer
feature_maps_sparse = visualize_layers(runners[0].model.to(device), 'classifier.conv1', data)


data = dss[1][5]
data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
data = data.to(device)

# Visualize the specified layer
feature_maps_dense = visualize_layers(runners[1].model.to(device), 'classifier.conv1', data)

for number, feature_map in enumerate(feature_maps_sparse):
    print(number,feature_map[0].shape)
    print('sparse input with 64 events per sample')
    plt.figure(figsize=(10, 10))
    for i in range(np.min([feature_maps_sparse[number][0].shape[0],256])):
        plt.subplot(16, 16, i + 1)
        plt.imshow(feature_maps_sparse[number][0][i], cmap='viridis')
        plt.axis('off')
    plt.show()
    plt.figure(figsize=(10, 10))
    print('dense input with 20000 events per sample')
    for i in range(np.min([feature_maps_dense[number][0].shape[0],256])):
        plt.subplot(16, 16, i + 1)
        plt.imshow(feature_maps_dense[number][0][i], cmap='viridis')
        plt.axis('off')
    plt.show()

In [None]:
for name, param in runners[1].model.classifier.named_parameters():
    print(name)
runners[1].model.classifier

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = runners[0].model.to(device)
model.eval()

correct = 0
total = 0

y=torch.tensor([],device=device)
y_hat=torch.tensor([],device=device)
preds = []
targets = []
files = []
with torch.no_grad():
    for data in tqdm(gdm.test_dataloader()):
        files.extend(data.file_id)
        targets.append(data.y)
        data = data.to(device)
        y = torch.cat((y,data.y))
        out = model(data)
        preds.append(out.clone().detach().cpu())
        label = torch.argmax(out, dim=1) 
        y_hat = torch.cat((y_hat,label))
        correct += torch.sum(label == data.y)
        total += data.y.shape[0]
  
y = y.clone().detach().cpu().numpy()
y_hat = y_hat.clone().detach().cpu().numpy() 
preds_ = torch.cat(preds,dim=0) #.permute(0,2,1)
targets_ = torch.cat(targets,dim=0)
metrics = torchmetrics.classification.Accuracy(num_classes=runner.cfg.dataset.num_classes, task="multiclass", top_k=1) 

acc = metrics(preds_, targets_).detach().cpu().numpy()
print(acc)
# return confusion_matrix_computed, y, y_hat, files