In [27]:
import pickle
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.neighbors import KernelDensity, NearestNeighbors
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')
random.seed(42)

In [22]:
def get_classes(csv):
    unfilter_classes = list(csv["Class"].unique())[1:]
    classes = []
    for cl in unfilter_classes:
        if "data" in cl: continue
        classes.append(cl)
    return classes

In [3]:
def calc_distribution(same, diff):
    kde_same = KernelDensity(kernel="gaussian",bandwidth=0.75).fit(np.array(same).reshape(-1, 1))
    kde_diff = KernelDensity(kernel="gaussian",bandwidth=0.75).fit(np.array(diff).reshape(-1, 1))
    return kde_same, kde_diff

In [4]:
def make_plot(path, same, diff,title, feature=None):
    Path(path).mkdir(parents=True, exist_ok=True)
    

    
    fig, ax = plt.subplots()
    names = ["Same", "Diff"]
    for idx, a in enumerate([same, diff]):
        sns.distplot(a, ax=ax, kde=True, hist=False, rug=False, label=names[idx])
    ax.set_xlim([0, 1])
    ax.set_ylim([0, 12.5])
    

    fig.add_axes(ax)
    plt.legend()
    fig.suptitle(title, fontsize=10)
    plt.savefig(path+"same-diff-dist.png")
    
    plt.close('all')

In [5]:
def load_pckl(path):
    with open(path, "rb") as f: ret = pickle.load(f)

    return ret

In [6]:
def my_distance(x, y, **kwargs):
    # print(kwargs)
    cos = nn.CosineSimilarity(dim=0, eps=1e-6)

    return 1- float( cos(torch.tensor(x), torch.tensor(y)) )

In [24]:
def analysis(mode, cl, closest_same_class, closest_diff_class):
    path = "data/figures/{}/closest_same_closest_diff/{}/".format(mode+"-take-2",ade_classes[int(cl)])
    title = "{} - Same-Diff Class".format(ade_classes[int(cl)])
    make_plot(path, closest_same_class, closest_diff_class, title)

    kde_same, kde_diff = calc_distribution(closest_same_class, closest_diff_class)

    with open(path + "distribution.pckl", "wb") as f:
        pickle.dump([kde_same, kde_diff], f)

In [32]:
def closest(csv, cl, mode, same, diff):
    if len(same) == 0: return
    closest_same_class = []
    closest_diff_class = []

    sample_length = min( 500, len(same) )

    same_class = rng.choice(same, sample_length, axis=0)
    # same_class = same
    
    
    for idx, feature in tqdm(enumerate(same_class)):
        
        search_index = np.delete(same, idx, 0)
        if len(search_index) == 0: return
        knn = NearestNeighbors(algorithm="brute", n_neighbors=1, metric=my_distance)
        knn.fit(search_index)
        closest_same_class.append( 1-knn.kneighbors(feature.reshape(1, -1))[0][0][0] )


        
        closest_diff_class.append( 1 - diff.kneighbors(feature.reshape(1, -1))[0][0][0] )




    if  closest_same_class and  closest_diff_class :
        analysis(mode, cl, closest_same_class, closest_diff_class)
    else:
        return

In [19]:
rng = np.random.default_rng()
cos = nn.CosineSimilarity(dim=0, eps=1e-6)

ade = pd.read_csv("data/features_150.csv")
ade_classes = {row["Idx"]:row["Name"].replace(";", "-") for idx, row in ade.iterrows()}

mode = "train_non_torch"

In [None]:
csv = pd.read_csv("data/{}/features.csv".format(mode), names=["Idx", "Class", "Path"], low_memory=False)
classes = get_classes(csv)
knn_diff = NearestNeighbors(algorithm="brute", n_neighbors=1, metric=my_distance)


for cl in tqdm(classes, desc="Total"):
    try: same = load_pckl('data/{}/same_index.pckl'.format(mode))[cl]
    except: continue
    knn_diff.fit(rng.choice(load_pckl('data/{}/diff_index.pckl'.format(mode))[cl], 20000, axis=0))
    closest(csv, cl, mode, same, knn_diff)

