# Purpose

### 2022-09-09
Test running queries in parallel with `dask`. Now that we'll run ANN for 250+ subreddits, running in a single thread could take a loooong time.


### 2022-08-01
Calculating precise nearest neighbors has become too expensive as we go over 40k subreddits. So instead let's calculate approx nearest neighbors (ANN). 

In this notebook we use [ANNOY](https://github.com/spotify/annoy).  Main reason for using annoy over FAISS is that annoy has official wheels in pypi, but FAISS only officially supports installation from conda. For now we don't want to depend on third-party wheels for FAISS b/c that can be messy to install & replicate in a VM. Maybe when we switch to kubeflow we can try FAISS.


# Notebook setup

In [1]:
%load_ext autoreload
%autoreload 2

In [233]:
from datetime import datetime
import gc
import os
import json
import logging
from logging import info
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import seaborn as sns

import dask
from dask import dataframe as dd
from tqdm import tqdm

import mlflow
import hydra
import annoy

import subclu
from subclu.models.aggregate_embeddings import (
    AggregateEmbeddings, AggregateEmbeddingsConfig,
    load_config_agg_jupyter, get_dask_df_shape,
)
from subclu.models import aggregate_embeddings_pd

from subclu.utils import set_working_directory, get_project_subfolder
from subclu.utils.eda import (
    setup_logging, counts_describe, value_counts_and_pcts,
    notebook_display_config, print_lib_versions,
    style_df_numeric
)
from subclu.utils.mlflow_logger import MlflowLogger, save_pd_df_to_parquet_in_chunks
from subclu.eda.aggregates import (
    compare_raw_v_weighted_language
)
from subclu.utils.data_irl_style import (
    get_colormap, theme_dirl
)
from subclu.utils.big_query_utils import load_data_to_bq_table
from subclu.models.bq_embedding_schemas import embeddings_schema, similar_sub_schema


print_lib_versions([annoy, dask, hydra, mlflow, np, pd, plotly, sns, subclu])

python		v 3.7.10
===
annoy		v: 1.17.0
dask		v: 2021.06.0
hydra		v: 1.1.0
mlflow		v: 1.16.0
numpy		v: 1.19.5
pandas		v: 1.2.4
plotly		v: 4.14.3
seaborn		v: 0.11.1
subclu		v: 0.6.0


In [3]:
# plotting
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.dates as mdates
plt.style.use('default')

setup_logging()
notebook_display_config()

# Set sqlite database as MLflow URI

In [4]:
# use new class to initialize mlflow
mlf = MlflowLogger(tracking_uri='sqlite')
mlflow.get_tracking_uri()

'sqlite:////home/jupyter/subreddit_clustering_i18n/mlflow_sync/djb-100-2021-04-28-djb-eda-german-subs/mlruns.db'

## Get list of experiments with new function

In [5]:
mlf.list_experiment_meta(output_format='pandas').tail(9)

Unnamed: 0,experiment_id,name,artifact_location,lifecycle_stage
30,30,v0.5.0_mUSE_clustering_test,gs://i18n-subreddit-clustering/mlflow/mlruns/30,active
31,31,v0.5.0_mUSE_clustering,gs://i18n-subreddit-clustering/mlflow/mlruns/31,active
32,32,v0.5.0_nearest_neighbors_test,gs://i18n-subreddit-clustering/mlflow/mlruns/32,active
33,33,v0.5.0_nearest_neighbors,gs://i18n-subreddit-clustering/mlflow/mlruns/33,active
34,34,v0.6.0_mUSE_aggregates_test,gs://i18n-subreddit-clustering/mlflow/mlruns/34,active
35,35,v0.6.0_mUSE_aggregates,gs://i18n-subreddit-clustering/mlflow/mlruns/35,active
36,36,v0.6.0_mUSE_clustering_test,gs://i18n-subreddit-clustering/mlflow/mlruns/36,active
37,37,v0.6.0_mUSE_clustering,gs://i18n-subreddit-clustering/mlflow/mlruns/37,active
38,38,v0.6.0_nearest_neighbors,gs://i18n-subreddit-clustering/mlflow/mlruns/38,active


## Get runs from embeddings aggregation jobs

Want to make sure we can load these artifacts for other jobs

In [6]:
%%time

df_mlf_runs =  mlf.search_all_runs(experiment_ids=[35])
df_mlf_runs.shape

CPU times: user 56.5 ms, sys: 8.27 ms, total: 64.8 ms
Wall time: 64 ms


(4, 43)

In [7]:
df_mlf_runs.head()

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.df_subs_agg_c1_uw-cols,metrics.time_fxn-full_aggregation_fxn_minutes,metrics.time_fxn-data_loading_time,metrics.memory_free,metrics.time_fxn-df_posts_agg_c1_no_delay,metrics.time_fxn-df_subs_agg_c1_uw,metrics.memory_total,metrics.df_subs_agg_c1-cols,metrics.df_v_subs-rows,metrics.df_v_subs-cols,metrics.df_subs_agg_c1_uw-rows,metrics.memory_used,metrics.df_v_post_comments-cols,metrics.cpu_count,metrics.df_posts_agg_c1-cols,metrics.df_v_post_comments-rows,metrics.df_posts_agg_c1-rows,metrics.df_subs_agg_c1-rows,metrics.memory_used_percent,metrics.time_fxn-df_subs_agg_c1,params.mlflow_experiment,params.weight_subreddit_meta,params.mlflow_tracking_uri,params.embeddings_post_and_comments_path,params.cpu_count,params.host_name,params.weight_post_and_comments,params.bucket_output,params.embeddings_subreddit_path,params.memory_total,params.agg_style,params.embeddings_bucket,tags.mlflow.source.git.commit,tags.mlflow.user,tags.mlflow.source.name,tags.mlflow.source.type,tags.host_name
0,badc44b0e5ac467da14f710da0b410c6,35,FINISHED,gs://i18n-subreddit-clustering/mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts,2022-08-16 08:41:53.006000+00:00,2022-08-31 04:01:44.575000+00:00,515.0,820.674805,3.698822,1197814.0,544.288655,15.926672,1444961.0,515.0,771760.0,514.0,771760.0,774858.0,515.0,96.0,515.0,51906348.0,51906348.0,771760.0,0.536248,15.926672,v0.6.0_mUSE_aggregates,0.15,sqlite,i18n_topic_model_batch/runs/20220811/post_and_comment_text_combined/text_all/embedding/2022-08-11_084218,96,djb-100-2021-04-28-djb-eda-german-subs,0.85,i18n-subreddit-clustering,i18n_topic_model_batch/runs/20220811/subreddits/text/embedding/2022-08-11_082859,1444961,dask_delayed,i18n-subreddit-clustering,df6a30d80cfe36c1badb1531c7cbae7dd1046f21,jupyter,/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py,LOCAL,djb-100-2021-04-28-djb-eda-german-subs
1,ca79765b72c5428395b02926612d85fd,35,FINISHED,gs://i18n-subreddit-clustering/mlflow/mlruns/35/ca79765b72c5428395b02926612d85fd/artifacts,2022-08-16 08:41:31.162000+00:00,2022-08-31 03:13:27.187000+00:00,,,3.719202,1175788.0,,,1444961.0,,771760.0,514.0,,442282.0,515.0,96.0,,51906348.0,,,0.306086,,v0.6.0_mUSE_aggregates,0.15,sqlite,i18n_topic_model_batch/runs/20220811/post_and_comment_text_combined/text_all/embedding/2022-08-11_084218,96,djb-100-2021-04-28-djb-eda-german-subs,0.85,i18n-subreddit-clustering,i18n_topic_model_batch/runs/20220811/subreddits/text/embedding/2022-08-11_082859,1444961,serial,i18n-subreddit-clustering,df6a30d80cfe36c1badb1531c7cbae7dd1046f21,jupyter,/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py,LOCAL,djb-100-2021-04-28-djb-eda-german-subs
2,7552abcd785d4e229c6272aebf1beaf3,35,FAILED,gs://i18n-subreddit-clustering/mlflow/mlruns/35/7552abcd785d4e229c6272aebf1beaf3/artifacts,2022-08-16 08:35:25.596000+00:00,2022-08-16 08:40:12.627000+00:00,,,0.734633,1281046.0,,,1444961.0,,771760.0,514.0,,59844.0,515.0,96.0,,11644466.0,,,0.041416,,v0.6.0_mUSE_aggregates,0.15,sqlite,i18n_topic_model_batch/runs/20220811/post_and_comment_text_combined/text_all/embedding/2022-08-11_084218,96,djb-100-2021-04-28-djb-eda-german-subs,0.85,i18n-subreddit-clustering,i18n_topic_model_batch/runs/20220811/subreddits/text/embedding/2022-08-11_082859,1444961,dask_delayed,i18n-subreddit-clustering,df6a30d80cfe36c1badb1531c7cbae7dd1046f21,jupyter,/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py,LOCAL,djb-100-2021-04-28-djb-eda-german-subs
3,15b78241482e492cae90644e9733b50c,35,FAILED,gs://i18n-subreddit-clustering/mlflow/mlruns/35/15b78241482e492cae90644e9733b50c/artifacts,2022-08-16 07:37:37.787000+00:00,2022-08-16 07:43:09.168000+00:00,,,,1430384.0,,,1444961.0,,771760.0,514.0,,2822.0,,96.0,,,,,0.001953,,,0.15,,i18n_topic_model_batch/runs/20220811/post_and_comment_text_combined/text_all/embedding/2022-08-11_084218,96,djb-100-2021-04-28-djb-eda-german-subs,0.85,,i18n_topic_model_batch/runs/20220811/subreddits/text/embedding/2022-08-11_082859,1444961,,i18n-subreddit-clustering,df6a30d80cfe36c1badb1531c7cbae7dd1046f21,jupyter,/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py,LOCAL,djb-100-2021-04-28-djb-eda-german-subs


In [8]:
run_uuid = 'badc44b0e5ac467da14f710da0b410c6'

# Check run artifacts

In [9]:
l_artifacts_top_level = mlf.list_run_artifacts(
    run_id=run_uuid,
    only_top_level=True,
    verbose=True,
)
len(l_artifacts_top_level)

08:11:54 | INFO | "   219 <- Artifacts to check count"
08:11:54 | INFO | "   219 <- Artifacts clean count"
08:11:54 | INFO | "     5 <- Artifacts & folders at TOP LEVEL clean count"


5

In [10]:
l_artifacts_all = mlf.list_run_artifacts(
    run_id=run_uuid,
    only_top_level=False,
    verbose=False,
)
len(l_artifacts_all)

08:12:01 | INFO | "   219 <- Artifacts clean count"
08:12:01 | INFO | "     5 <- Artifacts & folders at TOP LEVEL clean count"


219

In [11]:
l_artifacts_top_level

['df_posts_agg_c1',
 'df_subs_agg_c1',
 'df_subs_agg_c1_ndjson',
 'df_subs_agg_c1_unweighted',
 'df_subs_agg_c1_unweighted_ndjson']

In [12]:
l_sub_c = [i for i in l_artifacts_all if 'df_subs_agg_c1' in i]
print(len(l_sub_c))
l_sub_c[:6]

14


['mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1/_common_metadata',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1/_metadata',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1/part.0.parquet',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1/part.1.parquet',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1/part.2.parquet',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1/part.3.parquet']

In [13]:
l_post_c = [i for i in l_artifacts_all if 'df_posts_agg_c1' in i]
print(len(l_post_c))
l_post_c[:6]

205


['mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_posts_agg_c1/_common_metadata',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_posts_agg_c1/_metadata',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_posts_agg_c1/part.0.parquet',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_posts_agg_c1/part.1.parquet',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_posts_agg_c1/part.10.parquet',
 'mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_posts_agg_c1/part.100.parquet']

# Load aggregated embeddings

use `gsutil` to download embeddings for posts b/c that can take a LONG time to download sequentially. `gsutil` makes parallel downloaidng much faster and reports download speeds above 500MB / s:
```bash
ents_sub_desc/part.67.parquet...
/ [2/197 files][ 61.7 GiB/ 75.4 GiB]  81% Done 632.0 MiB/s ETA 00:00:22
```

In [14]:
%%time

df_agg_sub_c = mlf.read_run_artifact(
    run_id=run_uuid,
    artifact_folder='df_subs_agg_c1',
    read_function='pd_parquet',
    verbose=False,
)
print(df_agg_sub_c.shape)

08:12:07 | INFO | "Local folder to download artifact(s):
  /home/jupyter/subreddit_clustering_i18n/data/local_cache/mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/df_subs_agg_c1"
100%|########################################| 14/14 [00:00<00:00, 51554.22it/s]
08:12:07 | INFO | "  Parquet files found:     4"
08:12:07 | INFO | "  Parquet files to use:     4"


(771760, 515)
CPU times: user 9.88 s, sys: 3.61 s, total: 13.5 s
Wall time: 7.63 s


In [15]:
df_agg_sub_c.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 771760 entries, 0 to 771759
Columns: 515 entries, subreddit_id to embeddings_511
dtypes: float32(512), int64(1), object(2)
memory usage: 1.5+ GB


In [16]:
df_agg_sub_c.iloc[:5, :25]

Unnamed: 0,subreddit_id,subreddit_name,posts_for_embeddings_count,embeddings_0,embeddings_1,embeddings_2,embeddings_3,embeddings_4,embeddings_5,embeddings_6,embeddings_7,embeddings_8,embeddings_9,embeddings_10,embeddings_11,embeddings_12,embeddings_13,embeddings_14,embeddings_15,embeddings_16,embeddings_17,embeddings_18,embeddings_19,embeddings_20,embeddings_21
0,t5_1001tl,jewel_xo,1,-0.013827,-0.022239,0.049441,0.04947,-0.003573,0.0403,-0.017904,0.008067,-0.037719,-0.002597,0.008067,0.063312,0.014693,-0.042116,0.039357,-0.001491,0.068194,-0.0228,0.044375,0.017662,0.055988,-0.039419
1,t5_10029e,milkyhentai,1,-0.023227,-0.002677,0.031942,-0.010885,-0.02366,0.031216,0.044461,-0.024292,0.015413,0.047858,0.066368,0.076699,-0.040204,-0.004243,-0.048563,0.01012,0.024157,-0.020109,0.05592,0.015352,0.033357,0.010606
2,t5_1006k8,badwouldyourather,1,-0.032487,0.024979,-0.021948,0.039006,0.053261,0.037567,0.036113,0.011514,0.012002,0.020055,0.052276,0.02545,0.050676,-8e-06,-0.005012,0.000559,0.058759,-0.002293,0.010347,0.009482,0.024522,-0.02372
3,t5_100806,jojojosiah,2,0.004711,0.005103,0.037912,0.023591,0.02459,0.029586,-0.012185,-0.031729,-0.016308,0.063303,-0.015289,0.008682,0.008985,0.006243,-0.000484,0.015052,0.003276,0.002508,0.009955,-0.004335,0.001302,-0.016699
4,t5_1009a3,memesenespanol,380,0.003731,-0.013876,-0.003987,0.002683,-0.010202,0.038552,0.012759,0.016535,-0.056693,0.001183,0.009329,-0.005247,0.00963,-0.001513,0.000606,-0.004258,-0.005582,-0.002777,-0.001939,0.002463,-0.003895,-0.006126


In [17]:
df_agg_sub_c.iloc[-10:, :25]

Unnamed: 0,subreddit_id,subreddit_name,posts_for_embeddings_count,embeddings_0,embeddings_1,embeddings_2,embeddings_3,embeddings_4,embeddings_5,embeddings_6,embeddings_7,embeddings_8,embeddings_9,embeddings_10,embeddings_11,embeddings_12,embeddings_13,embeddings_14,embeddings_15,embeddings_16,embeddings_17,embeddings_18,embeddings_19,embeddings_20,embeddings_21
771750,t5_6qo047,scienceno488,0,0.049827,0.008324,0.033611,0.032743,0.025569,0.08428,-0.016352,-0.054408,-0.030566,-0.041512,-0.073808,0.016681,-0.042703,0.033103,-0.078955,-0.012576,-0.057177,-0.055155,-0.023454,-0.001195,-0.053379,0.017656
771751,t5_6phqtd,jessivann,0,-0.03014,-0.02967,0.020029,0.082169,0.044654,0.040252,-0.059091,-0.060705,-0.063691,-0.000536,0.060795,0.091693,-0.052438,-0.0421,-0.031989,0.03837,-0.0137,-0.020486,0.059314,0.009449,0.082631,-0.055906
771752,t5_6ttq8w,emiii,0,-0.014652,-0.025286,0.043648,0.031067,0.015594,0.04937,-0.0435,-0.038448,-0.079371,0.045315,0.05531,0.089274,-0.030252,0.011965,-0.010014,-0.002786,0.049631,0.012006,0.05387,0.002287,0.051208,0.021202
771753,t5_6hetta,tinomantana,0,-0.039023,0.00349,0.019,0.030172,0.045525,0.027864,-0.06264,-0.071491,-0.060166,0.022435,0.082675,0.084102,-0.073745,0.023707,-0.021793,-0.006323,0.015783,-0.057229,0.059388,0.006303,0.070736,-0.043325
771754,t5_6t85bh,twistedspun,0,-0.041944,0.011073,0.008569,0.025391,0.04935,0.047871,-0.018863,0.007396,-0.059508,0.070213,0.088125,0.106481,-0.048567,0.022983,-0.02309,-0.011047,0.04233,-0.022907,0.015026,0.017319,0.03738,0.012689
771755,t5_wi2zh,autocadmemes,0,-0.031634,-0.03242,-0.076026,-0.027858,0.026071,-0.07847,0.050454,0.065865,-0.03423,-0.037524,0.073746,-0.062604,-0.046956,-0.072842,-0.072843,0.040903,0.066408,0.024241,-0.073683,0.017086,-0.057149,0.00301
771756,t5_6rhnln,gamestopcensorship,0,-0.03322,-0.046152,0.040127,-0.071822,0.037267,0.045015,-0.014639,0.002809,-0.014747,-0.03038,0.042313,-0.050563,-0.068255,-0.057389,0.008784,-0.008465,0.00622,-0.029101,-0.070836,-0.01327,0.021226,0.000765
771757,t5_6pg8i2,coffeepoblog,0,-0.023961,-0.00217,0.044053,-0.000543,0.048995,0.034303,-0.014573,0.051975,-0.068334,-0.032103,-0.026937,-0.056592,-0.054808,0.002123,-0.082676,0.071959,-0.061371,-0.02176,0.016584,-0.02959,0.019676,-0.072884
771758,t5_6cv9bz,threadtreatment,0,-0.020639,0.030801,-0.007754,0.075224,0.028055,0.026198,0.027853,0.057297,-0.058218,0.045436,0.031521,0.02764,-0.053154,0.035169,-0.021854,-0.052122,0.019887,-0.027726,-0.095366,-0.017542,-0.018356,-0.046369
771759,t5_4la7jo,spongebobcringememes,0,-0.004059,0.013185,-0.059028,0.054473,0.048189,0.0258,0.028421,0.038999,-0.09201,0.065643,0.035775,-0.049399,-0.010349,-0.042322,-0.045216,-0.043074,0.044714,-0.033876,0.0506,0.025916,0.04787,5.7e-05


# Filter subreddits to use in ANN index

In a previous version we only kept subs that had embeddings AND clustering data. 
<br>Now that we cover 700k subreddits for v0.6.0, we need to be more thoughtfula bout how we'll select which subs to keep for ANN.

For v0.6.0 we'll keep only subs that have 5+ posts in L90 days. From this mode dashboard we expect that number to be around 250k subreddits.

Mode Dashboard: https://app.mode.com/reddit/reports/e6cde33162c4 

## Load metdata to apply other filters [optional]

If we want to filter subreddits based on other data, we'll need to pull data from mlflow or BigQuery.


In [18]:
# run_id_final_model = ''

In [19]:
# l_artifacts_top_level = mlf.list_run_artifacts(
#     run_id=run_id_final_model,
#     only_top_level=True,
#     verbose=True,
# )
# len(l_artifacts_top_level)

In [20]:
# l_artifacts_all = mlf.list_run_artifacts(
#     run_id=run_id_final_model,
#     only_top_level=False,
#     verbose=False,
# )
# len(l_artifacts_all)

In [21]:
# l_artifacts_top_level

In [22]:
# l_sub_c = [i for i in l_artifacts_all if 'df_labels' in i]
# print(len(l_sub_c))
# l_sub_c[:6]

In [23]:
# df_labels = mlf.read_run_artifact(
#     run_id=run_id_final_model,
#     artifact_folder='df_labels',
#     read_function='pd_parquet',
#     verbose=False,
# )
# print(df_labels.shape)

In [24]:
# df_labels.iloc[:5, :15]

## Apply filters

In v0.6.0, we already have the number of posts for embedding in the embedding file, so we don't need to load additional files to apply filters

In [25]:
# use df_pc_counts because it has the counts for post+comment after filtering for length
value_counts_and_pcts(
  pd.cut(
      df_agg_sub_c['posts_for_embeddings_count'],
      bins=[-1, 0, 1, 2, 3, 4, 5, 49, np.inf],
      labels=[
        "00 posts", "01 post", '02 posts', '03 posts',
        '04 posts', '05 posts'
        , '06-49 posts', '50+ posts'
      ]
  ).rename('posts_with_len_3+'),
  sort_index=True,
  add_col_prefix=False,
  count_type='subreddits',
  sort_index_ascending=False,
  cumsum_count=True,
  reset_index=True,
).hide_index().set_caption(f"<h4 align='left'>Post distribution for subreddits with 1 view & 1 attempted post in L90-days</h4>")

posts_with_len_3+,subreddits_count,percent_of_subreddits,cumulative_sum_of_subreddits,cumulative_percent_of_subreddits
50+ posts,72510,9.4%,72510,9.4%
06-49 posts,154858,20.1%,227368,29.5%
05 posts,23205,3.0%,250573,32.5%
04 posts,33084,4.3%,283657,36.8%
03 posts,56898,7.4%,340555,44.1%
02 posts,125070,16.2%,465625,60.3%
01 post,240338,31.1%,705963,91.5%
00 posts,65797,8.5%,771760,100.0%


In [26]:
df_agg_sub_c_raw = df_agg_sub_c.copy()

df_agg_sub_c = df_agg_sub_c[df_agg_sub_c['posts_for_embeddings_count'] >= 5]
df_agg_sub_c.shape

(250573, 515)

# Build annoy index

I created a custom `AnnoyIndex` class with some extra methods to create outputs & (and calculate cosine distance) for BigQuery.

In [27]:
from subclu.models.nn_annoy import AnnoyIndex

In [28]:
%%time

index_cols = ['subreddit_id', 'subreddit_name']
l_embedding_cols = [c for c in df_agg_sub_c.columns if c.startswith('embeddings_')]
n_trees = 200

nn_index = AnnoyIndex(
    df_agg_sub_c[l_embedding_cols + index_cols],
    index_cols=index_cols,
    metric='angular',
    n_trees=n_trees,
)

nn_index.build()

CPU times: user 1h 1min 57s, sys: 40.4 s, total: 1h 2min 37s
Wall time: 1min 2s


## Test `search_k`
`search_k=-1` will search all trees and get the most accurate results but it will take longer to compute.

Even with small changes we can see in the examples below that there is a time difference and sometimes even in the top10 results we will miss a neighbor when we set k<=3 (only search k=3 -> only search 3 trees).

In [29]:
%%time

n_test_i = 212
nn_index.get_top_n_by_item(n_test_i, k=9, search_k=-1, include_distances=True)

CPU times: user 275 ms, sys: 72 ms, total: 347 ms
Wall time: 346 ms


Unnamed: 0,subreddit_id_a,subreddit_name_a,distance_rank,subreddit_id_b,subreddit_name_b,distance
0,t5_10dzqu,godawfulmovies,0,t5_10dzqu,godawfulmovies,0.0
1,t5_10dzqu,godawfulmovies,1,t5_me7ba,podcastsharing,0.69861
2,t5_10dzqu,godawfulmovies,2,t5_2u29p,filmjunk,0.705137
3,t5_10dzqu,godawfulmovies,3,t5_2c7q0h,podcastpromoting,0.716215
4,t5_10dzqu,godawfulmovies,4,t5_n99oj,findthepathpodcast,0.716643
5,t5_10dzqu,godawfulmovies,5,t5_t6jv7,sinisterhood,0.716963
6,t5_10dzqu,godawfulmovies,6,t5_2zzeu,highersidechats,0.720665
7,t5_10dzqu,godawfulmovies,7,t5_np3is,letsgo2courtpodcast,0.721888
8,t5_10dzqu,godawfulmovies,8,t5_2t8p3,wehatemovies,0.723386


In [30]:
%%time
nn_index.get_top_n_by_item(n_test_i, k=9, search_k=1, include_distances=True)

CPU times: user 245 ms, sys: 12 µs, total: 245 ms
Wall time: 244 ms


Unnamed: 0,subreddit_id_a,subreddit_name_a,distance_rank,subreddit_id_b,subreddit_name_b,distance
0,t5_10dzqu,godawfulmovies,0,t5_10dzqu,godawfulmovies,0.0
1,t5_10dzqu,godawfulmovies,1,t5_2u29p,filmjunk,0.705137
2,t5_10dzqu,godawfulmovies,2,t5_2c7q0h,podcastpromoting,0.716215
3,t5_10dzqu,godawfulmovies,3,t5_n99oj,findthepathpodcast,0.716643
4,t5_10dzqu,godawfulmovies,4,t5_2zzeu,highersidechats,0.720665
5,t5_10dzqu,godawfulmovies,5,t5_np3is,letsgo2courtpodcast,0.721888
6,t5_10dzqu,godawfulmovies,6,t5_35xxi9,headgumpodcast,0.738195
7,t5_10dzqu,godawfulmovies,7,t5_2vo38,harmontown,0.741415
8,t5_10dzqu,godawfulmovies,8,t5_26gz8w,theteamhouse,0.746122


In [35]:
top_k_test_ = 20
cols_drop_ = ['subreddit_id_a', 'subreddit_id_b', 'distance']
cols_append_ = ['subreddit_name_b',]
df_compare_sk = nn_index.get_top_n_by_item(
    n_test_i, k=top_k_test_, search_k=-1, include_distances=True
).drop(cols_drop_, axis=1)

for k_ in [int(0.998 * n_trees), int(0.85 * n_trees), 
           int(0.5 * n_trees), min([200, int(0.1 * n_trees)]),
           1]:
    df_compare_sk = pd.concat(
        [
            df_compare_sk,
            nn_index.get_top_n_by_item(
                n_test_i, k=top_k_test_, search_k=k_, include_distances=True
            )[cols_append_].rename(columns={c: f"{c}_{k_}" for c in df_compare_sk.columns})
        ],
        axis=1,
    )
df_compare_sk

Unnamed: 0,subreddit_name_a,distance_rank,subreddit_name_b,subreddit_name_b_199,subreddit_name_b_170,subreddit_name_b_100,subreddit_name_b_20,subreddit_name_b_1
0,godawfulmovies,0,godawfulmovies,godawfulmovies,godawfulmovies,godawfulmovies,godawfulmovies,godawfulmovies
1,godawfulmovies,1,podcastsharing,filmjunk,filmjunk,filmjunk,filmjunk,filmjunk
2,godawfulmovies,2,filmjunk,podcastpromoting,podcastpromoting,podcastpromoting,podcastpromoting,podcastpromoting
3,godawfulmovies,3,podcastpromoting,findthepathpodcast,findthepathpodcast,findthepathpodcast,findthepathpodcast,findthepathpodcast
4,godawfulmovies,4,findthepathpodcast,highersidechats,highersidechats,highersidechats,highersidechats,highersidechats
5,godawfulmovies,5,sinisterhood,letsgo2courtpodcast,letsgo2courtpodcast,letsgo2courtpodcast,letsgo2courtpodcast,letsgo2courtpodcast
6,godawfulmovies,6,highersidechats,headgumpodcast,headgumpodcast,headgumpodcast,headgumpodcast,headgumpodcast
7,godawfulmovies,7,letsgo2courtpodcast,harmontown,harmontown,harmontown,harmontown,harmontown
8,godawfulmovies,8,wehatemovies,theteamhouse,theteamhouse,theteamhouse,theteamhouse,theteamhouse
9,godawfulmovies,9,weeklyplanetpodcast,headgum,headgum,headgum,headgum,headgum


## Get df with all items

For 80k subreddits it took:

Test a few values of k_search.
```bash
100%|██████████| 81973/81973 [1:17:02<00:00, 17.73it/s]
17:07:23 | INFO | "(8115327, 7) <- df_top_items shape"
```

In [189]:
%%time

df_nn_top = nn_index.get_top_n_by_item_all_fast(
    k=3,
    search_k=5,
    include_distances=True,
    append_i=True,
    cosine_similarity=True,
    n_sample=8000,
)

100%|##########| 8000/8000 [00:07<00:00, 1091.90it/s]
21:13:48 | INFO | "Start combining all ANNs into a df..."
21:13:50 | INFO | "(24000, 4) <- df_nn_top shape"
21:13:50 | INFO | "Adding index labels (subreddit ID & Name)"
21:13:50 | INFO | "Done adding index names"
21:13:50 | INFO | "(24000, 8) <- df_nn_top shape"
21:13:50 | INFO | "Calculating cosine similarity..."


CPU times: user 8.79 s, sys: 240 ms, total: 9.03 s
Wall time: 9.02 s


In [190]:
df_nn_top.head(9)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
0,t5_1009a3,memesenespanol,0,240461,0.531268,1,t5_hc3xv,memesespanol,0.858877
1,t5_1009a3,memesenespanol,0,483,0.561,2,t5_10wycq,memesesp,0.842639
2,t5_1009a3,memesenespanol,0,90209,0.572214,3,t5_3qq2qy,beelcitosmemes,0.836285
3,t5_100a1y,karstcast,1,24691,0.831054,1,t5_2rmzr,cyr,0.654674
4,t5_100a1y,karstcast,1,127204,0.880385,2,t5_5bak6c,zachhonest,0.612461
5,t5_100a1y,karstcast,1,94757,0.885147,3,t5_42apvd,mindscepter,0.608258
6,t5_100cmu,mechanicaladvice,2,10357,0.426811,1,t5_2dos9x,mechanichelp,0.908916
7,t5_100cmu,mechanicaladvice,2,44382,0.442545,2,t5_2vsnf,askmechanics,0.902077
8,t5_100cmu,mechanicaladvice,2,40782,0.47093,3,t5_2upg7,carhelp,0.889112


In [191]:
df_nn_top.tail(9)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
23991,t5_29sgdw,jesmonite,7997,52097,0.620637,1,t5_2yqob,resincasting,0.807405
23992,t5_29sgdw,jesmonite,7997,75637,0.639685,2,t5_3ej1m,moldmaking,0.795401
23993,t5_29sgdw,jesmonite,7997,6437,0.699211,3,t5_27ysfu,dicemaking,0.755552
23994,t5_29sgft,mshumor,7998,174914,0.611478,1,t5_6ebb0a,chronicillnessmemes,0.813047
23995,t5_29sgft,mshumor,7998,130606,0.715889,2,t5_5fd7m0,dopehumor,0.743751
23996,t5_29sgft,mshumor,7998,199584,0.736238,3,t5_6jxn97,justforfunsies,0.728977
23997,t5_29sl3d,yiffmilfs,7999,204936,0.315177,1,t5_6lbhid,hentaibabe,0.950332
23998,t5_29sl3d,yiffmilfs,7999,123686,0.316145,2,t5_56ww16,normanporncollection,0.950026
23999,t5_29sl3d,yiffmilfs,7999,95687,0.31766,3,t5_43yyr2,furryblowjobs,0.949546


# Check some example outputs

In [193]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'ich_iel']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


In [194]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'vegetarischde']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


