In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import torchvision
import torchmetrics
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
entity = "haraghi"
project = "DGCNN"

In [None]:
run_ids = ['9rrxu350','x4lf35wy']

artifact_dirs = [WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best") for run_id in run_ids]

In [None]:
api = wandb.Api()

In [None]:
cfg_bare = OmegaConf.load("config_bare.yaml")
configs = [api.run(osp.join(entity, project, run_id)).config for run_id in run_ids]
cfgs = [OmegaConf.merge(cfg_bare,OmegaConf.create(config)) for config in configs]
cfg_files = []
for cfg in cfgs:
    if "cfg_path" in cfg.keys():
        print(cfg.cfg_path)
        cfg_files.append(OmegaConf.merge(cfg_bare,OmegaConf.load(cfg.cfg_path)))
    else:
        cfg_files.append(cfg)
            
    
# cfg = OmegaConf.merge(cfg_file, cfg)
# print(OmegaConf.to_yaml(cfg))

In [None]:
def recursive_dict_compare(all_cfg, other_cfg):
    """
    Recursively compare two dictionaries and return their differences.
    """

    
    # Initialize the result dictionary
    diff = {}

    # Check for keys in dict1 that are not in dict2
    for key in other_cfg:
        if key not in all_cfg:
            diff[key] = other_cfg[key]
        else:
            # If the values are dictionaries, recursively compare them
            if isinstance(all_cfg[key], dict) and isinstance(other_cfg[key], dict):
                nested_diff = recursive_dict_compare(all_cfg[key], other_cfg[key])
                if nested_diff:
                    diff[key] = nested_diff
            # Otherwise, compare the values directly
            elif all_cfg[key] != other_cfg[key]:
                if not(key == "num_classes" and other_cfg[key] is None and all_cfg[key] is not None):
                    diff[key] = other_cfg[key]
                    

    return diff


In [None]:
print([recursive_dict_compare(OmegaConf.to_object(cfg),OmegaConf.to_object(cfg_file)) for cfg, cfg_file in zip(cfgs, cfg_files)])

In [None]:
# Seed everything. Note that this does not make training entirely
# deterministic.
for cfg in cfgs:
    pl.seed_everything(cfg.seed, workers=True)

for cfg in cfgs[1:]:
    compare_dict = recursive_dict_compare(OmegaConf.to_object(cfgs[0].dataset),OmegaConf.to_object(cfg.dataset))
    if len(compare_dict)!=0:
        if not (len(compare_dict) == 1 and 'num_workers' in compare_dict.keys()):
            print(compare_dict)
            print(cfg.dataset)
            print(cfgs[0].dataset)
            # raise Exception("Datasets are not the same")
# Create datasets using factory pattern


gdm = GraphDataModule(cfgs[0])
for cfg in cfgs:
    cfg.dataset.num_classes = gdm.num_classes

In [None]:
def percentile(t, q):
    B, C, H, W = t.shape
    k = 1 + round(.01 * float(q) * (C * H * W - 1))
    result = t.view(B, -1).kthvalue(k).values
    return result[:,None,None,None]

def create_image(representation):
    B, C, H, W = representation.shape
    representation = representation.view(B, 3, C // 3, H, W).sum(2)

    # do robust min max norm
    representation = representation.detach().cpu()
    robust_max_vals = percentile(representation, 99)
    robust_min_vals = percentile(representation, 1)

    representation = (representation - robust_min_vals)/(robust_max_vals - robust_min_vals)
    representation = torch.clamp(255*representation, 0, 255).byte()

    representation = torchvision.utils.make_grid(representation)

    return representation

In [None]:
models = [model_factory.factory(cfg) for cfg in cfgs]

# Tie it all together with PyTorch Lightning: Runner contains the model,
# optimizer, loss function and metrics; Trainer executes the
# training/validation loops and model checkpointing.
 
runners = [Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model) for artifact_dir, cfg, model in zip(artifact_dirs, cfgs, models)]

gdms = [GraphDataModule(cfg) for cfg in cfgs]
dss = []
for cfg,gdm in zip(cfgs, gdms):
    cfg.dataset.num_classes = gdm.num_classes

    dss.append(create_dataset(
        dataset_path = gdm.dataset_path,
        dataset_name  = gdm.dataset_name,
        dataset_type = 'test',
        transform = gdm.transform_dict['test'],
        num_workers=gdm.num_workers
    ))

In [None]:
class2ind = {c:[] for c in dss[0].categories}
for i,d in enumerate(dss[0]):
    
    class2ind[d.label[0]].append(i)

In [None]:
class2ind.keys()

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = runners[1].model.to(device)
model.eval()

class_name = 'speed_1'
target_layers = [model.classifier.layer4[-1]]

idx = class2ind[class_name][torch.randint(0, len(class2ind[class_name]),(1,))]
data = dss[1][idx]
data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
input_tensor = data.to(device) # Create an input tensor image for your model..
# Note: input_tensor can be a batch tensor with several images!
print(data.label[0])

# Construct the CAM object once, and then re-use it on many images:
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)

# You can also use it within a with statement, to make sure it is freed,
# In case you need to re-create it inside an outer loop:
# with GradCAM(model=model, target_layers=target_layers, use_cuda=args.use_cuda) as cam:
#   ...

# We have to specify the target we want to generate
# the Class Activation Maps for.
# If targets is None, the highest scoring category
# will be used for every image in the batch.
# Here we use ClassifierOutputTarget, but you can define your own custom targets
# That are, for example, combinations of categories, or specific outputs in a non standard model.

targets = [ClassifierOutputTarget(data.y.item())]





In [None]:
# You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

# In this example grayscale_cam has only one image in the batch:
grayscale_cam = grayscale_cam[0, :]


vox = runners[0].model.quantization_layer.forward(data.to(device))
   
rep = create_image(vox)
img = np.float32(rep.permute(1,2,0).detach().cpu().numpy())/255

target_size = (img.shape[1], img.shape[0])  # (H, W)

# Resize the image using interpolation
resized_image = cv2.resize(grayscale_cam, target_size, interpolation=cv2.INTER_LINEAR)
# The `interpolation` parameter specifies the interpolation method. You can choose from various methods, including:
# - cv2.INTER_NEAREST (Nearest-neighbor interpolation)
# - cv2.INTER_LINEAR (Bilinear interpolation, which is usually a good choice for upscaling)
# - cv2.INTER_CUBIC (Bicubic interpolation)
# - cv2.INTER_LANCZOS4 (Lanczos interpolation)
visualization = show_cam_on_image(img, resized_image, use_rgb=True)
plt.imshow(visualization)
plt.gca().invert_yaxis()
plt.show()
