# Purpose

### 2022-11-08
Run queries in parallel with `dask`. Now that we'll run ANN for 250+ subreddits, running in a single thread could take a loooong time.

New ETA for ~250k subreddits: ~50 minutes.


### 2022-08-01
Calculating precise nearest neighbors has become too expensive as we go over 40k subreddits. So instead let's calculate approx nearest neighbors (ANN). 

In this notebook we use [ANNOY](https://github.com/spotify/annoy).  Main reason for using annoy over FAISS is that annoy has official wheels in pypi, but FAISS only officially supports installation from conda. For now we don't want to depend on third-party wheels for FAISS b/c that can be messy to install & replicate in a VM. Maybe when we switch to kubeflow we can try FAISS.


# Notebook setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datetime import datetime
import gc
import os
import json
import logging
from logging import info
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import seaborn as sns

import dask
from dask import dataframe as dd
from tqdm import tqdm

import mlflow
import hydra
import annoy


import subclu
from subclu.models.nn_annoy import AnnoyIndex
from subclu.utils.hydra_config_loader import LoadHydraConfig
from subclu.data.data_loaders import LoadSubreddits
from subclu.utils.mlflow_logger import MlflowLogger, save_pd_df_to_parquet_in_chunks

from subclu.utils.big_query_utils import load_data_to_bq_table
from subclu.models.bq_embedding_schemas import embeddings_schema, similar_sub_schema


# General utils to display & set working directories
from subclu.utils import set_working_directory, get_project_subfolder
from subclu.utils.eda import (
    setup_logging, counts_describe, value_counts_and_pcts,
    notebook_display_config, print_lib_versions,
    style_df_numeric
)


print_lib_versions([annoy, dask, hydra, mlflow, np, pd, plotly, sns, subclu])

python		v 3.7.10
===
annoy		v: 1.17.0
dask		v: 2021.06.0
hydra		v: 1.1.0
mlflow		v: 1.16.0
numpy		v: 1.19.5
pandas		v: 1.2.4
plotly		v: 4.14.3
seaborn		v: 0.11.1
subclu		v: 0.6.1


In [3]:
# plotting
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.dates as mdates
plt.style.use('default')

setup_logging()
notebook_display_config()

# Set sqlite database as MLflow URI

In [4]:
# use new class to initialize mlflow
mlf = MlflowLogger(tracking_uri='sqlite')
mlflow.get_tracking_uri()

'sqlite:////home/jupyter/subreddit_clustering_i18n/mlflow_sync/djb-100-2021-04-28-djb-eda-german-subs/mlruns.db'

## Get list of experiments with new function

In [5]:
mlf.list_experiment_meta(output_format='pandas').tail(9)

Unnamed: 0,experiment_id,name,artifact_location,lifecycle_stage
35,35,v0.6.0_mUSE_aggregates,gs://i18n-subreddit-clustering/mlflow/mlruns/35,active
36,36,v0.6.0_mUSE_clustering_test,gs://i18n-subreddit-clustering/mlflow/mlruns/36,active
37,37,v0.6.0_mUSE_clustering,gs://i18n-subreddit-clustering/mlflow/mlruns/37,active
38,38,v0.6.0_nearest_neighbors,gs://i18n-subreddit-clustering/mlflow/mlruns/38,active
39,39,v0.6.1_mUSE_aggregates_test,gs://i18n-subreddit-clustering/mlflow/mlruns/39,active
40,40,v0.6.1_mUSE_aggregates,gs://i18n-subreddit-clustering/mlflow/mlruns/40,active
41,41,v0.6.1_mUSE_clustering_test,gs://i18n-subreddit-clustering/mlflow/mlruns/41,active
42,42,v0.6.1_mUSE_clustering,gs://i18n-subreddit-clustering/mlflow/mlruns/42,active
43,43,v0.6.1_nearest_neighbors,gs://i18n-subreddit-clustering/mlflow/mlruns/43,active


## Get runs from embeddings aggregation jobs

Want to make sure we can load these artifacts for other jobs

In [6]:
%%time

df_mlf_runs =  mlf.search_all_runs(experiment_ids=[40])
df_mlf_runs.shape

CPU times: user 56.9 ms, sys: 8.71 ms, total: 65.6 ms
Wall time: 65 ms


(4, 43)

In [7]:
df_mlf_runs[df_mlf_runs['status'] == 'FINISHED'].iloc[:5, :10]

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.memory_free,metrics.df_v_post_comments-cols,metrics.df_v_subs-rows,metrics.memory_used_percent
2,91ac7ca171024c779c0992f59470c81b,40,FINISHED,gs://i18n-subreddit-clustering/mlflow/mlruns/40/91ac7ca171024c779c0992f59470c81b/artifacts,2022-11-07 21:38:57.662000+00:00,2022-11-22 19:01:56.828000+00:00,1283359.0,515.0,781653.0,0.552656


### Check run artifacts for selected run

In [8]:
run_uuid_ = '91ac7ca171024c779c0992f59470c81b'
l_artifacts_top_level = mlf.list_run_artifacts(
    run_id=run_uuid_,
    only_top_level=True,
    verbose=True,
)
l_artifacts_all = mlf.list_run_artifacts(
    run_id=run_uuid_,
    only_top_level=False,
    verbose=False,
)

22:16:03 | INFO | "   288 <- Artifacts to check count"
22:16:03 | INFO | "   288 <- Artifacts clean count"
22:16:03 | INFO | "     6 <- Artifacts & folders at TOP LEVEL clean count"
22:16:09 | INFO | "   288 <- Artifacts clean count"
22:16:09 | INFO | "     6 <- Artifacts & folders at TOP LEVEL clean count"


In [9]:
for t_ in l_artifacts_top_level:
    l_ = [i for i in l_artifacts_all if t_ in i]
    print(f"=== Items in folder: {len(l_):,.0f} | {t_}  ===")
    for _ in l_[:3]:
        print(' ', '/'.join(_.split('/')[5:]))
    print('')

=== Items in folder: 63 | ann_df-2022-11-22_185903  ===
  ann_df-2022-11-22_185903/_common_metadata
  ann_df-2022-11-22_185903/_metadata
  ann_df-2022-11-22_185903/part.0.parquet

=== Items in folder: 211 | df_posts_agg_c1  ===
  df_posts_agg_c1/_common_metadata
  df_posts_agg_c1/_metadata
  df_posts_agg_c1/part.0.parquet

=== Items in folder: 14 | df_subs_agg_c1  ===
  df_subs_agg_c1/_common_metadata
  df_subs_agg_c1/_metadata
  df_subs_agg_c1/part.0.parquet

=== Items in folder: 1 | df_subs_agg_c1_ndjson  ===
  df_subs_agg_c1_ndjson/subreddit_embeddings_2022-11-18_171217.json

=== Items in folder: 7 | df_subs_agg_c1_unweighted  ===
  df_subs_agg_c1_unweighted/_common_metadata
  df_subs_agg_c1_unweighted/_metadata
  df_subs_agg_c1_unweighted/part.0.parquet

=== Items in folder: 1 | df_subs_agg_c1_unweighted_ndjson  ===
  df_subs_agg_c1_unweighted_ndjson/subreddit_embeddings_2022-11-18_165307.json



# Set run parameters to log for mlflow

This dictionary is equivalent to a config file for now. Use it as a bases for kubeflow re-write.

How to get active run:
```python
mlflow.active_run().info.run_id
```

In [10]:
d_mlf_params = {
    'run_name': f"ann_subreddit_test-{datetime.utcnow().strftime('%Y-%m-%d_%H%M%S')}",
    'mlflow_experiment_name': 'v0.6.1_nearest_neighbors',
    'embeddings_run_uuid': '91ac7ca171024c779c0992f59470c81b',
    'subreddit_embeddings_folder': 'df_subs_agg_c1',
    'post_embeddings_folder': 'df_posts_agg_c1',
    'n_min_post_per_sub': 6,

    # index columns for ANN df, JSON, & BQ table
    'index_cols': ['subreddit_id', 'subreddit_name'],
    'model_version': 'v0.6.1',
    'model_name': 'cau-text-mUSE',
    
    # sample number of subreddits to sample.
    #  Set to None to run on full data
    'n_sample_embedding_rows': None,
    
    # flag & params to upload to bigquery
    'upload_to_bq': False,
    'bq_project': 'reddit-employee-datasets',
    'bq_dataset': 'david_bermejo',
    'bq_table_name': 'cau_similar_subreddits_by_text',
}
d_ann_params = {
    'n_trees': 200,
    'metric': 'angular',
}
run_uuid = d_mlf_params['embeddings_run_uuid']

# Load aggregated embeddings

For subreddit-level embeddings, my python code (serial) is fine. 

Try `gsutil` to download **posts-level embeddings** b/c that can take a LONG time to download sequentially. `gsutil` makes parallel downloaidng much faster and reports download speeds above 500MB / s:

```bash
ents_sub_desc/part.67.parquet...
/ [2/197 files][ 61.7 GiB/ 75.4 GiB]  81% Done 632.0 MiB/s ETA 00:00:22
```

In [11]:
%%time

# mlf.set_experiment(d_mlf_params.mlflow_experiment_name)
t_start_job = datetime.utcnow()
info(f"== Start ANN job ==")

t_start_read_embeddings_ = datetime.utcnow()
df_agg_sub_c_raw = mlf.read_run_artifact(
    run_id=run_uuid,
    artifact_folder='df_subs_agg_c1',
    read_function='pd_parquet',
    verbose=False,
)


info(df_agg_sub_c_raw.shape)

22:16:09 | INFO | "== Start ANN job =="
22:16:15 | INFO | "Local folder to download artifact(s):
  /home/jupyter/subreddit_clustering_i18n/data/local_cache/i18n-subreddit-clustering/mlflow/mlruns/40/91ac7ca171024c779c0992f59470c81b/artifacts/df_subs_agg_c1"
100%|########################################| 14/14 [00:00<00:00, 32822.95it/s]
22:16:16 | INFO | "  Parquet files found:     4"
22:16:16 | INFO | "  Parquet files to use:     4"
22:16:17 | INFO | "(781653, 515)"


CPU times: user 10.4 s, sys: 4.06 s, total: 14.5 s
Wall time: 8.16 s


In [12]:
df_agg_sub_c_raw.iloc[:5, :7]

Unnamed: 0,subreddit_id,subreddit_name,posts_for_embeddings_count,embeddings_0,embeddings_1,embeddings_2,embeddings_3
0,t5_1001tl,jewel_xo,1,-0.028712,-0.027187,0.024826,0.046359
1,t5_1004au,tisbutafleshwound,3,0.010298,-0.000277,-0.004013,0.01762
2,t5_1006a0,sethigh,1,0.027356,0.032256,-0.022585,-0.004125
3,t5_1008xr,asiandiasporamusic,2,-0.011276,0.00072,-0.010621,0.021452
4,t5_1009a3,memesenespanol,299,-0.005113,-0.005898,-0.012267,0.006103


In [13]:
df_agg_sub_c_raw.iloc[-5:, :7]

Unnamed: 0,subreddit_id,subreddit_name,posts_for_embeddings_count,embeddings_0,embeddings_1,embeddings_2,embeddings_3
781648,t5_71dwdl,leagoldmining,0,0.027257,0.03655,-0.086116,-0.007389
781649,t5_6u3a0g,onlyfans_subscribers7,0,-0.05452,0.011804,-0.01482,0.02281
781650,t5_7a1p9b,xecauquan1,0,-0.013504,-0.092274,0.013367,0.032947
781651,t5_6xryrp,steroidsarmspeptide,0,-0.03319,0.058719,0.071802,0.065691
781652,t5_7a4axh,autonation_,0,-0.016285,-0.034998,-0.000827,0.058873


## Load subreddit metadata
We need the metadata to keep only subreddits that have a high enough activity.

See the `subreddit_seed_for_clusters` column definition for threshold.

In [14]:
# load config data that has the keys needed to load the subreddit meta
config_name = '/data_text_and_metadata/v0.6.1_model'

cfg_cluster_meta = LoadHydraConfig(
    config_name=config_name,
    config_path="../config",
)

print([k for k in cfg_cluster_meta.config_dict.keys()])

['data_text_and_metadata']


In [15]:
%%time

df_sub_meta = LoadSubreddits(
    bucket_name=cfg_cluster_meta.config_dict['data_text_and_metadata']['bucket_name'],
    folder_path=cfg_cluster_meta.config_dict['data_text_and_metadata']['folder_subreddits_text_and_meta'],
    columns=['subreddit_id', 'subreddit_name', 'primary_topic', 'subreddit_seed_for_clusters'],
).read_raw()

print(df_sub_meta.shape)

22:16:18 | INFO | "Reading raw data..."
22:16:19 | INFO | "  Local folder to download artifact(s):
  /home/jupyter/subreddit_clustering_i18n/data/local_cache/i18n_topic_model_batch/runs/20221107/subreddits_fix/text"
100%|################################| 4/4 [00:00<00:00, 15060.34it/s]


(781653, 4)
CPU times: user 741 ms, sys: 349 ms, total: 1.09 s
Wall time: 2.01 s


In [16]:
df_sub_meta.head()

Unnamed: 0,subreddit_id,subreddit_name,primary_topic,subreddit_seed_for_clusters
0,t5_2qh1i,askreddit,Learning and Education,True
1,t5_2qh33,funny,Funny/Humor,True
2,t5_35n7t,whitepeopletwitter,Internet Culture and Memes,True
3,t5_2qh0u,pics,Art,True
4,t5_2qh13,worldnews,World News,True


# Filter subreddits to use in ANN index

In a previous version we only kept subs that had embeddings AND clustering data. 
<br>Now that we cover 700k subreddits for v0.6.x, we need to be more thoughtful about how we'll select which subs to keep for ANN.

For v0.6.1 we'll keep only subs that have **4+ posts in L90 days**. From this mode dashboard we expect that number to be around 289k subreddits.

Mode Dashboards: 
- v0.6.0: https://app.mode.com/reddit/reports/e6cde33162c4 
- v0.6.1: https://app.mode.com/reddit/reports/87ce3abc9e37


## Apply filters

In v0.6.1, we already have the number of posts for embedding in the embedding file, so we don't need to load additional data (from mlflow or BQ) to apply post-count filters.

In [17]:
# use df_pc_counts because it has the counts for post+comment after filtering for length
value_counts_and_pcts(
  pd.cut(
      df_agg_sub_c_raw['posts_for_embeddings_count'],
      bins=[-1, 0, 1, 2, 3, 4, 5, 49, np.inf],
      labels=[
        "00 posts", "01 post", '02 posts', '03 posts',
        '04 posts', '05 posts'
        , '06-49 posts', '50+ posts'
      ]
  ).rename('posts_with_len_3+'),
  sort_index=True,
  add_col_prefix=False,
  count_type='subreddits',
  sort_index_ascending=False,
  cumsum_count=True,
  reset_index=True,
).hide_index().set_caption(f"<h4 align='left'>Post distribution for subreddits with 1 view & 1 attempted post in L90-days</h4>")

posts_with_len_3+,subreddits_count,percent_of_subreddits,cumulative_sum_of_subreddits,cumulative_percent_of_subreddits
50+ posts,74004,9.5%,74004,9.5%
06-49 posts,158397,20.3%,232401,29.7%
05 posts,23732,3.0%,256133,32.8%
04 posts,33723,4.3%,289856,37.1%
03 posts,57946,7.4%,347802,44.5%
02 posts,128068,16.4%,475870,60.9%
01 post,235794,30.2%,711664,91.0%
00 posts,69989,9.0%,781653,100.0%


## Include `active`/`seed` subreddits with 3+ posts

We need to include these because they make the core of the subreddits in cluster model for recommendations.

In [18]:
value_counts_and_pcts(
    df_sub_meta['subreddit_seed_for_clusters']
)

Unnamed: 0,subreddit_seed_for_clusters-count,subreddit_seed_for_clusters-percent,subreddit_seed_for_clusters-pct_cumulative_sum
False,672175,86.0%,86.0%
True,109478,14.0%,100.0%


In [23]:
# use a temp variable to prevent unintended mixups if we run cells out of order
_ = df_agg_sub_c_raw.merge(
    df_sub_meta[['subreddit_id', 'subreddit_seed_for_clusters']],
    how='left',
    on='subreddit_id'
).copy()

display(value_counts_and_pcts(
    pd.cut(
        _[_['subreddit_seed_for_clusters'] == True]['posts_for_embeddings_count'],
        bins=[-1, 0, 1, 2, 3, 4, 5, 49, np.inf],
        labels=[
        "00 posts", "01 post", '02 posts', '03 posts',
        '04 posts', '05 posts'
        , '06-49 posts', '50+ posts'
        ]
    ).rename('posts_with_len_3+'),
    sort_index=True,
    add_col_prefix=False,
    count_type='subreddits',
    sort_index_ascending=False,
    cumsum_count=True,
    reset_index=True,
).hide_index().set_caption(f"<h4 align='left'>Post distribution for subreddits to use for `seeds` in L90-days</h4>"))

del _

posts_with_len_3+,subreddits_count,percent_of_subreddits,cumulative_sum_of_subreddits,cumulative_percent_of_subreddits
50+ posts,57160,52.2%,57160,52.2%
06-49 posts,42587,38.9%,99747,91.1%
05 posts,2780,2.5%,102527,93.7%
04 posts,3164,2.9%,105691,96.5%
03 posts,1429,1.3%,107120,97.8%
02 posts,940,0.9%,108060,98.7%
01 post,724,0.7%,108784,99.4%
00 posts,694,0.6%,109478,100.0%


In [24]:
%%time

df_agg_sub_c = (
    df_sub_meta[['subreddit_id', 'subreddit_seed_for_clusters']]
    .merge(
        df_agg_sub_c_raw,
        how='right',
        on='subreddit_id'
    )
    .copy()
)

mask_subs_over_post_threshold = df_agg_sub_c['posts_for_embeddings_count'] >= d_mlf_params['n_min_post_per_sub']
mask_subs_seed_ = (
    (df_agg_sub_c['posts_for_embeddings_count'] >= 3) &
    (df_agg_sub_c['subreddit_seed_for_clusters'] == True)
)
info(f"{mask_subs_over_post_threshold.sum():,.0f} <- Subs above post threshold")
info(f"{mask_subs_seed_.sum():,.0f} <- Subs marked as `cluster seeds` (highly active)")

df_agg_sub_c = df_agg_sub_c[mask_subs_over_post_threshold | mask_subs_seed_]
df_agg_sub_c.shape

22:33:31 | INFO | "232,401 <- Subs above post threshold"
22:33:31 | INFO | "107,120 <- Subs marked as `cluster seeds` (highly active)"


CPU times: user 2.65 s, sys: 1.25 s, total: 3.9 s
Wall time: 3.9 s


(239774, 516)

# Build annoy index

I created a custom `AnnoyIndex` class with some extra methods to create outputs & (and calculate cosine distance) for BigQuery.

In [25]:
%%time

index_cols = ['subreddit_id', 'subreddit_name']
l_embedding_cols = [c for c in df_agg_sub_c.columns if c.startswith('embeddings_')]

nn_index = AnnoyIndex(
    df_agg_sub_c[l_embedding_cols + index_cols],
    index_cols=index_cols,
    metric=d_ann_params['metric'],
    n_trees=d_ann_params['n_trees'],
)

nn_index.build()

CPU times: user 35min 5s, sys: 1min 19s, total: 36min 25s
Wall time: 46.4 s


## Get df with all items

For 80k subreddits it took 1 hr & 17 minutes.

I had to create a new method b/c that method would've taken over 18 hours to get ANN for 250k subreddits.

New method should take ~40 minutes to get 250 ANN for 250k subreddits !!!.


```bash
# old method:
100%|██████████| 81973/81973 [1:17:02<00:00, 17.73it/s]
17:07:23 | INFO | "(8115327, 7) <- df_top_items shape"


# new method:
  7%|6         | 17098/250573 [02:40<36:22, 106.97it/s]
  
 80%|########  | 192439/239774 [24:58<06:33, 120.32it/s]
```

In [26]:
%%time

df_nn_top = nn_index.get_top_n_by_item_all_fast(
    k=200,
    search_k=-1,
    include_distances=True,
    append_i=True,
    cosine_similarity=True,
)

100%|##########| 239774/239774 [31:21<00:00, 127.45it/s]
23:07:29 | INFO | "Start combining all ANNs into a df..."
23:08:24 | INFO | "(47954800, 4) <- df_nn_top shape"
23:08:24 | INFO | "Adding index labels (subreddit ID & Name)"
23:08:35 | INFO | "Done adding index names"
23:08:35 | INFO | "(47954800, 8) <- df_nn_top shape"
23:08:35 | INFO | "Calculating cosine similarity..."


CPU times: user 32min 10s, sys: 21.1 s, total: 32min 31s
Wall time: 32min 30s


### Quick Checks

In [27]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'france']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
3314800,t5_2qhjz,france,16574,6584,0.489396,1,t5_29145x,francedigeste,0.880246
3314801,t5_2qhjz,france,16574,92786,0.500917,2,t5_4c3l03,france6,0.874541
3314802,t5_2qhjz,france,16574,21992,0.513589,3,t5_2rj8v,francais,0.868113
3314803,t5_2qhjz,france,16574,50277,0.536449,4,t5_2zkfk,askfrance,0.856111
3314804,t5_2qhjz,france,16574,90336,0.541022,5,t5_47quxa,yahooqr,0.853648
3314805,t5_2qhjz,france,16574,45568,0.558391,6,t5_2xe8t,paslegorafi,0.8441
3314806,t5_2qhjz,france,16574,130621,0.567863,7,t5_5yjd6o,france_actu_debats,0.838766
3314807,t5_2qhjz,france,16574,3015,0.608544,8,t5_22i0,de,0.814837
3314808,t5_2qhjz,france,16574,16544,0.608873,9,t5_2qhh9,quebec,0.814637
3314809,t5_2qhjz,france,16574,118628,0.612728,10,t5_5i39cu,lbaqr,0.812282