In [195]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'antivegan']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


In [196]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'mexico']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


In [197]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'memesenespanol']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
0,t5_1009a3,memesenespanol,0,240461,0.531268,1,t5_hc3xv,memesespanol,0.858877
1,t5_1009a3,memesenespanol,0,483,0.561,2,t5_10wycq,memesesp,0.842639
2,t5_1009a3,memesenespanol,0,90209,0.572214,3,t5_3qq2qy,beelcitosmemes,0.836285


In [198]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'de']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity
10143,t5_22i0,de,3381,102150,0.449787,1,t5_4egnbw,dezwo,0.898846
10144,t5_22i0,de,3381,73325,0.506051,2,t5_3caax,600euro,0.871956
10145,t5_22i0,de,3381,83717,0.544782,3,t5_3jxvk,tja,0.851606


In [199]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'ich_iel']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


In [200]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'cfb']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


In [201]:
(
    df_nn_top[df_nn_top['subreddit_name'] == 'finanzen']
    .head(15)
)

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity


# Add dt/pt column & metadata columns

In [210]:
d_topk_meta = {
    'pt': datetime.utcnow().strftime("%Y-%m-%d"),
    'mlflow_run_id': run_uuid, 
    'model_name': 'cau-text-mUSE',
    'model_version': 'v0.6.0',
}
for k, v in d_topk_meta.items():
    df_nn_top[k] = v

