In [None]:
from omegaconf import OmegaConf
import numpy as np
import os
import re
import os.path as osp
import torch
import pandas as pd
import pickle
import wandb

from transform_factory import factory as transforms
import model_factory
from graph_data_module import GraphDataModule
from train import Runner
from models.est import create_image
from datasets_torch_geometric.dataset_factory import create_dataset
from torch_geometric.loader import DataLoader
from utils.config_utils import get_checkpoint_file, get_config_file, show_cfg
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from models.est import create_image
api = wandb.Api()

In [None]:
projects = api.projects(entity="haraghi")
for project in projects:
    print(project.name)

In [None]:
datasets_name_and_num_classes = {
    "NCARS": {"name": "N-Cars", "num_classes": 2},
    "NASL": {"name": "N-ASL", "num_classes": 24},
    "NCALTECH101": {"name": "N-Caltech101", "num_classes": 101},
    "DVSGESTURE_TONIC": {"name": "DVS-Gesture", "num_classes": 11},
    "FAN1VS3": {"name": "Fan1vs3", "num_classes": 2}
}

In [None]:
dataset_projects = [
        "FINAL-NASL-varyinig-sparsity",
        "FINAL-NCARS-varyinig-sparsity",
        "FINAL-DVSGESTURE_TONIC-HP-varyinig-sparsity",
        "FINAL-FAN1vs3-varyinig-sparsity",
        "FINAL-NCALTECH101-varyinig-sparsity",
]

In [None]:
def find_val_and_test_acc_keys(run):
    val_acc_key = []
    test_acc_key = []
    for key in run.summary.keys():
        if "val" in key and "acc" in key and "mean" in key:
            val_acc_key.append(key)
        if "test" in key and "acc" in key and "mean" in key:
            test_acc_key.append(key)
    assert len(val_acc_key) <= 1, f"More than one val acc key found: {val_acc_key}"
    assert len(test_acc_key) <= 1, f"More than one test acc key found: {test_acc_key}"
    return val_acc_key[0] if len(val_acc_key) == 1 else None , test_acc_key[0] if len(test_acc_key) == 1 else None

In [None]:
folder_name = 'paper'
subfolder_name = os.path.join('images',folder_name,'fig_1')
entity = 'haraghi'
if not os.path.exists(subfolder_name):
    os.makedirs(subfolder_name)

test_dict = {}
num_events_set = set() 

for project_name in dataset_projects:
    runs = api.runs(f"{entity}/{project_name}")
    runs = [r for r in runs if r.state == "finished" and "transform" in r.config]
    if len(runs) == 0:
        print(f"No runs found for {project_name}")
        continue
    num_events = np.unique([run.config['transform']['train']['num_events_per_sample'] for run in runs])
    runs_per_num_events = {num_event: [run for run in runs if run.config['transform']['train']['num_events_per_sample'] == num_event] for num_event in num_events}
    dataset_name = runs[0].config["dataset"]["name"]
    
    num_events_set = num_events_set.union(set(num_events))
    
    test_mean = {}
    test_max = {}
    for num_event in num_events:
        test_mean[num_event] = []
        tes_max_val = -1
        for run in runs_per_num_events[num_event]:
            _, test_key = find_val_and_test_acc_keys(run)
            if test_key in run.summary: 
                test_mean[num_event].append(run.summary[test_key])
                if run.summary[test_key] > tes_max_val:
                    tes_max_val = run.summary[test_key]
                    test_max[num_event] = (run.summary[test_key], run)
            else:
                test_mean[num_event].append(None)
            
        print(f"percentage of runs with test acc for {num_event} events: {np.sum([v is not None for v in test_mean[num_event]]) / len(test_mean[num_event])} out of {len(test_mean[num_event])} runs")
    
    test_dict[dataset_name] = [test_mean  ,test_max]

num_events_list = sorted(list(num_events_set))

In [None]:
datasets = {}
for dataset_name, test_results in test_dict.items(): 
    print(f"Dataset: {dataset_name}", flush=True)
    if dataset_name == "NASL":
        datasets[dataset_name] =   create_dataset(
            dataset_path = os.path.join("datasets_torch_geometric", dataset_name, 'data'),
            dataset_name  = dataset_name,
            dataset_type = 'test',
            transform = None,#transforms(cfg.transform.train),
            pre_transform = None,# transforms(cfg.pre_transform.train),
            num_workers=3
        )

In [None]:
sample_ids =[13749,
13217,
14728,
5696,
1807,
3742,
3432]

c =[ 1428, 1670]
v = [12827, 13211]
w = [13520, 13297, 13785]
sample_ids = c + v + w

In [None]:
after_EST = True
class_label = 'w'
device = torch.device("cpu")

# Create custom colormaps
cmap_blue = mcolors.LinearSegmentedColormap.from_list("blue_cmap", ["white", "blue"])
cmap_red = mcolors.LinearSegmentedColormap.from_list("red_cmap", ["white", "red"])
            