In [28]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'finanzen']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
11864800,t5_35m5e,finanzen,59324,126925,0.392432,1,t5_5txdoj,finanzenat,0.922998
11864801,t5_35m5e,finanzen,59324,76249,0.470963,2,t5_3isqn,italiapersonalfinance,0.889097
11864802,t5_35m5e,finanzen,59324,61226,0.481969,3,t5_37aoh,vosfinances,0.883853
11864803,t5_35m5e,finanzen,59324,669,0.484936,4,t5_11cinh,befire,0.882418
11864804,t5_35m5e,finanzen,59324,42405,0.522513,5,t5_2w5jv,eupersonalfinance,0.86349
11864805,t5_35m5e,finanzen,59324,8816,0.530557,6,t5_2clhc5,literaciafinanceira,0.859254
11864806,t5_35m5e,finanzen,59324,233763,0.536577,7,t5_oe819,personalfinanceza,0.856042
11864807,t5_35m5e,finanzen,59324,37864,0.540692,8,t5_2uo3q,ausfinance,0.853826
11864808,t5_35m5e,finanzen,59324,63855,0.54347,9,t5_38zrx,personalfinancenz,0.85232
11864809,t5_35m5e,finanzen,59324,65493,0.543982,10,t5_39zkf,fiaustralia,0.852042