In [213]:
df_nn_top = df_nn_top.drop('dt', axis=1)

In [215]:
df_nn_top.tail()

Unnamed: 0,subreddit_id,subreddit_name,seed_ix,nn_ix,distance,distance_rank,similar_subreddit_id,similar_subreddit_name,cosine_similarity,model_version,pt,mlflow_run_id,model_name
23995,t5_29sgft,mshumor,7998,130606,0.715889,2,t5_5fd7m0,dopehumor,0.743751,v0.6.0,2022-09-09,badc44b0e5ac467da14f710da0b410c6,cau-text-mUSE
23996,t5_29sgft,mshumor,7998,199584,0.736238,3,t5_6jxn97,justforfunsies,0.728977,v0.6.0,2022-09-09,badc44b0e5ac467da14f710da0b410c6,cau-text-mUSE
23997,t5_29sl3d,yiffmilfs,7999,204936,0.315177,1,t5_6lbhid,hentaibabe,0.950332,v0.6.0,2022-09-09,badc44b0e5ac467da14f710da0b410c6,cau-text-mUSE
23998,t5_29sl3d,yiffmilfs,7999,123686,0.316145,2,t5_56ww16,normanporncollection,0.950026,v0.6.0,2022-09-09,badc44b0e5ac467da14f710da0b410c6,cau-text-mUSE
23999,t5_29sl3d,yiffmilfs,7999,95687,0.31766,3,t5_43yyr2,furryblowjobs,0.949546,v0.6.0,2022-09-09,badc44b0e5ac467da14f710da0b410c6,cau-text-mUSE


