In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import copy
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
import matplotlib.pyplot as plt
import torchmetrics


In [None]:
entity = "haraghi"
project = "DGCNN"

In [None]:
# run_ids = ['syifhzlv','1iyq2lum'] # 64 events, 20000 events
# run_ids = ['icey2rjl','1iyq2lum'] # 8 events, 20000 events
# run_ids = ['7zx0vpka','1iyq2lum'] # 8 events, 20000 events
# run_ids = ['02o9q7aq','02o9q7aq'] # ShuffleNet 64 events
# run_ids = ['9rrxu350','x4lf35wy']
run_ids = ['9rrxu350','x4lf35wy'] # Fan data: 1024 events, 25000 events
# 
SPARSE = 0
DENSE = 1

artifact_dirs = [WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best") for run_id in run_ids]

In [None]:
api = wandb.Api()

In [None]:
cfg_bare = OmegaConf.load("config_bare.yaml")
configs = [api.run(osp.join(entity, project, run_id)).config for run_id in run_ids]
cfgs = [OmegaConf.merge(cfg_bare,OmegaConf.create(config)) for config in configs]
cfg_files = []
for cfg in cfgs:
    if "cfg_path" in cfg.keys():
        print(cfg.cfg_path)
        cfg_files.append(OmegaConf.merge(cfg_bare,OmegaConf.load(cfg.cfg_path)))
    else:
        cfg_files.append(cfg)
            
    
# cfg = OmegaConf.merge(cfg_file, cfg)
# print(OmegaConf.to_yaml(cfg))

In [None]:
def recursive_dict_compare(all_cfg, other_cfg):
    """
    Recursively compare two dictionaries and return their differences.
    """

    
    # Initialize the result dictionary
    diff = {}

    # Check for keys in dict1 that are not in dict2
    for key in other_cfg:
        if key not in all_cfg:
            diff[key] = other_cfg[key]
        else:
            # If the values are dictionaries, recursively compare them
            if isinstance(all_cfg[key], dict) and isinstance(other_cfg[key], dict):
                nested_diff = recursive_dict_compare(all_cfg[key], other_cfg[key])
                if nested_diff:
                    diff[key] = nested_diff
            # Otherwise, compare the values directly
            elif all_cfg[key] != other_cfg[key]:
                if not(key == "num_classes" and other_cfg[key] is None and all_cfg[key] is not None):
                    diff[key] = other_cfg[key]
                    

    return diff


In [None]:
print([recursive_dict_compare(OmegaConf.to_object(cfg),OmegaConf.to_object(cfg_file)) for cfg, cfg_file in zip(cfgs, cfg_files)])

In [None]:
# Seed everything. Note that this does not make training entirely
# deterministic.
for cfg in cfgs:
    pl.seed_everything(cfg.seed, workers=True)

for cfg in cfgs[1:]:
    compare_dict = recursive_dict_compare(OmegaConf.to_object(cfgs[0].dataset),OmegaConf.to_object(cfg.dataset))
    if len(compare_dict)!=0:
        if not (len(compare_dict) == 1 and 'num_workers' in compare_dict.keys()):
            print(compare_dict)
            print(cfg.dataset)
            print(cfgs[0].dataset)
            # raise Exception("Datasets are not the same")
# Create datasets using factory pattern



In [None]:
gdms = [GraphDataModule(cfg) for cfg in cfgs]
dss = []
for cfg,gdm in zip(cfgs, gdms):
    cfg.dataset.num_classes = gdm.num_classes

    dss.append(create_dataset(
        dataset_path = gdm.dataset_path,
        dataset_name  = gdm.dataset_name,
        dataset_type = 'test',
        transform = gdm.transform_dict['test'],
        num_workers=gdm.num_workers
    ))

In [None]:
class SaveFeatures():
    def __init__(self, module, device=None):
        # we are going to hook some model's layer (module here)
        self.hook = module.register_forward_hook(self.hook_fn)
        self.device = device

    def hook_fn(self, module, input, output):
        # when the module.forward() is executed, here we intercept its
        # input and output. We are interested in the module's output.
        self.features = output.clone()
        if self.device is not None:
            self.features = self.features.to(self.device)
        self.features.requires_grad_(True)

    def close(self):
        # we must call this method to free memory resources
        self.hook.remove()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models = [model_factory.factory(cfg) for cfg in cfgs]

# Tie it all together with PyTorch Lightning: Runner contains the model,
# optimizer, loss function and metrics; Trainer executes the
# training/validation loops and model checkpointing.
 
runners = [Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model) for artifact_dir, cfg, model in zip(artifact_dirs, cfgs, models)]


In [None]:
model_sparse = runners[SPARSE].model.to(device)
feature_maps_sparse = [ SaveFeatures(model_sparse.classifier.relu,device=device),
                        SaveFeatures(model_sparse.classifier.layer1[-1],device=device),
                        SaveFeatures(model_sparse.classifier.layer2[-1],device=device),
                        SaveFeatures(model_sparse.classifier.layer3[-1],device=device),
                        SaveFeatures(model_sparse.classifier.layer4[-1],device=device)]