In [29]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'de']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
603000,t5_22i0,de,3015,93929,0.389254,1,t5_4egnbw,dezwo,0.924241
603001,t5_22i0,de,3015,68278,0.477014,2,t5_3caax,600euro,0.886229
603002,t5_22i0,de,3015,77935,0.481225,3,t5_3jxvk,tja,0.884211
603003,t5_22i0,de,3015,231106,0.493754,4,t5_irnzx,dachschaden,0.878103
603004,t5_22i0,de,3015,18131,0.509679,5,t5_2qo9i,austria,0.870114
603005,t5_22i0,de,3015,64009,0.517639,6,t5_392ha,asozialesnetzwerk,0.866025
603006,t5_22i0,de,3015,40631,0.571133,7,t5_2vk0m,nachrichten,0.836903
603007,t5_22i0,de,3015,96508,0.598626,8,t5_4juf8o,politpro,0.820823
603008,t5_22i0,de,3015,233371,0.607671,9,t5_nls07,belgium2,0.815368
603009,t5_22i0,de,3015,231620,0.608017,10,t5_jsyzh,poldersocialisme,0.815158


In [30]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'mexico']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
3352400,t5_2qhv7,mexico,16762,106045,0.43765,1,t5_4ywzju,askmexico,0.904231
3352401,t5_2qhv7,mexico,16762,17695,0.520065,2,t5_2qm06,monterrey,0.864766
3352402,t5_2qhv7,mexico,16762,26092,0.541566,3,t5_2sbh1,mexicali,0.853353
3352403,t5_2qhv7,mexico,16762,37938,0.546211,4,t5_2up3k,ticos,0.850827
3352404,t5_2qhv7,mexico,16762,79115,0.55256,5,t5_3la4d,mujico,0.847338
3352405,t5_2qhv7,mexico,16762,34838,0.555084,6,t5_2tw1p,mexicocity,0.845941
3352406,t5_2qhv7,mexico,16762,37443,0.576658,7,t5_2ujoy,memexico,0.833733
3352407,t5_2qhv7,mexico,16762,13792,0.579503,8,t5_2lxxle,mexicow,0.832088
3352408,t5_2qhv7,mexico,16762,25891,0.588112,9,t5_2samk,guatemala,0.827062
3352409,t5_2qhv7,mexico,16762,101331,0.609711,10,t5_4sbz8m,cdmx,0.814126


