In [None]:
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns

In [None]:
import faiss

In [None]:
infile = 'vecs/tess_ode.pkl'

In [None]:
data = pd.read_pickle(infile)

In [None]:
vecs = np.stack(list(data.values()))

In [None]:
vecs = vecs/np.linalg.norm(vecs, axis=1).reshape(-1,1)

In [None]:
i2f = dict(enumerate(list(data.keys())))

In [None]:
i2f

### Build DB

In [None]:
d = vecs.shape[1]; d

In [None]:
db = faiss.IndexFlatIP(d)

In [None]:
db.add(vecs)

In [None]:
db.ntotal

### Input vec

In [None]:
indir = 'tess/16_17/z_normalized/'
filenames = [indir+x+'.npy' for x in list(data.keys())]

In [None]:
tsne_res = pd.read_pickle('tess_tsne.pkl')

In [None]:
df = pd.DataFrame(zip(filenames, tsne_res[:,0], tsne_res[:,1]),
                      columns = ['file','x1', 'x2'])

In [None]:
plot_bounds = (5,7, 18, 20)

In [None]:
s = df.loc[(df['x1'] > plot_bounds[0]) & (df['x1'] < plot_bounds[1])
       & (df['x2'] > plot_bounds[2]) & (df['x2'] < plot_bounds[3]) ]

In [None]:
def plot_one(sample):
    file = sample['file'].values[0]
    ts = np.load(file)[1]
    plt.figure(figsize=(5,5))
    #     ax.set_title('cluster: ' + str(ixs[i]) + '  ' + 'score: ' + str(probs[i]))
    ax=sns.lineplot(x=np.arange(len(ts)), y=ts, color='k', alpha=0.8)
    ax.set(yticks=[], xticks=[])
    obj_id = file2id(file)
    fname = file.split(indir)[-1][:-4]
    return ax, ts, obj_id, fname

In [None]:
def file2id(file):
    return file.split(indir)[-1].split('-')[2]

In [None]:
ax, ts, obj_id, fname = plot_one(s.sample(n=1))
input_vec = data[fname]

In [None]:
input_vec

In [None]:
fname

In [None]:
np.linalg.norm(input_vec)

In [None]:
input_vec = input_vec/np.linalg.norm(input_vec)
input_vec = input_vec.reshape(1,-1)

In [None]:
%timeit db.search(input_vec, k=5)

In [None]:
probs, ixs = db.search(input_vec, k=6)
ixs = ixs[0]
probs = probs[0]

In [None]:
dict(zip([i2f[i] for i in ixs], probs))

### Quick sample

In [None]:
def roulette(sample):
    file = sample['file'].values[0]
    ts = np.load(file)[1]
#     plt.figure(figsize=(5,5))
#     #     ax.set_title('cluster: ' + str(ixs[i]) + '  ' + 'score: ' + str(probs[i]))
#     ax=sns.lineplot(x=np.arange(len(ts)), y=ts, color='k', alpha=0.8)
#     ax.set(yticks=[], xticks=[])
    obj_id = file2id(file)
    fname = file.split(indir)[-1][:-4]
    input_vec = data[fname]
    input_vec = input_vec/np.linalg.norm(input_vec)
    input_vec = input_vec.reshape(1,-1)
    probs, ixs = db.search(input_vec, k=4)
    ixs = ixs[0]
    files = [indir+i2f[i]+'.npy' for i in ixs.reshape(-1)]
    curves = [np.load(f)[1] for f in files]
    obj_ids = [file2id(f) for f in files]
    return obj_ids, curves

In [None]:
plot_bounds = (-100,100, -100,100)

In [None]:
s = df.loc[(df['x1'] > plot_bounds[0]) & (df['x1'] < plot_bounds[1])
       & (df['x2'] > plot_bounds[2]) & (df['x2'] < plot_bounds[3]) ]

In [None]:
res = roulette(s.sample(n=1))

In [None]:
obj_ids, ys = res
def make_plot():
    f, axes = plt.subplots(1, 4, figsize=(20, 5), sharey=True)
    for i in range(len(ys)):
        ax = axes[i]
        ax.set_xticks(ticks=[])
        ax.set_yticks(ticks=[])
        ax.set_title('id: ' + str(obj_ids[i]))
        if i == 0:
            sns.lineplot(x=np.arange(len(ys[i])), y=ys[i], color='r', alpha=0.6, ax=ax)
        else:
            sns.lineplot(x=np.arange(len(ys[i])), y=ys[i], color='k', alpha=0.8, ax=ax)
    return f

In [None]:
f

In [None]:
obj_ids

In [None]:
f.savefig('plots/2.png')

### Random plots

In [None]:
# for _ in range(50):
#     obj_ids, ys = roulette(s.sample(n=1))
#     inp_id = obj_ids[0]
#     f = make_plot()
#     f.savefig('plots/'+inp_id+'.png')

### Batch Query

In [None]:
vecs.shape

In [None]:
len(i2f)

In [None]:
%timeit db.search(vecs, k=2)

In [None]:
probs, ixs = db.search(vecs, k=2)

In [None]:
max_probs = probs[:,1]

In [None]:
temp = dict(enumerate(max_probs))

In [None]:
res={i2f[i]:v for i,v in temp.items()}

In [None]:
topk = sorted(res.items(), key=lambda x: x[1])

In [None]:
top20 = [indir+x[0]+'.npy' for x in topk[:20]]

In [None]:
top20

In [None]:
file2id(top20[0])

In [None]:
i = 0
j = 0
fig, axes = plt.subplots(5, 4, figsize=(40, 40), sharey=False)
for f in top20:
    obj = file2id(f)
    y = np.load(f)[1]
    if i % 5 == 0:
        i = 0
    if j % 4 == 0:
        j = 0
    ax = axes[i][j]
    ax.set_xticks(ticks=[])
    ax.set_yticks(ticks=[])
    ax.set_title('id: ' + str(obj))
    sns.lineplot(x=np.arange(len(y)), y=y, color='r', alpha=0.6, ax=ax)
    i += 1
    j += 1
fig.savefig('most_dissimilar.png', tight_layout=True)

In [None]:
for f in top20:
    d = np.load(f)[1]
    plt.figure()
    sns.lineplot(x=np.arange(len(d)), y=d)
    plt.show()