In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import shutil
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import torchvision
import torchmetrics
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
entity = "haraghi"
project = "DGCNN"

In [None]:
api = wandb.Api()
runs = api.runs(f"{entity}/{project}")
cfg_bare = OmegaConf.load("config_bare.yaml")
cfgs = [OmegaConf.merge(cfg_bare,OmegaConf.create(run.config)) for run in runs]
dataset_runs = {}
# Get the dataset names from the config file
dataset_names = list(set([cfg.dataset.name for cfg in cfgs]))
for dataset_name in dataset_names:
    # Get the runs for this dataset

    dataset_runs[dataset_name] = {run.id : (run,cfg) for run,cfg in zip(runs,cfgs) if 
                                  cfg.dataset.name == dataset_name and 
                                  cfg.model.name == 'EST' and 
                                  cfg.model.num_bins == 9 and
                                  cfg.model.resnet_pretrained and
                                  'test/acc' in run.summary and
                                  'epoch' in run.summary and
                                  (not cfg.model.cnn_type or cfg.model.cnn_type == "resnet34") and
                                  run.summary['epoch'] > 51 }
    
    # dataset_runs[dataset_name] = sorted(dataset_runs[dataset_name], key=lambda r: r[0].summary['test/acc'], reverse=True)
    
    print(dataset_name, len(dataset_runs[dataset_name]))


In [None]:
def get_sparsity_level(runner, gdm):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.cuda.empty_cache()
    model = runner.model.to(device)
    sparsity_level= {'train':[], 'val':[], 'test':[]}

    model.eval()

    for data in gdm.train_dataloader():   
        data = data.to(device)
        vox = model.quantization_layer.forward(data.to(device)).clone().detach()
        vox_cropped = model.crop_and_resize_to_resolution(vox, model.crop_dimension) 
        sparsity_level['train'].extend([(torch.count_nonzero(v) / torch.numel(v)).item() for v in vox_cropped])
    sparsity_level['train'] = np.array(sparsity_level['train'])[:]
    for data in gdm.val_dataloader():   
        data = data.to(device)
        vox = model.quantization_layer.forward(data.to(device)).clone().detach()
        vox_cropped = model.crop_and_resize_to_resolution(vox, model.crop_dimension) 
        sparsity_level['val'].extend([(torch.count_nonzero(v) / torch.numel(v)).item() for v in vox_cropped])
    sparsity_level['val'] = np.array(sparsity_level['val'])[:]
    for data in gdm.test_dataloader():   
        data = data.to(device)
        vox = model.quantization_layer.forward(data.to(device)).clone().detach()
        vox_cropped = model.crop_and_resize_to_resolution(vox, model.crop_dimension) 
        sparsity_level['test'].extend([(torch.count_nonzero(v) / torch.numel(v)).item() for v in vox_cropped])
    sparsity_level['test'] = np.array(sparsity_level['test'])[:]
    sparsity_level['all'] = np.concatenate([sparsity_level['train'], sparsity_level['val'], sparsity_level['test']])
    
    return sparsity_level

In [None]:
sparsity_results = {}
for dataset_name in dataset_names:
    print(dataset_name)
    print("-"*50)
    sparsity_results[dataset_name] = []
    for run_id,v in dataset_runs[dataset_name].items():
        run, cfg = v
        print(f"{dataset_name}:{run.summary['test/acc']:.4f} {run.summary['epoch']}") 

        try:
            artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")

            gdm = GraphDataModule(cfg) 
            if cfg.dataset.num_classes is None:
                cfg.dataset.num_classes = gdm.num_classes
            
            runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg)) 
    
            sparsity_level = get_sparsity_level(runner, gdm)
        except:
            print(f"Error for {run_id}: {dataset_name}, {cfg.model.name}, {cfg.transform.train.num_events_per_sample}." )
            sparsity_level = None
        sparsity_results[dataset_name].append((run, cfg, sparsity_level))

        # delete artifact_dir
        shutil.rmtree(artifact_dir)
        
        
    
    
    


In [None]:
sparsity_results_filtered_NASL = [sr for sr in sparsity_results['NCARS'] if sr[2] is not None]
for run,cfg,sparsity in sparsity_results_filtered_NASL:
    print(run.id,run.summary['test/acc'],cfg.transform.train.num_events_per_sample, np.mean(sparsity['all']), np.std(sparsity['all']))

In [None]:
run_id = '2yqeh948'
for run in runs:
    if run.id == run_id:
        artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")
        cfg = run.config

runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg)) 



In [None]:
for key in sparsity_level.keys():
    print(f"{key} mean: {np.mean(sparsity_level[key])}, std: {np.std(sparsity_level[key])}")


In [None]:
for dataset_name in dataset_names:
    print(dataset_name)
    for run,cfg in dataset_runs[dataset_name]:
        print(cfg.transform.train.num_events_per_sample, run.summary['epoch'],  cfg.wandb.experiment_name, run.id, run.summary['test/acc']) 

In [None]:
def percentile(t, q):
    B, C, H, W = t.shape
    k = 1 + round(.01 * float(q) * (C * H * W - 1))
    result = t.view(B, -1).kthvalue(k).values
    return result[:,None,None,None]

def create_image(representation):
    B, C, H, W = representation.shape
    representation = representation.view(B, 3, C // 3, H, W).sum(2)

    # do robust min max norm
    representation = representation.detach().cpu()
    robust_max_vals = percentile(representation, 99)
    robust_min_vals = percentile(representation, 1)

    representation = (representation - robust_min_vals)/(robust_max_vals - robust_min_vals)
    representation = torch.clamp(255*representation, 0, 255).byte()

    representation = torchvision.utils.make_grid(representation)

    return representation