In [31]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'formula1']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
3401400,t5_2qimj,formula1,17007,7135,0.334023,1,t5_29o8ec,grandprixracing,0.944214
3401401,t5_2qimj,formula1,17007,54108,0.39302,2,t5_31vs7,scuderiaferrari,0.922768
3401402,t5_2qimj,formula1,17007,52816,0.409588,3,t5_316st,f1feederseries,0.916119
3401403,t5_2qimj,formula1,17007,1721,0.432823,4,t5_13t1oy,mclarenformula1,0.906332
3401404,t5_2qimj,formula1,17007,26493,0.459663,5,t5_2sdeq,indycar,0.894355
3401405,t5_2qimj,formula1,17007,40840,0.467478,6,t5_2vmby,lewishamilton,0.890732
3401406,t5_2qimj,formula1,17007,56939,0.467626,7,t5_33n2v1,astonmartinformula1,0.890663
3401407,t5_2qimj,formula1,17007,61317,0.471614,8,t5_37co3,haasf1team,0.88879
3401408,t5_2qimj,formula1,17007,41108,0.474223,9,t5_2vpfj,formulae,0.887556
3401409,t5_2qimj,formula1,17007,80978,0.491059,10,t5_3ndbi,formuladank,0.879431


In [32]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'worldcup']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
4253800,t5_2rdrs,worldcup,21269,225717,0.482058,1,t5_7ak4bi,world_cup_tv,0.88381
4253801,t5_2rdrs,worldcup,21269,192418,0.503103,2,t5_70yedu,footballwc2022,0.873444
4253802,t5_2rdrs,worldcup,21269,179489,0.510172,3,t5_6xljiq,worldcupfifa,0.869862
4253803,t5_2rdrs,worldcup,21269,7436,0.542714,4,t5_2a5u5m,worldcup_2022,0.85273
4253804,t5_2rdrs,worldcup,21269,230175,0.548586,5,t5_h0487,worldcupbetting,0.849526
4253805,t5_2rdrs,worldcup,21269,219987,0.556651,6,t5_78kzrs,fantasywc,0.84507
4253806,t5_2rdrs,worldcup,21269,222959,0.563547,7,t5_79mmmn,qatarworldcupnews,0.841207
4253807,t5_2rdrs,worldcup,21269,62821,0.596152,8,t5_38ae8,boycottqatarworldcup,0.822301
4253808,t5_2rdrs,worldcup,21269,29242,0.617381,9,t5_2sr4p,womenssoccer,0.80942
4253809,t5_2rdrs,worldcup,21269,173548,0.617591,10,t5_6vz7bd,fifawordcup2022,0.809291


In [33]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'soccer']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
3378400,t5_2qi58,soccer,16892,23674,0.381617,1,t5_2rxse,reddevils,0.927184
3378401,t5_2qi58,soccer,16892,23147,0.388552,2,t5_2rsl6,chelseafc,0.924514
3378402,t5_2qi58,soccer,16892,17434,0.40744,3,t5_2qkr5,football,0.916997
3378403,t5_2qi58,soccer,16892,23777,0.417122,4,t5_2ryq7,coys,0.913004
3378404,t5_2qi58,soccer,16892,27889,0.419106,5,t5_2sk2p,ussoccer,0.912175
3378405,t5_2qi58,soccer,16892,22536,0.421663,6,t5_2rnmt,acmilan,0.9111
3378406,t5_2qi58,soccer,16892,24031,0.424434,7,t5_2s14k,mcfc,0.909928
3378407,t5_2qi58,soccer,16892,16690,0.427173,8,t5_2qhqt,gunners,0.908761
3378408,t5_2qi58,soccer,16892,24716,0.436714,9,t5_2s561,barca,0.90464
3378409,t5_2qi58,soccer,16892,39434,0.438652,10,t5_2v6bc,atletico,0.903792


