In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import shutil
import copy
import os.path as osp
import torch
import pytorch_lightning as pl
from tqdm import tqdm
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger

import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import torchvision
import torchmetrics
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
entity = "haraghi"
project = "DGCNN"

In [None]:
api = wandb.Api()
runs = api.runs(f"{entity}/{project}")
cfg_bare = OmegaConf.load("config_bare.yaml")
cfgs = [OmegaConf.merge(cfg_bare,OmegaConf.create(run.config)) for run in runs]
dataset_runs = {}
# Get the dataset names from the config file
dataset_names = list(set([cfg.dataset.name for cfg in cfgs]))
for dataset_name in dataset_names:
    # Get the runs for this dataset

    dataset_runs[dataset_name] = {run.id : (run,cfg) for run,cfg in zip(runs,cfgs) if 
                                  cfg.dataset.name == dataset_name and 
                                  cfg.model.name == 'EST' and 
                                  cfg.model.num_bins == 9 and
                                  cfg.model.resnet_pretrained and
                                #   'test/acc' in run.summary and
                                #   'epoch' in run.summary and
                                  (not cfg.model.cnn_type or cfg.model.cnn_type == "resnet34") and
                                #   run.summary['epoch'] > 51 
                                    "1vlnuera" in run.id
                                  }
    
    # dataset_runs[dataset_name] = sorted(dataset_runs[dataset_name], key=lambda r: r[0].summary['test/acc'], reverse=True)
    
    print(dataset_name, len(dataset_runs[dataset_name]))


In [None]:
trainer = pl.Trainer(
    enable_progress_bar=True,
    # Use DDP training by default, even for CPU training
    # strategy="ddp_notebook",
    devices=torch.cuda.device_count(),
    accelerator="auto"
)

In [None]:
num_events_per_sample = {}
num_events_per_sample_unique = {}
for dataset_name in dataset_names:
    print(dataset_name)
    print("-"*50)
    num_events_per_sample[dataset_name] = []
    for run_id, v in dataset_runs[dataset_name].items():
        run,cfg = v
        neps = cfg.transform.test.num_events_per_sample   
        if neps is not None and neps < 1500 and (neps & (neps - 1)) == 0:
            num_events_per_sample[dataset_name].append(cfg.transform.test.num_events_per_sample) 
        elif neps is None and cfg.transform.test.random_num_events_per_sample.transform:
            num_events_per_sample[dataset_name].append("random")
    num_events_per_sample_unique[dataset_name] = list(set(num_events_per_sample[dataset_name]))
    print(num_events_per_sample_unique[dataset_name])
    

In [None]:
run_id = 'sx3f1cu2'
cfg = dataset_runs['FAN1VS3'][run_id][1]
gdm = GraphDataModule(cfg) 
if cfg.dataset.num_classes is None:
    cfg.dataset.num_classes = gdm.num_classes
artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")

runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg))  

In [None]:
from glob import glob
run_id = '1vlnuera'
cfg = dataset_runs['NASL'][run_id][1]
gdm = GraphDataModule(cfg) 
if cfg.dataset.num_classes is None:
    cfg.dataset.num_classes = gdm.num_classes
artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")
# artifact_dir = glob(osp.join("DGCNN", run_id, "checkpoints","*"))
# assert len(artifact_dir) == 1
# artifact_dir = artifact_dir[0]
runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg))  

In [None]:
acc_results = []
loss_results = []
for _ in range(400):
        results  = trainer.test(model=runner, datamodule=gdm, verbose=0)
        acc_results.append(results[0]['test/acc'])
        loss_results.append(results[0]['test/loss'])
        
plt.hist(loss_results, bins=40)

In [None]:
plt.hist(loss_results, bins=40)

In [None]:
file_path = "test_sparsity_sensitivity_nasl_randome_inverse_sampling.pt"

if os.path.isfile(file_path):
    sparsity_results = torch.load(file_path)
else:
    sparsity_results = {}
    