model_dense = runners[DENSE].model.to(device)
feature_maps_dense = [ SaveFeatures(model_dense.classifier.relu,device=device),    
                        SaveFeatures(model_dense.classifier.layer1[-1],device=device),
                        SaveFeatures(model_dense.classifier.layer2[-1],device=device),
                        SaveFeatures(model_dense.classifier.layer3[-1],device=device),
                        SaveFeatures(model_dense.classifier.layer4[-1],device=device)]


# For shuffleNet
# model.classifier.conv1.register_forward_hook(hook_fn)
# model.classifier.stage2[-1].register_forward_hook(hook_fn)
# model.classifier.stage3[-1].register_forward_hook(hook_fn)
# model.classifier.stage4[-1].register_forward_hook(hook_fn)
# model.classifier.conv5.register_forward_hook(hook_fn)



In [None]:
torch.cuda.empty_cache()

idx = np.random.randint(len(dss[SPARSE]))
# idx = 10799

data = dss[SPARSE][idx]
data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
data = data.to(device)

model_sparse.eval()
with torch.no_grad():
    model_sparse(data)
    
# Visualize the sparse input
vox_sparse = model_sparse.quantization_layer.forward(data)
vox_cropped_sparse = model_sparse.crop_and_resize_to_resolution(vox_sparse, model_sparse.crop_dimension)
vox_sparse = vox_sparse.clone().detach().cpu().numpy()
vox_cropped_sparse = vox_cropped_sparse.clone().detach().cpu().numpy()