In [34]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'ligamx'] 
    .head(15)
    # 'r/davidochoa' is relevant to r/ligamx, but very small (last post over 1 month ago). 
    # Might still need to add a filter based on recent activity... otherwise we'll send people to dead subs
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
7434800,t5_2uh0l,ligamx,37174,36323,0.441022,1,t5_2u8t3,chivas,0.90275
7434801,t5_2uh0l,ligamx,37174,37444,0.468274,2,t5_2ujqj,clubamerica,0.89036
7434802,t5_2uh0l,ligamx,37174,234377,0.510977,3,t5_pnkp1,newmexicounited,0.869451
7434803,t5_2uh0l,ligamx,37174,27889,0.551077,4,t5_2sk2p,ussoccer,0.848157
7434804,t5_2uh0l,ligamx,37174,16892,0.555678,5,t5_2qi58,soccer,0.845611
7434805,t5_2uh0l,ligamx,37174,138089,0.557142,6,t5_69129a,davidochoa,0.844797
7434806,t5_2uh0l,ligamx,37174,47519,0.558865,7,t5_2y7k1,cruzazul,0.843835
7434807,t5_2uh0l,ligamx,37174,39434,0.567277,8,t5_2v6bc,atletico,0.839099
7434808,t5_2uh0l,ligamx,37174,22536,0.569231,9,t5_2rnmt,acmilan,0.837988
7434809,t5_2uh0l,ligamx,37174,20999,0.573692,10,t5_2rbnb,mls,0.835439


# Add dt/pt column & metadata columns

In [35]:
d_topk_meta = {
    'pt': datetime.utcnow().strftime("%Y-%m-%d"),
    'mlflow_run_id': run_uuid, 
    'model_name': d_mlf_params['model_name'],
    'model_version': d_mlf_params['model_version'],
}
print(d_topk_meta)
for k, v in d_topk_meta.items():
    df_nn_top[k] = v

{'pt': '2022-11-22', 'mlflow_run_id': '91ac7ca171024c779c0992f59470c81b', 'model_name': 'cau-text-mUSE', 'model_version': 'v0.6.1'}


In [36]:
df_nn_top.tail()

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
47954795,t5_zzw6f,missourisingles,239773,59069,0.672466,196,t5_35fwjx,phlgbtr4r,0.773895,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
47954796,t5_zzw6f,missourisingles,239773,158900,0.672747,197,t5_6qg84c,ohiohookup740,0.773706,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
47954797,t5_zzw6f,missourisingles,239773,5500,0.672805,198,t5_27btg2,gayyoungolddating,0.773667,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
47954798,t5_zzw6f,missourisingles,239773,61432,0.672903,199,t5_37f87,virginityexchange,0.773601,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
47954799,t5_zzw6f,missourisingles,239773,34062,0.673172,200,t5_2tpjl,euro4euro,0.77342,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [37]:
df_nn_top.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 47954800 entries, 0 to 47954799
Data columns (total 13 columns):
 #   Column                  Dtype  
---  ------                  -----  
 0   subreddit_id            object 
 1   subreddit_name          object 
 2   seed_ix                 int64  
 3   nn_ix                   int64  
 4   distance                float64
 5   distance_rank           int64  
 6   similar_subreddit_id    object 
 7   similar_subreddit_name  object 
 8   cosine_similarity       float64
 9   pt                      object 
 10  mlflow_run_id           object 
 11  model_name              object 
 12  model_version           object 
dtypes: float64(2), int64(3), object(8)
memory usage: 4.6+ GB


# Save DF to local & log to Mlflow

Instead of saving it to random location in GCS, save artifact locally & then log it to mlflow job as a new artifact.

Make sure to append a timestamp in case we try different ANN approaches


In [38]:
manual_model_timestamp = datetime.utcnow().strftime('%Y-%m-%d_%H%M%S')
path_this_model = get_project_subfolder(
    f"data/models/ann/manual_v061_{manual_model_timestamp}"
)
Path.mkdir(path_this_model, parents=True, exist_ok=True)
path_this_model

PosixPath('/home/jupyter/subreddit_clustering_i18n/data/models/ann/manual_v061_2022-11-22_230904')

In [39]:
%%time

p_df_subfolder = path_this_model / f"ann_df-{df_nn_top['subreddit_id'].nunique()}-{manual_model_timestamp}"
subfolder_df = p_df_subfolder.name

save_pd_df_to_parquet_in_chunks(
    df_nn_top,
    p_df_subfolder,
    write_index=False
)

23:09:09 | INFO | "Converting pandas to dask..."
23:09:43 | INFO | "  27,406.5 MB <- Memory usage"
23:09:43 | INFO | "      50	<- target Dask partitions	  550.0 <- target MB partition size"


CPU times: user 1min 52s, sys: 13.4 s, total: 2min 5s
Wall time: 1min 31s


### Log to mlflow

In [40]:
%%time

d_mlflow_paths = dict()
info(f"Start logging parquet to mlflow...")
with mlflow.start_run(run_id=run_uuid) as run:
    mlflow.log_artifacts(str(p_df_subfolder), subfolder_df)
    # get path to JSON file so that we can create a table from it
    d_mlflow_paths['mlflow_artifact_df'] = mlflow.get_artifact_uri(
        artifact_path=f"{subfolder_df}"
    )
info(f"Logging artifact complete!")

23:10:36 | INFO | "Start logging parquet to mlflow..."
23:11:22 | INFO | "Logging artifact complete!"


CPU times: user 1.73 s, sys: 1.57 s, total: 3.3 s
Wall time: 45.3 s


In [41]:
d_mlflow_paths

{'mlflow_artifact_df': 'gs://i18n-subreddit-clustering/mlflow/mlruns/40/91ac7ca171024c779c0992f59470c81b/artifacts/ann_df-239774-2022-11-22_230904'}

# Save to JSON for BigQuery


Fixed (2022-11-22) to correct format:
- WANTED: the `similar_subreddits` field should be: 
    - a list of dictionaries
        - each dict is a subreddit

See example of format we want here:
https://github.snooguts.net/reddit/gazette-models/blob/cf324c18d974d0b01bb40c71c7f6425d7ff16576/similar_subreddit/embeddings/local_write.py#L32

```python
def write_similar_subreddit_file(
    date_today: str,
    model_name: str,
    model_version: str,
    filename_path_top_k: Path,
    topk_dict: Dict,
    subreddit2id: Dict,
) -> List:
    with open(filename_path_top_k, "w") as f:
        for sr, sim_sr_pairs in topk_dict.items():
            line_dict: Dict[str, Any] = dict()
            if sim_sr_pairs:  # make sure subreddit list is not empty
                line_dict["pt"] = date_today
                line_dict["model_name"] = model_name
                line_dict["model_version"] = model_version
                line_dict["subreddit_name"] = sr
                line_dict["subreddit_id"] = subreddit2id[sr]

                if sr != sim_sr_pairs[0][0]:
                    raise ValueError(
                        f"Inconsistent subreddit name {sim_sr_pairs[0][0]} with searched name {sr}"
                    )

                sim_srs = []
                for sim_sr, sim_score in sim_sr_pairs[1:]:
                    sim_sr_dict = {
                        "subreddit_name": sim_sr,
                        "subreddit_id": subreddit2id[sim_sr],
                        "score": sim_score.astype(float),
                    }
                    sim_srs.append(sim_sr_dict)

                line_dict["similar_subreddit"] = sim_srs

                line = json.dumps(line_dict)
                f.write(line + "\n")
```

In [42]:
%%time

p_local_json = path_this_model / f"ann_ndjson-{df_nn_top['subreddit_id'].nunique()}-{manual_model_timestamp}"

Path.mkdir(p_local_json, exist_ok=True, parents=True)
subfolder_json = p_local_json.name