Total:   0%|                                                                                   | 0/104 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:01,  1.64s/it][A
2it [00:03,  1.64s/it][A
3it [00:04,  1.64s/it][A
4it [00:06,  1.64s/it][A
5it [00:08,  1.64s/it][A
6it [00:09,  1.64s/it][A
7it [00:11,  1.64s/it][A
8it [00:13,  1.64s/it][A
9it [00:14,  1.64s/it][A
10it [00:16,  1.64s/it][A
11it [00:18,  1.64s/it][A
12it [00:19,  1.64s/it][A
13it [00:21,  1.64s/it][A
14it [00:22,  1.64s/it][A
15it [00:24,  1.64s/it][A
16it [00:26,  1.64s/it][A
17it [00:27,  1.64s/it][A
18it [00:29,  1.63s/it][A
19it [00:31,  1.63s/it][A
20it [00:32,  1.64s/it][A
21it [00:34,  1.64s/it][A
22it [00:36,  1.64s/it][A
23it [00:37,  1.64s/it][A
24it [00:39,  1.64s/it][A
25it [00:40,  1.64s/it][A
26it [00:42,  1.64s/it][A
27it [00:44,  1.64s/it][A
28it [00:45,  1.64s/it][A
29it [00:47,  1.64s/it][A
30it [00:49,  1.64s/it][A
31it [00:50,  1.64s/it][A
32it [00:52,  1.64s/it][A
33i

82it [03:26,  2.57s/it][A
83it [03:29,  2.61s/it][A
84it [03:31,  2.64s/it][A
85it [03:34,  2.62s/it][A
86it [03:37,  2.64s/it][A
87it [03:39,  2.64s/it][A
88it [03:42,  2.66s/it][A
89it [03:45,  2.64s/it][A
90it [03:47,  2.58s/it][A
91it [03:49,  2.55s/it][A
92it [03:52,  2.55s/it][A
93it [03:54,  2.52s/it][A
94it [03:57,  2.50s/it][A
95it [03:59,  2.51s/it][A
96it [04:02,  2.52s/it][A
97it [04:04,  2.50s/it][A
98it [04:07,  2.53s/it][A
99it [04:10,  2.53s/it][A
100it [04:12,  2.52s/it][A
101it [04:15,  2.52s/it][A
102it [04:17,  2.53s/it][A
103it [04:20,  2.52s/it][A
104it [04:22,  2.52s/it][A
105it [04:25,  2.52s/it][A
106it [04:27,  2.50s/it][A
107it [04:30,  2.53s/it][A
108it [04:32,  2.55s/it][A
109it [04:35,  2.56s/it][A
110it [04:37,  2.57s/it][A
111it [04:40,  2.57s/it][A
112it [04:43,  2.57s/it][A
113it [04:45,  2.56s/it][A
114it [04:48,  2.58s/it][A
115it [04:50,  2.56s/it][A
116it [04:53,  2.55s/it][A
117it [04:55,  2.55s/it][A
118it [04:

166it [07:09,  2.46s/it][A
167it [07:11,  2.48s/it][A
168it [07:14,  2.47s/it][A
169it [07:16,  2.49s/it][A
170it [07:19,  2.50s/it][A
171it [07:21,  2.53s/it][A
172it [07:24,  2.53s/it][A
173it [07:26,  2.51s/it][A
174it [07:29,  2.54s/it][A
175it [07:32,  2.57s/it][A
176it [07:34,  2.57s/it][A
177it [07:37,  2.62s/it][A
178it [07:40,  2.63s/it][A
179it [07:42,  2.61s/it][A
180it [07:45,  2.58s/it][A
181it [07:47,  2.56s/it][A
182it [07:50,  2.53s/it][A
183it [07:52,  2.54s/it][A
184it [07:55,  2.54s/it][A
185it [07:57,  2.51s/it][A
186it [08:00,  2.52s/it][A
187it [08:02,  2.51s/it][A
188it [08:05,  2.53s/it][A
189it [08:07,  2.53s/it][A
190it [08:10,  2.52s/it][A
191it [08:12,  2.55s/it][A
192it [08:15,  2.57s/it][A
193it [08:18,  2.58s/it][A
194it [08:20,  2.66s/it][A
195it [08:23,  2.60s/it][A
196it [08:25,  2.58s/it][A
197it [08:28,  2.60s/it][A
198it [08:31,  2.57s/it][A
199it [08:33,  2.57s/it][A
200it [08:36,  2.58s/it][A
201it [08:38,  2.58s

249it [10:52,  2.63s/it][A
250it [10:54,  2.59s/it][A
251it [10:57,  2.60s/it][A
252it [11:00,  2.60s/it][A
253it [11:02,  2.64s/it][A
254it [11:05,  2.64s/it][A
255it [11:08,  2.66s/it][A
256it [11:10,  2.67s/it][A
257it [11:13,  2.68s/it][A
258it [11:16,  2.72s/it][A
259it [11:19,  2.68s/it][A
260it [11:21,  2.67s/it][A
261it [11:24,  2.66s/it][A
262it [11:26,  2.67s/it][A
263it [11:29,  2.68s/it][A
264it [11:32,  2.69s/it][A
265it [11:35,  2.68s/it][A
266it [11:37,  2.64s/it][A
267it [11:40,  2.63s/it][A
268it [11:42,  2.61s/it][A
269it [11:45,  2.57s/it][A
270it [11:47,  2.59s/it][A
271it [11:50,  2.58s/it][A
272it [11:52,  2.56s/it][A
273it [11:55,  2.56s/it][A
274it [11:58,  2.57s/it][A
275it [12:00,  2.56s/it][A
276it [12:03,  2.58s/it][A
277it [12:05,  2.59s/it][A
278it [12:08,  2.58s/it][A
279it [12:11,  2.60s/it][A
280it [12:13,  2.57s/it][A
281it [12:16,  2.59s/it][A
282it [12:18,  2.58s/it][A
283it [12:21,  2.59s/it][A
284it [12:23,  2.56s