for _ in range(30):
# for sample_id in sample_ids:
    for dataset_name, test_results in test_dict.items(): 
        dataset =  datasets[dataset_name]
        while True:
            sample_id = np.random.randint(len(dataset))
            if dataset[sample_id].label[0].lower() == class_label.lower():
                break
        print(f"sample_id: {sample_id} class: {dataset[sample_id].label[0]}")
        for num_event, test_max_run in test_results[1].items():
            if num_event <10000:
                continue
            print(num_event)
            run = test_max_run[1]     
            cfg,_ = get_config_file(run.entity, run.project, run.id, verbose=False)
            
            H, W = cfg.dataset.image_resolution
            
            dataset.transform = transforms(cfg.transform.test) 
            data = dataset[sample_id]
            vox = torch.zeros(2*H*W)
            # get values for each channel
            x, y, p = data.pos[:,0], data.pos[:,1], data.x[:,0]
            p = (p+1)/2  # maps polarity to 0, 1
            idx =         x.int() \
                        + W * y.int()\
                        + W * H * p.int()
            vox.put_(idx.long(), vox.new_full([data.num_nodes,], fill_value=1), accumulate=True)
            vox = vox.view(2, H, W)
            vox_sum = vox.sum(0)
            fig, ax = plt.subplots(1,3,figsize=(15,5))
            
            # Choose a colormap
            # cmap = plt.cm.viridis
            # nonzero_indices = torch.nonzero(vox_sum, as_tuple=True)
            # x_coords = nonzero_indices[0].numpy()
            # y_coords = nonzero_indices[1].numpy()
            # values = vox_sum[x_coords, y_coords]
            # norm = (values - values.min()) / (values.max() - values.min())
            # colors = cmap(norm.numpy())
            # scatter_ax = ax[0].scatter(y_coords, x_coords, c=colors, s=np.ones_like(x_coords)*1, alpha=1.0)
            
            nonzero_indices = torch.nonzero(vox[0], as_tuple=True)
            x_coords = nonzero_indices[0].numpy()
            y_coords = nonzero_indices[1].numpy()
            values = vox[0,x_coords, y_coords]
            norm = (values - values.min()) / (values.max() - values.min())
            colors = cmap_blue(norm.numpy())
            scatter_ax = ax[0].scatter(y_coords, x_coords, c=colors, s=np.ones_like(x_coords)*1, alpha=.7)
            cbar_blue = plt.cm.ScalarMappable(norm=plt.Normalize(vmin=values.min(), vmax=values.max()), cmap=cmap_blue)
            
            
            nonzero_indices = torch.nonzero(vox[1], as_tuple=True)
            x_coords = nonzero_indices[0].numpy()
            y_coords = nonzero_indices[1].numpy()
            values = vox[1,x_coords, y_coords]
            norm = (values - values.min()) / (values.max() - values.min())
            colors = cmap_red(norm.numpy())
            scatter_ax_2 = ax[0].scatter(y_coords, x_coords, c=colors, s=np.ones_like(x_coords)*1, alpha=.7)
            cbar = plt.cm.ScalarMappable(norm=plt.Normalize(vmin=values.min(), vmax=values.max()), cmap=cmap_red)
            # cbar.set_label('Value')

            ax[0].set_aspect('equal','box')
            ax[0].set_xlim(-0.5, vox_sum.shape[1] - 0.5)
            ax[0].set_ylim(-0.5, vox_sum.shape[0] - 0.5)
            
            
            accumulate_ax = ax[1].imshow(vox_sum)
            ax[1].set_aspect('equal','box')
            ax[1].invert_yaxis()

            if after_EST:
                model = model_factory.factory(cfg).to(device)
                checkpoint_file = get_checkpoint_file(run.entity, run.project, run.id)
                runner = Runner.load_from_checkpoint(checkpoint_path=checkpoint_file, cfg=cfg, model=model, map_location=device)
                data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
                data = data.to(device)
                runner.model.eval()
                with torch.no_grad():
                    vox_after_est = runner.model.quantization_layer.forward(data)
                vox_after_est = create_image(vox_after_est)
                est_ax = ax[2].imshow(vox_after_est.numpy().transpose(1,2,0), cmap='viridis')
                ax[2].invert_yaxis()
            # Show plot
            ax_names = [cbar, accumulate_ax, est_ax]
            for i in range(3):
                ax[i].axis('off')
                        # Remove labels and ticks
                ax[i].set_xticks([])
                ax[i].set_yticks([])
                ax[i].set_xlabel('')
                ax[i].set_ylabel('')
                ax[i].set_title('')
                ax[i].spines['top'].set_linewidth(2.0)
                ax[i].spines['right'].set_linewidth(2.0)
                ax[i].spines['bottom'].set_linewidth(2.0)
                ax[i].spines['left'].set_linewidth(2.0)
            
                fig.colorbar(ax_names[i], ax=ax[i])
            plt.show()
        break