f_local_json_name = f"ann_ndjson-{df_nn_top['subreddit_id'].nunique()}_subreddits.json"
f_local_json_full = p_local_json / f_local_json_name

# If we run this multiple times, make sure we don't append duplicated lines
try:
    info(f"Deleting existing file...")
    f_local_json_full.unlink()
except FileNotFoundError as e:
    info(f"NVM, file does not exist yet...\n {e}")

prefix_similar_sub = 'similar'

# These are the cols to nest for similar subreddits
cols_for_similar_sub_ = [
    'subreddit_id',
    'subreddit_name',
    'cosine_similarity',
    'distance_rank',
]

info(f"Start saving df as ndJSON...")
with open(f_local_json_full, 'w') as f:
    for seed_sub_id_, df_seed_ in tqdm(df_nn_top.groupby(['subreddit_id']), mininterval=2):
        d_seed = {
            **d_topk_meta,
            **{
                'subreddit_id': seed_sub_id_,
                'subreddit_name': str(df_seed_['subreddit_name'].values[0]),
                
                # 2022-11-22: fixed the logic for similar_subreddit 
                #   each subreddit should be its own dict
                'similar_subreddit': df_seed_[cols_for_similar_sub_].to_dict(orient='records')
            }
        }
        f.write(json.dumps(d_seed) + "\n")
info(f"Done saving as ndJSON")

23:11:32 | INFO | "Deleting existing file..."
23:11:32 | INFO | "NVM, file does not exist yet...
 [Errno 2] No such file or directory: '/home/jupyter/subreddit_clustering_i18n/data/models/ann/manual_v061_2022-11-22_230904/ann_ndjson-239774-2022-11-22_230904/ann_ndjson-239774_subreddits.json'"
23:11:32 | INFO | "Start saving df as ndJSON..."
100%|██████████| 239774/239774 [09:00<00:00, 443.48it/s]
23:20:39 | INFO | "Done saving as ndJSON"


CPU times: user 8min 57s, sys: 15.2 s, total: 9min 12s
Wall time: 9min 16s


In [43]:
%%time
# log to mlflow

with mlflow.start_run(run_id=run_uuid) as run:
    mlflow.log_artifacts(str(p_local_json), subfolder_json)
    # get path to JSON file so that we can create a table from it
    d_mlflow_paths['mlflow_artifact_json'] = mlflow.get_artifact_uri(
        artifact_path=f"{subfolder_json}/{f_local_json_name}"
    )
info(f"Logging artifact complete!")

23:21:25 | INFO | "Logging artifact complete!"


CPU times: user 3.4 s, sys: 3.52 s, total: 6.91 s
Wall time: 46.6 s


In [44]:
d_mlflow_paths['mlflow_artifact_json']

'gs://i18n-subreddit-clustering/mlflow/mlruns/40/91ac7ca171024c779c0992f59470c81b/artifacts/ann_ndjson-239774-2022-11-22_230904/ann_ndjson-239774_subreddits.json'

# Upload JSON to BQ

Example `schema` here:
- https://github.snooguts.net/reddit/gazette-models/blob/cf324c18d974d0b01bb40c71c7f6425d7ff16576/similar_subreddit/embeddings/bq_write.py

using `bq load` won't work with a JSON schema in BQ.

Instead, let's try using the python client. NOTE: we'll need to get the right authentication in the VM that has the correct read & write access, e.g.,:
```bash
# login
gcloud auth application-default login

# logout
gcloud auth application-default revoke
```

---
example format for path:
```
d_mlflow_paths['mlflow_artifact_json'] = (
    'gs://i18n-subreddit-clustering/mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/ann_ndjson-2022-09-10_003611/ann_ndjson-250573_subreddits.json'
)
```

In [None]:
BREAK

In [45]:
info(f"Creating table from file:\n{d_mlflow_paths['mlflow_artifact_json']}")

load_data_to_bq_table(
    uri=d_mlflow_paths['mlflow_artifact_json'],
    bq_project='reddit-employee-datasets',
    bq_dataset='david_bermejo',
    bq_table_name='cau_similar_subreddits_by_text',
    schema=similar_sub_schema(),
    partition_column='pt',
    table_description=(
        "Table with most similar subreddits by the text (posts & comments) in each sub."
        "  It works across 16 languages. So finance (English), Finanzen(German), & financia(Spanish) will be clustered together."
        "  See wiki: https://reddit.atlassian.net/wiki/spaces/DataScience/pages/2404220935/CA+Embeddings+Topic+Model"
    ),
    update_table_description=True,
)

23:21:25 | INFO | "Creating table from file:
gs://i18n-subreddit-clustering/mlflow/mlruns/40/91ac7ca171024c779c0992f59470c81b/artifacts/ann_ndjson-239774-2022-11-22_230904/ann_ndjson-239774_subreddits.json"
23:21:27 | INFO | "Loading data to table:
  reddit-employee-datasets.david_bermejo.cau_similar_subreddits_by_text"
23:21:27 | INFO | "Table reddit-employee-datasets.david_bermejo.cau_similar_subreddits_by_text already exist"
23:21:28 | INFO | "  0 rows in table BEFORE adding data"
23:22:58 | INFO | "Updating subreddit description from:
  Table with most similar subreddits by the text (posts & comments) in each sub.  It works across 16 languages. So finance (English), Finanzen(German), & financia(Spanish) will be clustered together.  See wiki: https://reddit.atlassian.net/wiki/spaces/DataScience/pages/2404220935/CA+Embeddings+Topic+Model