num_try = 5
for dataset_name in dataset_names:
    print("-"*50)
    print(dataset_name)
    print("-"*50)
    if 'NASL' not in dataset_name:
        continue
    if dataset_name not in sparsity_results:
        sparsity_results[dataset_name] = {}
    for run_id,v in dataset_runs[dataset_name].items():
        if run_id in sparsity_results[dataset_name]:
            continue
        run, cfg = v
        # print(f"{dataset_name}:{run.summary['test/acc']:.4f} {run.summary['epoch']}")
        trained_num_events = cfg.transform.train.num_events_per_sample
        print(f'trained #events: {trained_num_events}') 
        sparsity_results[dataset_name][run_id] = {}
        sparsity_results[dataset_name][run_id]['trained_num_events'] = trained_num_events
        # sparsity_results[dataset_name][run_id]['summary_test_acc'] = run.summary['test/acc']
        sparsity_results[dataset_name][run_id]['model'] = cfg.model.name
        sparsity_results[dataset_name][run_id]['tested_num_events'] = {}
        artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")
        # artifact_dir = glob(osp.join("DGCNN", run_id, "checkpoints","*"))
        # assert len(artifact_dir) == 1
        
  
        gdm = GraphDataModule(cfg) 
        if cfg.dataset.num_classes is None:
            cfg.dataset.num_classes = gdm.num_classes

        runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg)) 
        # runner = Runner.load_from_checkpoint(artifact_dir[0], cfg=cfg, model=model_factory.factory(cfg))    
  
        for neps in [8,16,32,64,128,256,512,1024,2048]:
            sparsity_results[dataset_name][run_id]['tested_num_events'][neps] = []
            print(f'tested #events: {neps}') 
            cfg_new = copy.deepcopy(cfg)        
            cfg_new.transform.test.num_events_per_sample = neps
            gdm = GraphDataModule(cfg_new) 
            if cfg_new.dataset.num_classes is None:
                cfg_new.dataset.num_classes = cfg_new.num_classes

            for _ in range(num_try):
                results  = trainer.test(model=runner, datamodule=gdm, verbose=0)
                sparsity_results[dataset_name][run_id]['tested_num_events'][neps].append(results)
                torch.save(sparsity_results, file_path)



In [None]:
sparsity_results = torch.load(f"sparsity_sensitivity_results_all_.pt")
num_try = 5
for dataset_name in dataset_names:
    print("-"*50)
    print(dataset_name)
    print("-"*50)
    if 'FAN' not in dataset_name:
        continue
    if dataset_name not in sparsity_results:
        sparsity_results[dataset_name] = {}
    for run_id,v in dataset_runs[dataset_name].items():
        if run_id in sparsity_results[dataset_name]:
            continue
        run, cfg = v
        print(f"{dataset_name}:{run.summary['test/acc']:.4f} {run.summary['epoch']}")
        trained_num_events = cfg.transform.train.num_events_per_sample
        print(f'trained #events: {trained_num_events}') 
        sparsity_results[dataset_name][run_id] = {}
        sparsity_results[dataset_name][run_id]['trained_num_events'] = trained_num_events
        sparsity_results[dataset_name][run_id]['summary_test_acc'] = run.summary['test/acc']
        sparsity_results[dataset_name][run_id]['model'] = cfg.model.name
        sparsity_results[dataset_name][run_id]['tested_num_events'] = {}
        artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")
  
        gdm = GraphDataModule(cfg) 
        if cfg.dataset.num_classes is None:
            cfg.dataset.num_classes = gdm.num_classes

        runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg))   
  
        for neps in num_events_per_sample_unique[dataset_name]:
            sparsity_results[dataset_name][run_id]['tested_num_events'][neps] = []
            print(f'tested #events: {neps}') 
            cfg_new = copy.deepcopy(cfg)        
            cfg_new.transform.test.num_events_per_sample = neps
            gdm = GraphDataModule(cfg_new) 
            if cfg_new.dataset.num_classes is None:
                cfg_new.dataset.num_classes = cfg_new.num_classes

            for _ in range(num_try):
                results  = trainer.test(model=runner, datamodule=gdm, verbose=0)
                sparsity_results[dataset_name][run_id]['tested_num_events'][neps].append(results)
            torch.save(sparsity_results, f"sparsity_sensitivity_results_all_.pt")
        #     print(f"Error for {run_id}: {dataset_name}, {cfg.model.name}, {cfg.transform.train.num_events_per_sample}." )
        #     dgsdk
        #     sparsity_level = None
        # cfg_new = copy.deepcopy(cfg)        
        # cfg_new.transform.test.num_events_per_sample = 
        # gdm = GraphDataModule(cfg) 
        # if cfg.dataset.num_classes is None:
        #     cfg.dataset.num_classes = gdm.num_classes

        #     artifact_dir = WandbLogger.download_artifact(artifact=f"{entity}/{project}/model-{run_id}:best")

       
            
        #     runner = Runner.load_from_checkpoint(osp.join(artifact_dir,"model.ckpt"), cfg=cfg, model=model_factory.factory(cfg)) 
    
        #     sparsity_level = get_sparsity_level(runner, gdm)
        #     print(f"Error for {run_id}: {dataset_name}, {cfg.model.name}, {cfg.transform.train.num_events_per_sample}." )
        #     dgsdk
        #     sparsity_level = None
        # sparsity_results[dataset_name].append((run, cfg, sparsity_level))

        # # delete artifact_dir
        shutil.rmtree(artifact_dir)
        
        
    
    
    
