In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import pairwise_distances
from typing import List
from sklearn.feature_extraction.text import TfidfVectorizer
import imodelsx.embeddings
from sklearn.preprocessing import normalize

df = pd.read_pickle("../data/data_clean.pkl")
n = df.shape[0]

# set up text for prediction
def get_text_representation(row):
    # return f"""- Title: {row["title"]}
# - Description: {row["description"]}
# - Predictor variables: {str(row["feature_names"])[1:-1]}"""
    return f"""{row["title"]}. {row["description"]}."""  # Keywords: {str(row["info___keywords"])[1:-1]}"""
    # return f"""{row["title"]}. {row["description"]}. Keywords: {str(row["info___keywords"])[1:-1]}"""
df['text'] = df.apply(get_text_representation, axis=1)

In [None]:
df[['description', 'info___usage___use_case', 'info___usage___why_use',
    'info___usage___notes', 'info___next_steps___advice']]

In [None]:
def id_to_idx(id, df):
    return np.where(df.id == id)[0]


sims = np.zeros((n, n))
for r, row in tqdm(df.iterrows()):
    ids = row["info___related_cdi_ids"]
    for id in ids:
        c = id_to_idx(id, df)
        sims[r, c] += 1

    # for c, col in df.iterrows():
    #     for key in [
    #         "categorization___chief_complaint",
    #         "categorization___specialty",
    #         "categorization___purpose",
    #         "categorization___system",
    #         "categorization___disease",
    #     ]:
    #         if row[key] == col[key]:
    #             sims[r, c] += 1


# average values across the diagonal
sims = (sims + sims.T) / 2

# set diagonal to 1
# np.fill_diagonal(sims, max(sims))

# plot clustermap
# sns.clustermap(sims)

In [139]:
# checkpoint = "bert-base-uncased"
# checkpoint = 'microsoft/deberta-v2-xxlarge'
checkpoint = 'tf-idf'
embs = imodelsx.embeddings.get_embs(
    df["text"].tolist(), checkpoint, batch_size=32, aggregate="first")

In [140]:
# embs = normalize(embs, norm="l2", axis=1)

In [141]:
# calculate pairwise similarity between all embeddings in embs
s = pairwise_distances(embs, metric="cosine")

# matrix showing gt
sims_sym = (sims >= 1.0).astype(bool)
# sims_sym = np.where(sims >= 1.0)

print('median dist', np.median(s[sims_sym]).round(2),
      'versus baseline', np.median(s[~sims_sym]).round(2), '(distances, so lower is better)')


# ranks for largest values in each row (0 is best rank)
ranks_by_dist = np.argsort(s, axis=1)[::, ::-1]

# select ranks for gt
ranks_for_gt = []
for i in range(sims_sym.shape[0]):
    ranks_for_gt.append(ranks_by_dist[i, sims_sym[i]])

# look at ranks
all_ranks = np.concatenate(ranks_for_gt) - 1
mean_rank = np.mean(all_ranks).round(2)
print('mean rank (lower is better)', mean_rank,
      'random baseline', (ranks_by_dist.shape[1] - 1) / 2)
# plt.axvline(mean_rank, color="k", linestyle="dashed", linewidth=1)
# plt.hist(all_ranks)

median dist 0.71 versus baseline 0.99 (distances, so lower is better)
mean rank (lower is better) 347.29 random baseline 344.5