to:
  Table with most similar subreddits by the text (posts & comments) in each sub.  It works across 16 languages. So finance (English), Finanzen(G

# Appendix

## Check more example outputs

In [46]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'ich_iel']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
12328400,t5_37k29,ich_iel,61642,17868,0.47322,1,t5_2qmr6,aeiou,0.888031,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328401,t5_37k29,ich_iel,61642,64451,0.507657,2,t5_39bxv,ik_ihe,0.871142,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328402,t5_37k29,ich_iel,61642,74678,0.514094,3,t5_3hn0l,deutschememes,0.867854,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328403,t5_37k29,ich_iel,61642,78816,0.525753,4,t5_3kr89k,gekte,0.861792,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328404,t5_37k29,ich_iel,61642,237378,0.527394,5,t5_w2zxy,okoidawappler,0.860928,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328405,t5_37k29,ich_iel,61642,1961,0.562436,6,t5_17d5ey,ichbin40undlustig,0.841833,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328406,t5_37k29,ich_iel,61642,233782,0.563186,7,t5_ofkj1,okbrudimongo,0.841411,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328407,t5_37k29,ich_iel,61642,45396,0.565225,8,t5_2xbtv,buenzli,0.84026,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328408,t5_37k29,ich_iel,61642,52936,0.566677,9,t5_318w4,cirkeltrek,0.839439,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
12328409,t5_37k29,ich_iel,61642,65302,0.567824,10,t5_39uv3,kopiernudeln,0.838788,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [47]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'ireland']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
3294800,t5_2qhb9,ireland,16474,17660,0.354353,1,t5_2qlve,northernireland,0.937217,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294801,t5_2qhb9,ireland,16474,230741,0.417049,2,t5_i25jp,casualireland,0.913035,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294802,t5_2qhb9,ireland,16474,34057,0.462139,3,t5_2tphj,irishproblems,0.893214,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294803,t5_2qhb9,ireland,16474,19816,0.524114,4,t5_2r1hz,dublin,0.862652,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294804,t5_2qhb9,ireland,16474,16614,0.54789,5,t5_2qhma,newzealand,0.849908,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294805,t5_2qhb9,ireland,16474,27124,0.551471,6,t5_2sgbm,irishpolitics,0.84794,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294806,t5_2qhb9,ireland,16474,21068,0.571349,7,t5_2rc51,belfast,0.83678,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294807,t5_2qhb9,ireland,16474,76630,0.573016,8,t5_3j2jr,casualuk,0.835826,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294808,t5_2qhb9,ireland,16474,21187,0.586199,9,t5_2rd5j,auckland,0.828185,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3294809,t5_2qhb9,ireland,16474,17397,0.59645,10,t5_2qkli,scotland,0.822124,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [48]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'vegetarischde']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
18547000,t5_4c06em,vegetarischde,92735,62003,0.443374,1,t5_37ruc,vegande,0.90171,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547001,t5_4c06em,vegetarischde,92735,16837,0.547077,2,t5_2qhzr,vegetarianism,0.850353,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547002,t5_4c06em,vegetarischde,92735,4808,0.576836,3,t5_25v3wn,kreisvegs,0.83363,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547003,t5_4c06em,vegetarischde,92735,16671,0.579313,4,t5_2qhpm,vegan,0.832198,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547004,t5_4c06em,vegetarischde,92735,17750,0.579923,5,t5_2qm7x,vegetarian,0.831845,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547005,t5_4c06em,vegetarischde,92735,35382,0.612797,6,t5_2u0f5t,vegfr,0.81224,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547006,t5_4c06em,vegetarischde,92735,57372,0.627332,7,t5_33xgk,veganuk,0.803227,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547007,t5_4c06em,vegetarischde,92735,38099,0.628234,8,t5_2uquu,askvegans,0.802661,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547008,t5_4c06em,vegetarischde,92735,107,0.633577,9,t5_109235,exvegans,0.79929,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
18547009,t5_4c06em,vegetarischde,92735,40161,0.635087,10,t5_2ven0,antivegan,0.798332,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [49]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'antivegan']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
8032200,t5_2ven0,antivegan,40161,27141,0.283415,1,t5_2sgfh,vegancirclejerk,0.959838,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032201,t5_2ven0,antivegan,40161,16671,0.303616,2,t5_2qhpm,vegan,0.953909,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032202,t5_2ven0,antivegan,40161,232177,0.322566,3,t5_kycqf,veganforcirclejerkers,0.947976,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032203,t5_2ven0,antivegan,40161,52285,0.332124,4,t5_30wk6,veganmemes,0.944847,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032204,t5_2ven0,antivegan,40161,136219,0.33339,5,t5_675dds,vegancirclejerkchat,0.944425,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032205,t5_2ven0,antivegan,40161,107,0.365696,6,t5_109235,exvegans,0.933133,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032206,t5_2ven0,antivegan,40161,38099,0.370972,7,t5_2uquu,askvegans,0.93119,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032207,t5_2ven0,antivegan,40161,25819,0.408916,8,t5_2sa7z,debateavegan,0.916394,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032208,t5_2ven0,antivegan,40161,16837,0.455829,9,t5_2qhzr,vegetarianism,0.89611,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8032209,t5_2ven0,antivegan,40161,183332,0.462295,10,t5_6ylhsr,noreason2bvegan,0.893142,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [50]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'mexico']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
3352400,t5_2qhv7,mexico,16762,106045,0.43765,1,t5_4ywzju,askmexico,0.904231,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352401,t5_2qhv7,mexico,16762,17695,0.520065,2,t5_2qm06,monterrey,0.864766,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352402,t5_2qhv7,mexico,16762,26092,0.541566,3,t5_2sbh1,mexicali,0.853353,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352403,t5_2qhv7,mexico,16762,37938,0.546211,4,t5_2up3k,ticos,0.850827,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352404,t5_2qhv7,mexico,16762,79115,0.55256,5,t5_3la4d,mujico,0.847338,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352405,t5_2qhv7,mexico,16762,34838,0.555084,6,t5_2tw1p,mexicocity,0.845941,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352406,t5_2qhv7,mexico,16762,37443,0.576658,7,t5_2ujoy,memexico,0.833733,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352407,t5_2qhv7,mexico,16762,13792,0.579503,8,t5_2lxxle,mexicow,0.832088,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352408,t5_2qhv7,mexico,16762,25891,0.588112,9,t5_2samk,guatemala,0.827062,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3352409,t5_2qhv7,mexico,16762,101331,0.609711,10,t5_4sbz8m,cdmx,0.814126,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [51]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'memesenespanol']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
0,t5_1009a3,memesenespanol,0,429,0.536542,1,t5_10wycq,memesesp,0.856062,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
1,t5_1009a3,memesenespanol,0,230340,0.564098,2,t5_hc3xv,memesespanol,0.840896,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
2,t5_1009a3,memesenespanol,0,83885,0.568305,3,t5_3qq2qy,beelcitosmemes,0.838515,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3,t5_1009a3,memesenespanol,0,231468,0.584952,4,t5_jhy39,yo_ctm,0.828915,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
4,t5_1009a3,memesenespanol,0,232337,0.602862,5,t5_lana1,memesbr,0.818279,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
5,t5_1009a3,memesenespanol,0,8441,0.603264,6,t5_2buupx,memeitaliani,0.818036,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
6,t5_1009a3,memesenespanol,0,85274,0.607121,7,t5_3wam26,latesitoo,0.815702,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
7,t5_1009a3,memesenespanol,0,234584,0.615698,8,t5_q1xei,memefrancais,0.810458,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
8,t5_1009a3,memesenespanol,0,106292,0.617525,9,t5_4z9yto,memesbuenaonda,0.809332,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
9,t5_1009a3,memesenespanol,0,37443,0.618534,10,t5_2ujoy,memexico,0.808708,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [52]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'de']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
603000,t5_22i0,de,3015,93929,0.389254,1,t5_4egnbw,dezwo,0.924241,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603001,t5_22i0,de,3015,68278,0.477014,2,t5_3caax,600euro,0.886229,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603002,t5_22i0,de,3015,77935,0.481225,3,t5_3jxvk,tja,0.884211,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603003,t5_22i0,de,3015,231106,0.493754,4,t5_irnzx,dachschaden,0.878103,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603004,t5_22i0,de,3015,18131,0.509679,5,t5_2qo9i,austria,0.870114,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603005,t5_22i0,de,3015,64009,0.517639,6,t5_392ha,asozialesnetzwerk,0.866025,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603006,t5_22i0,de,3015,40631,0.571133,7,t5_2vk0m,nachrichten,0.836903,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603007,t5_22i0,de,3015,96508,0.598626,8,t5_4juf8o,politpro,0.820823,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603008,t5_22i0,de,3015,233371,0.607671,9,t5_nls07,belgium2,0.815368,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
603009,t5_22i0,de,3015,231620,0.608017,10,t5_jsyzh,poldersocialisme,0.815158,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [53]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'askfrance']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
10055400,t5_2zkfk,askfrance,50277,16574,0.536449,1,t5_2qhjz,france,0.856111,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055401,t5_2zkfk,askfrance,50277,17418,0.630384,2,t5_2qkoi,paris,0.801308,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055402,t5_2zkfk,askfrance,50277,32495,0.630685,3,t5_2tdpb,suisse,0.801118,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055403,t5_2zkfk,askfrance,50277,75629,0.634011,4,t5_3iawa,pasdequestionidiote,0.799015,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055404,t5_2zkfk,askfrance,50277,20364,0.650897,5,t5_2r6ca,fragreddit,0.788167,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055405,t5_2zkfk,askfrance,50277,59931,0.673849,6,t5_3622g,askargentina,0.772964,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055406,t5_2zkfk,askfrance,50277,29070,0.67707,7,t5_2sq2i,nantes,0.770788,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055407,t5_2zkfk,askfrance,50277,237154,0.681338,8,t5_vnwft,perguntereddit,0.767889,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055408,t5_2zkfk,askfrance,50277,40732,0.684963,9,t5_2vl55,wallonia,0.765413,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
10055409,t5_2zkfk,askfrance,50277,91674,0.706629,10,t5_4amb3y,casualit,0.750338,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