print('voxel grid: ', vox_sparse[0].shape)
print('sparse input with 64 events per sample',flush=True)
plt.figure(figsize=(10, 10))
for i in range(np.min([vox_sparse[0].shape[0],256])):
    plt.subplot(9, 9, i + 1)
    plt.imshow(vox_sparse[0][i], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.axis('off')
plt.show()

print('voxel cropped grid: ', vox_cropped_sparse[0].shape)
print('sparse input with 64 events per sample',flush=True)
plt.figure(figsize=(10, 10))
for i in range(np.min([vox_cropped_sparse[0].shape[0],256])):
    plt.subplot(9, 9, i + 1)
    plt.imshow(vox_cropped_sparse[0][i], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.axis('off')
plt.show()

data = dss[DENSE][idx]
data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
data = data.to(device)

model_dense.eval()
with torch.no_grad():
    model_dense(data)

# Visualize the dense input
vox_dense = model_dense.quantization_layer.forward(data)
vox_cropped_dense = model_dense.crop_and_resize_to_resolution(vox_dense, model_dense.crop_dimension)
vox_dense = vox_dense.clone().detach().cpu().numpy()
vox_cropped_dense = vox_cropped_dense.clone().detach().cpu().numpy()

print('voxel grid: ', vox_sparse[0].shape)
print('dense input with 20000 events per sample',flush=True)
plt.figure(figsize=(10, 10))
for i in range(np.min([vox_dense[0].shape[0],256])):
    plt.subplot(9, 9, i + 1)
    plt.imshow(vox_dense[0][i], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.axis('off')
plt.show()

print('voxel cropped grid: ', vox_cropped_dense[0].shape)
print('dense input with 20000 events per sample',flush=True)
plt.figure(figsize=(10, 10))
for i in range(np.min([vox_cropped_dense[0].shape[0],256])):
    plt.subplot(9, 9, i + 1)
    plt.imshow(vox_cropped_dense[0][i], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.axis('off')
plt.show()

for number, fm in enumerate(feature_maps_sparse):
    print('sparse input with 64 events per sample',flush=True)
    
    feature_map_sparse = fm.features.clone().detach().cpu().numpy()
    print(number,feature_map_sparse[0].shape)
    plt.figure(figsize=(12, 8))
    for i in range(np.min([feature_map_sparse[0].shape[0],48])):
        plt.subplot(6, 8, i + 1)
        plt.imshow(feature_map_sparse[0][i], cmap='viridis')
        plt.gca().invert_yaxis()
        plt.axis('off')
    plt.show()
    
    print('dense input with 20000 events per sample',flush=True)
    feature_map_dense = feature_maps_dense[number].features.clone().detach().cpu().numpy()
    print(number,feature_map_dense[0].shape)
    plt.figure(figsize=(12, 8))
    
    for i in range(np.min([feature_map_dense[0].shape[0],48])):
        plt.subplot(6, 8, i + 1)
        plt.imshow(feature_map_dense[0][i], cmap='viridis')
        plt.gca().invert_yaxis()
        plt.axis('off')
    plt.show()

In [None]:
# vox_dense 
# vox_cropped_dense
# feature_maps_dense

# tensor_1 = vox_sparse
# tensor_2 = vox_dense

tensor_1 = vox_cropped_sparse
tensor_2 = vox_cropped_dense

# tensor_1 = feature_maps_sparse[1].features.clone().detach().cpu().numpy()
# tensor_2 = feature_maps_dense[1].features.clone().detach().cpu().numpy()

# plt.figure(figsize=(10, 10))
print(tensor_1.shape)
print(tensor_2.shape)
for i in np.arange(9):#range(np.min([tensor_1[0].shape[0],20])):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    print(i,'sparse')
    img = tensor_1[0,i,...] + tensor_1[0,i+9,...]
    # plt.imshow(np.concatenate([tensor_1[0,i,...],tensor_2[0,i,...]],axis = 1), cmap='binary')
    plt.imshow(img, cmap='binary')
    plt.gca().invert_yaxis()
    plt.colorbar()
    plt.axis('off')
    # plt.show()
    # print(i,'dense')
    img = tensor_2[0,i,...] + tensor_2[0,i+9,...]
    plt.subplot(1, 2, 2)
    plt.imshow(img, cmap='binary')
    plt.gca().invert_yaxis()
    plt.colorbar()
    plt.axis('off')
    plt.show()

In [None]:
tensor_1 = vox_sparse
save_folder = os.path.join('images','feature_maps','fan')
plt.figure(figsize=(12, 12))
plt.imshow(((vox_cropped_sparse[0,0,...])+1e-6), cmap='viridis')
orig_size = vox_cropped_sparse.shape[2]
print(orig_size)
plt.gca().invert_yaxis()
plt.colorbar(fraction=0.046, pad=0.04)
plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_sparse_0.png'), bbox_inches='tight')
plt.show()

tensor = feature_maps_sparse[0].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(4):
    plt.subplot(2, 2, i + 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_sparse_1.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_sparse[1].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(9):
    plt.subplot(3, 3, i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_sparse_2.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_sparse[2].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(16):
    plt.subplot(4,4 , i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_sparse_3.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_sparse[3].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(25):
    plt.subplot(5,5 , i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_sparse_4.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_sparse[4].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(36):
    plt.subplot(6,6 , i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.pcolor(tensor[0,i,...], cmap='viridis', vmin=0, vmax=1)
    # plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
    plt.axis('equal')
plt.savefig(os.path.join(save_folder,'nasl_sparse_5.png'), bbox_inches='tight')
plt.show()

In [None]:
tensor_1 = vox_dense

plt.figure(figsize=(12, 12))
plt.imshow(((vox_cropped_dense[0,0,...])+1e-6), cmap='viridis')
orig_size = vox_cropped_dense.shape[2]
print(orig_size)
plt.gca().invert_yaxis()
plt.colorbar(fraction=0.046, pad=0.04)
plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_dense_0.png'), bbox_inches='tight')
plt.show()


tensor = feature_maps_dense[0].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(4):
    plt.subplot(2, 2, i + 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_dense_1.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_dense[1].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(9):
    plt.subplot(3, 3, i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_dense_2.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_dense[2].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(16):
    plt.subplot(4,4 , i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_dense_3.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_dense[3].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(25):
    plt.subplot(5,5 , i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
plt.savefig(os.path.join(save_folder,'nasl_dense_4.png'), bbox_inches='tight')
plt.show()
tensor = feature_maps_dense[4].features.clone().detach().cpu().numpy()
plt.figure(figsize=(12,12))
for i in range(36):
    plt.subplot(6,6 , i+ 1)
    tensor[0,i,...] = tensor[0,i,...]/(np.max(tensor[0,i,...], axis=None)+1e-8)
    plt.pcolor(tensor[0,i,...], cmap='viridis', vmin=0, vmax=1)
    # plt.imshow(tensor[0,i,...], cmap='viridis')
    plt.gca().invert_yaxis()
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.axis('off')
    plt.axis('equal')
plt.savefig(os.path.join(save_folder,'nasl_dense_5.png'), bbox_inches='tight')
plt.show()

In [None]:
runners[0].model.classifier

In [None]:
for name, param in runners[0].model.named_parameters():
    print(name)

In [None]:
torch.cuda.empty_cache()


model_dense.eval()

correct = 0
total = 0

y=torch.tensor([],device=device)
y_hat=torch.tensor([],device=device)
preds = []
targets = []
files = []
with torch.no_grad():
    for data in tqdm(dss[SPARSE]):
        data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
        files.extend(data.file_id)
        targets.append(data.y)
        data = data.to(device)
        y = torch.cat((y,data.y))
        out = model_dense(data)
        preds.append(out.clone().detach().cpu())
        label = torch.argmax(out, dim=1) 
        y_hat = torch.cat((y_hat,label))
        correct += torch.sum(label == data.y)
        total += data.y.shape[0]
  
y = y.clone().detach().cpu().numpy()
y_hat = y_hat.clone().detach().cpu().numpy() 
preds_ = torch.cat(preds,dim=0) #.permute(0,2,1)
targets_ = torch.cat(targets,dim=0)
metrics = torchmetrics.classification.Accuracy(num_classes=runners[0].cfg.dataset.num_classes, task="multiclass", top_k=1) 

acc = metrics(preds_, targets_).detach().cpu().numpy()
print(acc)
# return confusion_matrix_computed, y, y_hat, files