In [None]:
import os
import sys
os.environ['GLEAMS_HOME'] = os.path.join(os.environ['HOME'],
                                         'Projects/gleams')
# Make sure all code is in the PATH.
sys.path.append(
    os.path.normpath(os.path.join(os.environ['GLEAMS_HOME'], 'src')))

In [None]:
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.spatial.distance as ssd
import seaborn as sns

In [None]:
import logging
logger = logging.getLogger('gleams')
logger.setLevel(logging.DEBUG)

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
from gleams.dag import dag

from gleams import config

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='serif')
sns.set_palette('Set1')
sns.set_context('paper', font_scale=1.3)    # Single-column figure.

In [None]:
split = 'test'

In [None]:
metadata_feature = (
    pd.read_parquet(
        os.path.join(os.environ['GLEAMS_HOME'], 'data', 'feature',
                     f'feature_{config.massivekb_task_id}_{split}.parquet'))
    .set_index(['dataset', 'filename', 'scan']))
metadata_embed = (
    pd.read_parquet(
        os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                     f'embed_{config.massivekb_task_id}_{split}.parquet'),
        columns=['dataset', 'filename', 'scan'])
    .set_index(['dataset', 'filename', 'scan']))

In [None]:
metadata = (pd.merge(metadata_feature, metadata_embed, 'right',
                     left_index=True, right_index=True)
            .reset_index().dropna())

In [None]:
embeddings = np.load(
    os.path.join(os.environ['GLEAMS_HOME'], 'data', 'embed',
                 f'embed_{config.massivekb_task_id}_{split}.npy'),
    mmap_mode='r')

In [None]:
pairs_pos = np.load(
    os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'feature',
        f'feature_{config.massivekb_task_id}_{split}_pairs_pos.npy'),
    mmap_mode='r')
pairs_neg = np.load(
    os.path.join(
        os.environ['GLEAMS_HOME'], 'data', 'feature',
        f'feature_{config.massivekb_task_id}_{split}_pairs_neg.npy'),
    mmap_mode='r')
num_pairs = min(len(pairs_pos), len(pairs_neg))
max_num_pairs = 1_000_000
num_pairs = min(num_pairs, max_num_pairs // 2)
logger.info('Using %d positive and negative feature pairs each', num_pairs)
idx_pos = np.random.choice(pairs_pos.shape[0], num_pairs, False)
idx_neg = np.random.choice(pairs_neg.shape[0], num_pairs, False)
pairs_pos, pairs_neg = pairs_pos[idx_pos], pairs_neg[idx_neg]

In [None]:
dist_pos, dist_neg = [], []
for pair in pairs_pos:
    index1, index2 = metadata.iloc[pair].index
    dist_pos.append(ssd.euclidean(embeddings[index1], embeddings[index2]))
for pair in pairs_neg:
    index1, index2 = metadata.iloc[pair].index
    dist_neg.append(ssd.euclidean(embeddings[index1], embeddings[index2]))

In [None]:
joblib.dump([dist_pos, dist_neg], 'pairs_dist.joblib')

In [None]:
# dist_pos, dist_neg = joblib.load('pairs_dist.joblib')

In [None]:
width = 7
height = width / 1.618    # golden ratio
fig, ax = plt.subplots(figsize=(width, height))

sns.kdeplot(dist_pos, shade=True, label='Positive pairs', ax=ax)
sns.kdeplot(dist_neg, shade=True, label='Negative pairs', ax=ax)

ax.set_xlabel('Embedded distance')
ax.set_ylabel('Density')

sns.despine()

plt.savefig('pairs_dist.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
logging.shutdown()