# Show recontructions (in-class variation)
    - focus on showing highest confidence examples in single class for target model 

In [None]:
import numpy as np
import torch
from torchvision.datasets import ImageFolder
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import dejavu_utils.reconstruction_utils as ru
import dejavu_utils.plot_utils as pu
import json
import pickle

import os 
from pathlib import Path

plt.style.use(['science'])

For this notebook to run, you will need to fill the following paths:

In [None]:
logging_folder = ''
imgnet_dir = ''
bbox_dir = ''
bbox_idxs = ''
model_A_pth = f'{logging_folder}/{model}/{model}_dssweep_{ds}pc_A/model_ep{epoch}.pth'
model_B_pth = f'{logging_folder}/{model}/{model}_dssweep_{ds}pc_B/model_ep{epoch}.pth'
rcdm_A_pth = f'{logging_folder}/RCDM/{model}/rcdm_{model}_{epoch}ep_{ds}pc_A/model600000.pt'
rcdm_B_pth = f'{logging_folder}/{model}/rcdm_{model}_{epoch}ep_{ds}pc_B/model600000.pt'

In [None]:
model = 'vicreg'
attk_set = 'A'
ref_set = 'B' if attk_set == 'A' else 'A'
epoch = 1000
ds = 300
k_neighb = 100
conf_gap_thresh = 3

#ssl params 
mlp = '8192-8192-8192'
gpu = 1

with open("imgnet_classes.json") as f:
    imgnet_classes = json.load(f)

In [None]:
attk_data = ru.get_attack_data(model, ds, epoch, k_neighb)
ru.print_class_statistics_sort_conf(attk_data, attk_set, epoch, ds, imgnet_classes, k = 40)

### Load SSL and RCDM models

In [None]:
if not torch.distributed.is_initialized(): 
    dist_url = Path(os.path.join('/scratch/', 'interactive_init'))
    if dist_url.exists():
        os.remove(str(dist_url))
    dist_url = dist_url.as_uri()

    torch.distributed.init_process_group(
        backend='nccl', init_method=dist_url,
        world_size=1, rank=0)                                    

torch.cuda.set_device(gpu) 

ssl_model_A, ssl_model_B = ru.load_ssl_models(model_A_pth, model_B_pth, mlp, model)

ssl_dim = ssl_model_A.module.representation_size + ssl_model_A.module.num_features
RCDM_A, diff_A = ru.load_rcdm_model(rcdm_A_pth, ssl_dim)
RCDM_B, diff_B = ru.load_rcdm_model(rcdm_B_pth, ssl_dim)

### Load datasets

In [None]:
attack_idxs = attk_data[f'set_{attk_set}_idxs_{epoch}ep_{ds}pc']
crop_ds = ru.aux_dataset(imgnet_dir, bbox_dir, attack_idxs, return_im_and_tgt = True) #dataset to load cropped images 

### Look at confident examples/patches in a given class

In [None]:
#badger
cl = 362
badger_patches, badger_idxs = ru.top_conf_show_class_examples(attk_data, attk_set, epoch, ds, 
                                                      cl, crop_ds, imgnet_classes, k = 40)

In [None]:
selected_badger_idx = [
    42118,
    126765,
    55913,
    16995
]

#### Then use RCDM to sample images using the NN

In [None]:
im_dict_badger = ru.gen_samples(
        selected_badger_idx, 
        diff_A, diff_B,
        ssl_model_A, ssl_model_B,
        RCDM_A, RCDM_B,
        epoch, ds,
        attk_data,   
        attk_set = 'A',
        just_neighbs = False
        )