In [216]:
df_nn_top.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 24000 entries, 0 to 23999
Data columns (total 13 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   subreddit_id            24000 non-null  object 
 1   subreddit_name          24000 non-null  object 
 2   seed_ix                 24000 non-null  int64  
 3   nn_ix                   24000 non-null  int64  
 4   distance                24000 non-null  float64
 5   distance_rank           24000 non-null  int64  
 6   similar_subreddit_id    24000 non-null  object 
 7   similar_subreddit_name  24000 non-null  object 
 8   cosine_similarity       24000 non-null  float64
 9   model_version           24000 non-null  object 
 10  pt                      24000 non-null  object 
 11  mlflow_run_id           24000 non-null  object 
 12  model_name              24000 non-null  object 
dtypes: float64(2), int64(3), object(8)
memory usage: 2.4+ MB


# Save DF to local & log to Mlflow

Instead of saving it to random location in GCS, save artifact locally & then log it to mlflow job as a new artifact.

Make sure to append a timestamp in case we try different ANN approaches


In [217]:
manual_model_timestamp = datetime.utcnow().strftime('%Y-%m-%d_%H%M%S')
path_this_model = get_project_subfolder(
    f"data/models/ann/manual_v060_{manual_model_timestamp}"
)
Path.mkdir(path_this_model, parents=True, exist_ok=True)
path_this_model

PosixPath('/home/jupyter/subreddit_clustering_i18n/data/models/ann/manual_v060_2022-09-09_212342')

In [220]:
%%time

p_df_subfolder = path_this_model / f'ann_df-{manual_model_timestamp}'
subfolder_df = p_df_subfolder.name

save_pd_df_to_parquet_in_chunks(
    df_nn_top,
    p_df_subfolder,
    write_index=False
)

21:33:33 | INFO | "Converting pandas to dask..."
21:33:33 | INFO | "    13.7 MB <- Memory usage"
21:33:33 | INFO | "       1	<- target Dask partitions	  120.0 <- target MB partition size"


CPU times: user 76.2 ms, sys: 8.01 ms, total: 84.2 ms
Wall time: 103 ms


### Log to mlflow

In [231]:
%%time

d_mlflow_paths = dict()
with mlflow.start_run(run_id=run_uuid) as run:
    mlflow.log_artifacts(str(p_df_subfolder), subfolder_df)
    # get path to JSON file so that we can create a table from it
    d_mlflow_paths['mlflow_artifact_df'] = mlflow.get_artifact_uri(
        artifact_path=f"{subfolder_df}"
    )
info(f"Logging artifact complete!")

21:48:47 | INFO | "Logging artifact complete!"


CPU times: user 233 ms, sys: 236 ms, total: 469 ms
Wall time: 2.37 s


In [232]:
d_mlflow_paths

{'mlflow_artifact_df': 'gs://i18n-subreddit-clustering/mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/ann_df_test-2022-09-09_212342'}

# Save to JSON for BigQuery

In [237]:
%%time

p_local_json = path_this_model / f'ann_ndjson_test-{manual_model_timestamp}'
Path.mkdir(p_local_json, exist_ok=True, parents=True)
subfolder_json = p_local_json.name

f_local_json_name = f"ann_ndjson-{df_nn_top['subreddit_id'].nunique()}_subreddits.json"
f_local_json_full = p_local_json / f_local_json_name


with open(f_local_json_full, 'w') as f:
    for seed_sub_id_, df_seed_ in tqdm(df_nn_top.groupby(['subreddit_id'])):

        d_seed = {
            **d_topk_meta,
            **{
                'subreddit_id': seed_sub_id_,
                'subreddit_name': str(df_seed_['subreddit_name'].values[0]),
                'similar_subreddit': {
                    'subreddit_id': list(df_seed_[f'{prefix_similar_sub}_subreddit_id']),
                    'subreddit_name': list(df_seed_[f'{prefix_similar_sub}_subreddit_name']),
                    'cosine_similarity': list(df_seed_['cosine_similarity']),
                    'distance_rank': list(df_seed_['distance_rank']),
                }
            }
        }
        f.write(json.dumps(d_seed) + "\n")

100%|██████████| 8000/8000 [00:01<00:00, 4027.19it/s]


CPU times: user 2.05 s, sys: 36.3 ms, total: 2.09 s
Wall time: 2.09 s


In [242]:
%%time
# log to mlflow

with mlflow.start_run(run_id=run_uuid) as run:
    mlflow.log_artifacts(str(p_local_json), subfolder_json)
    # get path to JSON file so that we can create a table from it
    d_mlflow_paths['mlflow_artifact_json'] = mlflow.get_artifact_uri(
        artifact_path=f"{subfolder_json}/{f_local_json_name}"
    )
info(f"Logging artifact complete!")

22:14:00 | INFO | "Logging artifact complete!"


CPU times: user 245 ms, sys: 228 ms, total: 473 ms
Wall time: 2 s


In [248]:
d_mlflow_paths['mlflow_artifact_json']

'gs://i18n-subreddit-clustering/mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/ann_ndjson_test-2022-09-09_212342/ann_ndjson-8000_subreddits.json'

# Upload JSON to BQ

using `bq load` won't work with a JSON schema in BQ.

Instead, let's try using the python client. NOTE: we'll need to get the right authentication in the VM that has the correct read & write access, e.g.,:
```bash
# login
gcloud auth application-default login

# logout
gcloud auth application-default revoke
```

In [250]:
info(f"Creating table from file:\n{d_mlflow_paths['mlflow_artifact_json']}")

load_data_to_bq_table(
    uri=d_mlflow_paths['mlflow_artifact_json'],
    bq_project='reddit-employee-datasets',
    bq_dataset='david_bermejo',
    bq_table_name='cau_similar_subreddits_by_text',
    schema=similar_sub_schema(),
    partition_column='pt',
    table_description=(
        "Table with most similar subreddits by the text (posts & comments) in each sub."
        "  It works across 16 languages. So finance (English), Finanzen(German), & financia(Spanish) will be clustered together."
        "  See wiki: https://reddit.atlassian.net/wiki/spaces/DataScience/pages/2404220935/CA+Embeddings+Topic+Model"
    ),
    update_table_description=True,
)

22:30:52 | INFO | "Creating table from file:
gs://i18n-subreddit-clustering/mlflow/mlruns/35/badc44b0e5ac467da14f710da0b410c6/artifacts/ann_ndjson_test-2022-09-09_212342/ann_ndjson-8000_subreddits.json"
22:30:54 | INFO | "Loading data to table:
  reddit-employee-datasets.david_bermejo.cau_similar_subreddits_by_text"
22:30:54 | INFO | "Created table reddit-employee-datasets.david_bermejo.cau_similar_subreddits_by_text"
22:30:54 | INFO | "  0 rows in table BEFORE adding data"
22:30:57 | INFO | "Updating subreddit description from:
  Table with most similar subreddits by the text (posts & comments) in each sub.  It works across 16 languages. So finance (English), Finanzen(German), & financia(Spanish) will be clustered together.  See wiki: https://reddit.atlassian.net/wiki/spaces/DataScience/pages/2404220935/CA+Embeddings+Topic+Model
to:
  Table with most similar subreddits by the text (posts & comments) in each sub.  It works across 16 languages. So finance (English), Finanzen(German), & 

# Appendix & Scratch

In [None]:
BREAK

In [18]:
# %%time
# # use gsutil to download post-level embeddings b/c it'll be much faster to run it in parallel

# remote_key =  'mlflow/mlruns/29/bfe6cbd59a21480c8c2b9923a3a9cbd6/artifacts/df_subs_agg_c1'

# # Need to remove the last part of the local path otherwise we'll get duplicate subfolders:
# #. top/2021-12-14/2021-12-14 instead of top/2021-12-14
# local_f = f"/home/jupyter/subreddit_clustering_i18n/data/local_cache/{'/'.join(remote_key.split('/')[:-1])}"
# Path(local_f).mkdir(parents=True, exist_ok=True)
# remote_gs_path = f"gs://i18n-subreddit-clustering/{remote_key}"
# print(f"Remote path:\n  {remote_gs_path}")
# print(f"Local path:\n  {local_f}")

# # `-n` flag means "no clober", so it should skip existing files (only copy new files)
# !gsutil -m cp -r -n $remote_gs_path $local_f

In [21]:
## We'll do posts in a separate notebook
# %%time

# df_agg_posts_c = mlf.read_run_artifact(
#     run_id=run_uuid,
#     artifact_folder='df_post_level_agg_c_post_comments_sub_desc',
#     read_function='pd_parquet',
#     verbose=False,
# )
# print(df_agg_posts_c.shape)

In [62]:
%%time

df_nn_top = df_nn_top.copy()

# append IDs & names for seed & nn (nearest neighbors)
df_nn_top = (
    nn_index.index_labels_df
    .rename(columns={c: f"{c}_a" for c in nn_index.index_labels_df.columns})
    .merge(
        df_nn_top.head(1000)
        .assign(subreddit_id_a=df_nn_top.head(1000)['seed_ix'].replace(d_index_to_sub_id)),
        how='right',
        on='subreddit_id_a'
    )
)
df_nn_top.shape

CPU times: user 10.6 s, sys: 468 ms, total: 11.1 s
Wall time: 11 s


(1000, 8)

In [63]:
%%time

df_nn_top = df_nn_top.copy()

# append IDs & names for seed & nn (nearest neighbors)
df_nn_top = (
    nn_index.index_labels_df
    .rename(columns={c: f"{c}_a" for c in nn_index.index_labels_df.columns})
    .set_index('subreddit_id_a')
    .merge(
        df_nn_top.head(1000)
        .assign(subreddit_id_a=df_nn_top.head(1000)['seed_ix'].replace(d_index_to_sub_id))
        .set_index('subreddit_id_a')
        ,
        how='right',
        left_index=True,
        right_index=True,
    )
)
df_nn_top.shape

CPU times: user 10.6 s, sys: 420 ms, total: 11.1 s
Wall time: 11 s


(1000, 7)

In [133]:
# This is 2x slower because it looks up sub_id & sub_name in series. 
#. Instead: lookup sub_id and do a df.merge() to get sub_name
d_topk_final = dict()
l_topk_final = list()

d_topk_meta = {
    'pt': datetime.utcnow().strftime("%Y-%m-%d"),
    'mlflow_run_id': run_uuid, 
    'model_name': 'cau-text-mUSE',
    'model_version': 'v0.6.0',
}


for seed_ix_, df_seed_ in tqdm(df_nn_top.head(400).groupby(['seed_ix'])):
    df_seed_['subreddit_id'] = df_seed_['nn_ix'].replace(d_index_to_sub_id_all)
#     df_seed_ = df_seed_.merge()
    l_topk_final.append(
        {
            **d_topk_meta,
            **{
                'subreddit_id': d_index_to_sub_id_all[seed_ix_],
                'subreddit_name': d_index_to_sub_name_all[seed_ix_],
                'similar_subreddit': {
                    'subreddit_id': list(df_seed_['subreddit_id']),
                    # 'subreddit_name': list(df_seed_['nn_ix'].replace(d_index_to_sub_name_all)),
                    'cosine_similarity': list(df_seed_['cosine_similarity']),
                    'distance_rank': list(df_seed_['distance_rank']),
                }
            }
        }
    )

100%|██████████| 2/2 [00:19<00:00,  9.82s/it]