In [54]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'cfb']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,pt,mlflow_run_id,model_name,model_version
3551600,t5_2qm9d,cfb,17758,30337,0.341123,1,t5_2sy54,fcs,0.941818,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551601,t5_2qm9d,cfb,17758,46902,0.372416,2,t5_2xys7,fsusports,0.930653,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551602,t5_2qm9d,cfb,17758,24818,0.399144,3,t5_2s5kg,lsufootball,0.920342,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551603,t5_2qm9d,cfb,17758,38181,0.410466,4,t5_2urol,cfbmemes,0.915759,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551604,t5_2qm9d,cfb,17758,37244,0.421497,5,t5_2uhr8,notredamefootball,0.91117,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551605,t5_2qm9d,cfb,17758,52621,0.423696,6,t5_31327,cfbball,0.910241,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551606,t5_2qm9d,cfb,17758,42943,0.432273,7,t5_2wcz4,theonlycolors,0.90657,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551607,t5_2qm9d,cfb,17758,20292,0.435893,8,t5_2r5u7,ohiostatefootball,0.904999,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551608,t5_2qm9d,cfb,17758,63942,0.444592,9,t5_3918y,cfbvegas,0.901169,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1
3551609,t5_2qm9d,cfb,17758,20265,0.449284,10,t5_2r5kj,sooners,0.899072,2022-11-22,91ac7ca171024c779c0992f59470c81b,cau-text-mUSE,v0.6.1


# Test `search_k`
`search_k=-1` will search all trees and get the most accurate results but it will take longer to compute.

Recommendation: 
<br>use k=-1  or 


Even with small changes we can see in the examples below that there is a time difference and sometimes even in the top10 results we will miss a neighbor when we set k<=3 -- i.e., k=3 -> only search 3 trees).

In [29]:
%%time

n_test_i = 212
nn_index.get_top_n_by_item(n_test_i, k=9, search_k=-1, include_distances=True)

CPU times: user 277 ms, sys: 32.1 ms, total: 310 ms
Wall time: 308 ms


Unnamed: 0,subreddit_id_a,subreddit_name_a,distance_rank,subreddit_id_b,subreddit_name_b,distance
0,t5_10dzqu,godawfulmovies,0,t5_10dzqu,godawfulmovies,0.0
1,t5_10dzqu,godawfulmovies,1,t5_me7ba,podcastsharing,0.69861
2,t5_10dzqu,godawfulmovies,2,t5_2u29p,filmjunk,0.705137
3,t5_10dzqu,godawfulmovies,3,t5_2c7q0h,podcastpromoting,0.716215
4,t5_10dzqu,godawfulmovies,4,t5_n99oj,findthepathpodcast,0.716643
5,t5_10dzqu,godawfulmovies,5,t5_t6jv7,sinisterhood,0.716963
6,t5_10dzqu,godawfulmovies,6,t5_2zzeu,highersidechats,0.720665
7,t5_10dzqu,godawfulmovies,7,t5_np3is,letsgo2courtpodcast,0.721888
8,t5_10dzqu,godawfulmovies,8,t5_2t8p3,wehatemovies,0.723386


In [30]:
%%time
nn_index.get_top_n_by_item(n_test_i, k=9, search_k=1, include_distances=True)

CPU times: user 244 ms, sys: 0 ns, total: 244 ms
Wall time: 243 ms


Unnamed: 0,subreddit_id_a,subreddit_name_a,distance_rank,subreddit_id_b,subreddit_name_b,distance
0,t5_10dzqu,godawfulmovies,0,t5_10dzqu,godawfulmovies,0.0
1,t5_10dzqu,godawfulmovies,1,t5_2u29p,filmjunk,0.705137
2,t5_10dzqu,godawfulmovies,2,t5_2c7q0h,podcastpromoting,0.716215
3,t5_10dzqu,godawfulmovies,3,t5_n99oj,findthepathpodcast,0.716643
4,t5_10dzqu,godawfulmovies,4,t5_2zzeu,highersidechats,0.720665
5,t5_10dzqu,godawfulmovies,5,t5_np3is,letsgo2courtpodcast,0.721888
6,t5_10dzqu,godawfulmovies,6,t5_35xxi9,headgumpodcast,0.738195
7,t5_10dzqu,godawfulmovies,7,t5_2vo38,harmontown,0.741415
8,t5_10dzqu,godawfulmovies,8,t5_26gz8w,theteamhouse,0.746122


In [31]:
top_k_test_ = 20
cols_drop_ = ['subreddit_id_a', 'subreddit_id_b', 'distance']
cols_append_ = ['subreddit_name_b',]
df_compare_sk = nn_index.get_top_n_by_item(
    n_test_i, k=top_k_test_, search_k=-1, include_distances=True
).drop(cols_drop_, axis=1)

for k_ in [int(0.998 * n_trees), int(0.85 * n_trees), 
           int(0.5 * n_trees), min([200, int(0.1 * n_trees)]),
           1]:
    df_compare_sk = pd.concat(
        [
            df_compare_sk,
            nn_index.get_top_n_by_item(
                n_test_i, k=top_k_test_, search_k=k_, include_distances=True
            )[cols_append_].rename(columns={c: f"{c}_{k_}" for c in df_compare_sk.columns})
        ],
        axis=1,
    )
df_compare_sk

Unnamed: 0,subreddit_name_a,distance_rank,subreddit_name_b,subreddit_name_b_199,subreddit_name_b_170,subreddit_name_b_100,subreddit_name_b_20,subreddit_name_b_1
0,godawfulmovies,0,godawfulmovies,godawfulmovies,godawfulmovies,godawfulmovies,godawfulmovies,godawfulmovies
1,godawfulmovies,1,podcastsharing,filmjunk,filmjunk,filmjunk,filmjunk,filmjunk
2,godawfulmovies,2,filmjunk,podcastpromoting,podcastpromoting,podcastpromoting,podcastpromoting,podcastpromoting
3,godawfulmovies,3,podcastpromoting,findthepathpodcast,findthepathpodcast,findthepathpodcast,findthepathpodcast,findthepathpodcast
4,godawfulmovies,4,findthepathpodcast,highersidechats,highersidechats,highersidechats,highersidechats,highersidechats
5,godawfulmovies,5,sinisterhood,letsgo2courtpodcast,letsgo2courtpodcast,letsgo2courtpodcast,letsgo2courtpodcast,letsgo2courtpodcast
6,godawfulmovies,6,highersidechats,headgumpodcast,headgumpodcast,headgumpodcast,headgumpodcast,headgumpodcast
7,godawfulmovies,7,letsgo2courtpodcast,harmontown,harmontown,harmontown,harmontown,harmontown
8,godawfulmovies,8,wehatemovies,theteamhouse,theteamhouse,theteamhouse,theteamhouse,theteamhouse
9,godawfulmovies,9,weeklyplanetpodcast,headgum,headgum,headgum,headgum,headgum
