# Purpose


### 2023-05-01
- Pull data for inference (some of it is cached)
- Load best model from previous run 
- Run inference on 3 subreddits (r/de, r/ich_iel, r/fragreddit)
    - Save inference data (raw)
- Reshape data for cache
    - 1 row = subreddit + country + subscribed status. Examples:
        - r/`de` | DE (Germany) | 1 (Subscribed) | {Nested struct of top N users}
        - r/`de` | DE (Germany) | 0 (viewed, but NOT Subscribed) |  {Nested struct of top N users}
        
Open questions:
- Other Geos:
    - How do we handle users that have NULL/no Geo?
    - How do we handle users from non-target Geos?
        - Maybe: create a `RoW` row and if you actually want their geo, you'll have to join it manually


# Imports & Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from datetime import datetime
import gc
import logging
from logging import info
import os
from pathlib import Path
import json

import polars as pl
import numpy as np
import pandas as pd
import plotly
import seaborn as sns

from tqdm.auto import tqdm
import dask
import mlflow

import subclu
from subclu.eda.aggregates import compare_raw_v_weighted_language
from subclu.utils import set_working_directory, get_project_subfolder
from subclu.utils.eda import (
    setup_logging, counts_describe, value_counts_and_pcts,
    notebook_display_config, print_lib_versions,
    style_df_numeric, reorder_array,
)
from subclu.utils.mlflow_logger import MlflowLogger
from subclu.utils.hydra_config_loader import LoadHydraConfig
from subclu.utils.data_irl_style import (
    get_colormap, theme_dirl, 
    get_color_dict, base_colors_for_manual_labels,
    check_colors_used,
)
from subclu.data.data_loaders import LoadPosts, LoadSubreddits, create_sub_level_aggregates


# ===
# imports specific to this notebook
import joblib
from typing import Tuple, Union

from subclu.models.nn_annoy import AnnoyIndex
from subclu.utils.eda import get_venn_sets2

from matplotlib_venn import venn2_unweighted

from google.cloud import bigquery
from subclu.pn_models import get_data


client = bigquery.Client()

print_lib_versions([bigquery, joblib, np, pd, pl, plotly, mlflow, subclu])

python		v 3.7.10
===
google.cloud.bigquery	v: 2.13.1
joblib		v: 1.0.1
numpy		v: 1.19.5
pandas		v: 1.2.4
polars		v: 0.17.1
plotly		v: 5.11.0
mlflow		v: 1.16.0
subclu		v: 0.6.1


In [3]:
# plotting
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.dates as mdates
plt.style.use('default')

setup_logging()
notebook_display_config()

# Define local path for this model outputs

In [4]:
manual_model_timestamp = datetime.utcnow().strftime('%Y-%m-%d_%H%M%S')
path_this_model = get_project_subfolder(
    f"data/models/pn_model/pn_manual_test_{manual_model_timestamp}"
)
Path.mkdir(path_this_model, parents=True, exist_ok=True)
print(path_this_model)

/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658


# Define key inputs
Use these throughout the process to filter/target specific subreddits, geos, & users

In [5]:
l_target_subreddits = ['ich_iel']

# We would want to add an automated way to pick these ANN subs, but look them up manually for now
l_target_ann_subreddits = ['fragreddit', 'de']

l_target_geos = [
    "MX", "ES", "AR"
    , "DE", "AT", "CH"
    , "US", "GB", "IN", "CA", "AU", "IE"
    , "FR", "NL", "IT"
    , "BR", "PT"
    , "PH"
]

# Load & reshape data

## Load training data from BQ

Here we mostly copy the query used for training. Note that we should match column names.

2 Minutes: ETA for **17 million** rows
- Note that this under-counts time because I was pulling data from some pre-computed tables

In [6]:
%%time
%%bigquery df_inference_raw --project data-science-prod-218515 

-- Select subreddit<>user data for INFERENCE. v2023-05-01
--   For this version pick data already processes for training
--   The main use case is to get the data outputs in the right shape/format so we can
--   hand something off to Engineering

DECLARE PT_FEATURES DATE DEFAULT "2022-12-01";
DECLARE PT_WINDOW_START DATE DEFAULT PT_FEATURES - 7;
DECLARE PT_VIEWS_START DATE DEFAULT PT_FEATURES - 29;

DECLARE TARGET_COUNTRY_CODES DEFAULT [
    "MX", "ES", "AR"
    , "DE", "AT", "CH"
    , "US", "GB", "IN", "CA", "AU", "IE"
    , "FR", "NL", "IT"
    , "BR", "PT"
    , "PH"
];

-- TODO(djb): Steps to create inference data:
--  Create CANDIDATES table users with subreddit<>users views in L30 days
--  Create TARGET table with selected subs<>users
--  Create exploded ToS table for users in user<>subreddit table


WITH subreddit_per_user_count AS (
    SELECT
        tos.user_id
        , COUNT(DISTINCT subreddit_id) AS tos_sub_count
    FROM `reddit-employee-datasets.david_bermejo.pn_test_users_de_campaign_tos_30_pct_20230418` AS tos
    GROUP BY 1
)
, post_consumes_agg AS (
    SELECT
        user_id
        , SUM(num_post_consumes) AS num_post_consumes
        , SUM(num_post_consumes_home) AS num_post_consumes_home
        , SUM(num_post_consumes_community) AS num_post_consumes_community
        , SUM(num_post_consumes_post_detail) AS num_post_consumes_post_detail
        , SUM(IF(app_name = 'ios', num_post_consumes, 0)) AS num_post_consumes_ios
        , SUM(IF(app_name = 'android', num_post_consumes, 0)) AS num_post_consumes_android
        , SUM(num_post_consumes_nsfw) AS num_post_consumes_nsfw
        , SAFE_DIVIDE(SUM(num_post_consumes_nsfw), SUM(num_post_consumes)) AS pct_post_consumes_nsfw
        -- , SUM(num_post_consumes_sfw) AS num_post_consumes_sfw
    FROM `data-prod-165221.video.post_consumes_30d_agg`
    WHERE DATE(pt) = PT_FEATURES
    GROUP BY 1
)
, candidate_sub_users AS (
    SELECT
        -- Need to fill user_id where user_id is missing from new selection criteria
        COALESCE(f.user_id, act.user_id) AS user_id
        -- But for other ids, only keep the raw data (don't fill from training b/c that data has some dupes)
        , f.subreddit_name AS target_subreddit
        , f.subreddit_id AS target_subreddit_id

        , COALESCE(act.send, 0) AS send
        , COALESCE(act.receive, 0) AS receive
        , COALESCE(act.click, 0) AS click

        -- The subscribed column in the old test table was wrong (all 0)
        , f.* EXCEPT(pt, pt_window_start, user_id, subreddit_name, subreddit_id, subscribed)

    FROM (
        SELECT *
        FROM `reddit-employee-datasets.david_bermejo.pn_test_users_de_campaign_20230418`
    ) AS f
        -- TODO(djb): remove this join. For normal inference, we won't need to join on actual sends/clicks
        --  Only doing it for this query because we want to explore the users that the model ranks high, but didn't receive it
        LEFT JOIN `reddit-employee-datasets.david_bermejo.pn_training_data_test_20230428` AS act
            ON f.user_id = act.user_id
                AND f.subreddit_name = act.target_subreddit

    WHERE f.subreddit_id IS NOT NULL
        -- For inference, we don't need this clause. 
        --   Only need to keep sends|receives for TRAINING
        -- AND act.target_subreddit IS NOT NULL
)
, user_actions_t7 AS (
    SELECT
      pne.user_id,
      COUNT(receive_endpoint_timestamp) user_receives_pn_t7,
      COUNT(click_endpoint_timestamp) user_clicks_pn_t7,
      COUNT(
        CASE
          WHEN notification_type='lifecycle_post_suggestions'
            THEN click_endpoint_timestamp
          ELSE NULL
        END
    ) user_clicks_trnd_t7
    FROM `data-prod-165221.channels.push_notification_events` AS pne
    INNER JOIN candidate_sub_users AS c
        ON pne.user_id = c.user_id
    WHERE
        DATE(pt) BETWEEN PT_WINDOW_START AND PT_FEATURES
        AND NOT REGEXP_CONTAINS(notification_type, "email")
        AND receive_endpoint_timestamp IS NOT NULL
  GROUP BY user_id
)
, subscribes AS (
    SELECT
        -- We need distinct in case a user subscribes multiple times to the same sub
        DISTINCT
        u.user_id,
        su.subreddit_id AS subreddit_id
    from data-prod-165221.ds_v2_postgres_tables.account_subscriptions AS s
        LEFT JOIN UNNEST(subscriptions) AS su

        INNER JOIN candidate_sub_users AS u
            ON s.user_id = u.user_id

    WHERE DATE(_PARTITIONTIME) = (CURRENT_DATE() - 2)
        AND DATE(subscribe_date) <= PT_FEATURES
)

-- Select final data
SELECT
    ct.user_id
    , ct.target_subreddit
    , ct.target_subreddit_id
    , ct.send
    , ct.receive
    , ct.click
    , CASE
        WHEN ct.geo_country_code IS NULL THEN 'MISSING' 
        WHEN ct.geo_country_code IN UNNEST(TARGET_COUNTRY_CODES) THEN ct.geo_country_code
        ELSE 'ROW'
    END AS geo_country_code_top
    , IF(s.subreddit_id IS NOT NULL, 1, 0) subscribed
    , COALESCE(tsc.tos_sub_count, 0) AS tos_30_sub_count
    , COALESCE(tos.tos_30_pct, 0) AS tos_30_pct
    , COALESCE(sv.feature_value, 0) AS screen_view_count_14d
    , CASE
        WHEN cl.legacy_user_cohort = 'new' THEN 1
        WHEN cl.legacy_user_cohort = 'resurrected' THEN 2
        WHEN cl.legacy_user_cohort = 'casual' THEN 3
        WHEN cl.legacy_user_cohort IS NULL THEN 4  -- '_missing_' or 'dead'
        WHEN cl.legacy_user_cohort = 'core' THEN 5
        ELSE 0
    END AS legacy_user_cohort_ord
    , pna.* EXCEPT(user_id)
    , co.* EXCEPT(user_id)
    , ct.* EXCEPT(user_id, target_subreddit, target_subreddit_id, send, receive, click, user_in_actual_but_missing_from_new)

FROM candidate_sub_users AS ct
    -- Get count of subs in ToS
    LEFT JOIN subreddit_per_user_count AS tsc
        ON ct.user_id = tsc.user_id

    -- Recent PN activity
    LEFT JOIN user_actions_t7 AS pna
        ON ct.user_id = pna.user_id

    -- Get view counts (all subreddits)
    LEFT JOIN (
        SELECT entity_id, feature_value
        FROM `data-prod-165221.user_feature_platform.screen_views_count_over_14_days_v1`
        WHERE DATE(pt) = PT_FEATURES
    ) AS sv
        ON ct.user_id = sv.entity_id
    
    -- USER cohort, Legacy
    LEFT JOIN (
        SELECT user_id, legacy_user_cohort
        FROM `data-prod-165221.attributes_platform.user_rolling_legacy_user_cohorts`
        WHERE DATE(pt) = PT_FEATURES
    ) AS cl
        ON ct.user_id = cl.user_id
    
    -- USER consumes
    LEFT JOIN post_consumes_agg AS co
        ON ct.user_id = co.user_id

    -- Add ToS_pct for target subreddit
    LEFT JOIN `reddit-employee-datasets.david_bermejo.pn_test_users_de_campaign_tos_30_pct_20230418` AS tos
        ON ct.user_id = tos.user_id
            AND ct.target_subreddit_id = tos.subreddit_id

    -- Get flag for user subscribed/not subscribed to sub
    LEFT JOIN subscribes AS s
        ON ct.user_id = s.user_id
        AND ct.target_subreddit_id = s.subreddit_id

-- For Inference, there's no need for WHERE clause because we want to score ALL users, even those that didn't receive the PN
-- WHERE ct.receive = 1

-- Only order to check data, no need to spend time ordering for training or inference
-- ORDER BY tos_30_pct DESC, click DESC, tos_sub_count DESC
;

Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 1129.32query/s]
Downloading: 100%|██████████| 17486680/17486680 [00:37<00:00, 460834.10rows/s]


CPU times: user 23.5 s, sys: 12.4 s, total: 35.8 s
Wall time: 1min 28s


In [7]:
df_inference_raw.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 17486680 entries, 0 to 17486679
Data columns (total 30 columns):
 #   Column                         Dtype  
---  ------                         -----  
 0   user_id                        object 
 1   target_subreddit               object 
 2   target_subreddit_id            object 
 3   send                           int64  
 4   receive                        int64  
 5   click                          int64  
 6   geo_country_code_top           object 
 7   subscribed                     int64  
 8   tos_30_sub_count               int64  
 9   tos_30_pct                     float64
 10  screen_view_count_14d          int64  
 11  legacy_user_cohort_ord         int64  
 12  user_receives_pn_t7            float64
 13  user_clicks_pn_t7              float64
 14  user_clicks_trnd_t7            float64
 15  num_post_consumes              int64  
 16  num_post_consumes_home         int64  
 17  num_post_consumes_community    int64  
 18  

In [8]:
counts_describe(df_inference_raw[['user_id', 'target_subreddit',  'click', 'subscribed']])

Unnamed: 0,dtype,count,unique,unique-percent,null-count,null-percent
user_id,object,17486680,11170619,63.88%,0,0.00%
target_subreddit,object,17486680,3,0.00%,0,0.00%
click,int64,17486680,2,0.00%,0,0.00%
subscribed,int64,17486680,2,0.00%,0,0.00%


There seems to be something off (again) with 2 user_ids... there must be some bad join where subreddit_name is messed up with the wrong subreddit_id.

For now, drop these but I'll need to include subreddit_id as much as possible to prevent such errors later.

In [9]:
%%time
# save the mask so we can re-use it when dropping dupes
mask_dupes_by_user_and_target_sub_ = df_inference_raw.duplicated(subset=['user_id', 'target_subreddit'], keep=False)
(
    df_inference_raw[mask_dupes_by_user_and_target_sub_]
    .sort_values(by=['user_id', 'target_subreddit'])
)

CPU times: user 13.5 s, sys: 413 ms, total: 13.9 s
Wall time: 13.9 s


Unnamed: 0,user_id,target_subreddit,target_subreddit_id,send,receive,click,geo_country_code_top,subscribed,tos_30_sub_count,tos_30_pct,screen_view_count_14d,legacy_user_cohort_ord,user_receives_pn_t7,user_clicks_pn_t7,user_clicks_trnd_t7,num_post_consumes,num_post_consumes_home,num_post_consumes_community,num_post_consumes_post_detail,num_post_consumes_ios,num_post_consumes_android,num_post_consumes_nsfw,pct_post_consumes_nsfw,geo_country_code,view_and_consume_unique_count,consume_unique_count,view_count,consume_count,consume_ios_count,consume_android_count
864499,t2_6yxrzyy,de,t5_13x7do,0,0,0,DE,0,121,0.0,21,5,416.0,8.0,8.0,8806,1351,3198,0,8806,0,185,0.021008,DE,1,1,0,2,2,0
17036271,t2_6yxrzyy,de,t5_22i0,0,0,0,DE,1,121,0.20337,21,5,416.0,8.0,8.0,8806,1351,3198,0,8806,0,185,0.021008,DE,744,742,380,2216,2216,0


In [10]:
%%time

(
    df_inference_raw[df_inference_raw.duplicated(subset=['user_id', 'target_subreddit_id'], keep=False)]
    .sort_values(by=['user_id', 'target_subreddit'])
)

CPU times: user 13.7 s, sys: 306 ms, total: 14 s
Wall time: 14 s


Unnamed: 0,user_id,target_subreddit,target_subreddit_id,send,receive,click,geo_country_code_top,subscribed,tos_30_sub_count,tos_30_pct,screen_view_count_14d,legacy_user_cohort_ord,user_receives_pn_t7,user_clicks_pn_t7,user_clicks_trnd_t7,num_post_consumes,num_post_consumes_home,num_post_consumes_community,num_post_consumes_post_detail,num_post_consumes_ios,num_post_consumes_android,num_post_consumes_nsfw,pct_post_consumes_nsfw,geo_country_code,view_and_consume_unique_count,consume_unique_count,view_count,consume_count,consume_ios_count,consume_android_count


In [11]:
# %%time
# # it can take a minute to check dupes in 17 million rows & ~30 columns
# (
#     df_inference_raw[df_inference_raw.duplicated(keep=False)]
#     .sort_values(by=['user_id'])
# )

In [12]:
%%time
# remove duplicates
df_inference_raw = df_inference_raw[~mask_dupes_by_user_and_target_sub_]
df_inference_raw.shape

CPU times: user 15.3 s, sys: 1.51 s, total: 16.8 s
Wall time: 16.8 s


(17486678, 30)

In [13]:
counts_describe(df_inference_raw[['user_id', 'target_subreddit', 'click', 'subscribed']])

Unnamed: 0,dtype,count,unique,unique-percent,null-count,null-percent
user_id,object,17486678,11170619,63.88%,0,0.00%
target_subreddit,object,17486678,3,0.00%,0,0.00%
click,int64,17486678,2,0.00%,0,0.00%
subscribed,int64,17486678,2,0.00%,0,0.00%


In [14]:
df_inference_raw.head()

Unnamed: 0,user_id,target_subreddit,target_subreddit_id,send,receive,click,geo_country_code_top,subscribed,tos_30_sub_count,tos_30_pct,screen_view_count_14d,legacy_user_cohort_ord,user_receives_pn_t7,user_clicks_pn_t7,user_clicks_trnd_t7,num_post_consumes,num_post_consumes_home,num_post_consumes_community,num_post_consumes_post_detail,num_post_consumes_ios,num_post_consumes_android,num_post_consumes_nsfw,pct_post_consumes_nsfw,geo_country_code,view_and_consume_unique_count,consume_unique_count,view_count,consume_count,consume_ios_count,consume_android_count
0,t2_phc7zhvs,ich_iel,t5_37k29,0,0,0,ROW,0,67,0.0,21,5,,,,27151,0,8277,5186,0,0,560,0.020625,HU,10,10,0,12,0,0
1,t2_ufrg38qv,fragreddit,t5_2r6ca,0,0,0,DE,0,4,0.0,0,4,,,,133,0,0,22,0,0,2,0.015038,DE,5,5,0,5,0,0
2,t2_9vz13y92,ich_iel,t5_37k29,0,0,0,CA,1,72,0.00013,21,5,,,,5750,4314,1055,0,0,769,132,0.022957,CA,118,117,10,144,0,0
3,t2_u8tvho25,de,t5_22i0,0,0,0,GB,0,46,0.0,0,4,,,,962,0,200,0,962,0,1,0.00104,GB,3,3,0,6,6,0
4,t2_fdqjopee,de,t5_22i0,0,0,0,DE,0,391,0.0,20,5,,,,5042,178,740,0,5042,0,3626,0.719159,DE,9,9,0,11,11,0


# Transform & EDA

## Test logic to filter out user<>subreddit pairs that are not likely to get clicks

Logic to select user<>subreddit TARGETS:
- 2+ views (any subscription status)
- 2+ consumes (any subscription status)
- subscribed AND (1+ view OR consume)
- subscribed AND (3+ PN clicks in L7 days) 

TODO(djb): apply this logic in SQL once we know it's good so that we can save a lot of compute on getting training data & on inference.


In [15]:
%%time

df_inf = (
    df_inference_raw.copy()
)
df_inf.shape

CPU times: user 1.17 s, sys: 824 ms, total: 1.99 s
Wall time: 1.99 s


(17486678, 30)

In [16]:
# mask_subscribes_and_activity = ()

# info(f"{mask_subscribes_and_activity.sum():,.0f} <- Subscribers with some activity")
# # df_inf['target'] = 

## Some EDA

In [17]:
%%time

# counts_describe(df_inf)

CPU times: user 4 µs, sys: 1e+03 ns, total: 5 µs
Wall time: 9.06 µs


## Click (CTR)
Since we're looking only at user that RECEIVED the PN, this CTR should be the same as the one computed in the overall dashboard.

In [18]:
value_counts_and_pcts(
    df_inf[df_inf['receive'] == 1],
    ['click'],
    sort_index=True,
    top_n=None,
    pct_digits=3,
)

Unnamed: 0_level_0,count,percent,cumulative_percent
click,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1,9746,6.543%,6.543%
0,139204,93.457%,100.000%


## Geo + Subreddit counts

As expected, `DE` contributes a large # of users, but the US provides even more! RoW (Rest of world) also provides 9% of the sample with a longer tail of other countries.

In [19]:
%%time

value_counts_and_pcts(
    df_inf,
    ['geo_country_code_top',],
#     sort_index=True,
#     sort_index_ascending=False,
#     reset_index=True,
    top_n=None,
#     return_df=True
)

CPU times: user 1.65 s, sys: 284 ms, total: 1.93 s
Wall time: 1.93 s


Unnamed: 0_level_0,count,percent,cumulative_percent
geo_country_code_top,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
US,6420802,36.7%,36.7%
DE,6006565,34.3%,71.1%
ROW,1571790,9.0%,80.1%
MISSING,653018,3.7%,83.8%
CA,618069,3.5%,87.3%
GB,599161,3.4%,90.8%
AU,282175,1.6%,92.4%
AT,240252,1.4%,93.7%
NL,235947,1.3%,95.1%
CH,144896,0.8%,95.9%


In [20]:
%%time

for sub_ in sorted(df_inf['target_subreddit'].unique()):
    display(
        value_counts_and_pcts(
            df_inf[df_inf['target_subreddit'] == sub_],
            ['target_subreddit', 'geo_country_code_top',],
            top_n=None,
        #     sort_index=True,
        #     sort_index_ascending=True,
        #     reset_index=True,
        #     return_df=True
        )
    )

Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,geo_country_code_top,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
de,US,2783799,38.3%,38.3%
de,DE,2248853,30.9%,69.2%
de,ROW,697542,9.6%,78.8%
de,CA,282069,3.9%,82.6%
de,GB,273896,3.8%,86.4%
de,MISSING,259009,3.6%,90.0%
de,AU,124913,1.7%,91.7%
de,AT,109768,1.5%,93.2%
de,NL,103158,1.4%,94.6%
de,CH,66792,0.9%,95.5%


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,geo_country_code_top,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
fragreddit,DE,1734045,78.1%,78.1%
fragreddit,US,137501,6.2%,84.3%
fragreddit,ROW,95051,4.3%,88.6%
fragreddit,MISSING,65363,2.9%,91.5%
fragreddit,AT,57727,2.6%,94.1%
fragreddit,CH,24840,1.1%,95.2%
fragreddit,GB,23853,1.1%,96.3%
fragreddit,CA,18344,0.8%,97.1%
fragreddit,NL,13302,0.6%,97.7%
fragreddit,AU,9514,0.4%,98.2%


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,geo_country_code_top,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ich_iel,US,3499502,43.8%,43.8%
ich_iel,DE,2023667,25.3%,69.1%
ich_iel,ROW,779197,9.8%,78.9%
ich_iel,MISSING,328646,4.1%,83.0%
ich_iel,CA,317656,4.0%,87.0%
ich_iel,GB,301412,3.8%,90.7%
ich_iel,AU,147748,1.8%,92.6%
ich_iel,NL,119487,1.5%,94.1%
ich_iel,AT,72757,0.9%,95.0%
ich_iel,FR,63970,0.8%,95.8%


CPU times: user 10.6 s, sys: 1.59 s, total: 12.2 s
Wall time: 12.2 s


## Subscribed


In [21]:
%%time

l_cols_check_ = [
    'subscribed',
    'legacy_user_cohort_ord',
]

for c_ in l_cols_check_:
    display(
        value_counts_and_pcts(
            df_inf.fillna(-1),
            [c_],
            top_n=None,
        )
    )

Unnamed: 0_level_0,count,percent,cumulative_percent
subscribed,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,16848688,96.4%,96.4%
1,637990,3.6%,100.0%


Unnamed: 0_level_0,count,percent,cumulative_percent
legacy_user_cohort_ord,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
4,8943858,51.1%,51.1%
5,7620522,43.6%,94.7%
3,470244,2.7%,97.4%
1,296305,1.7%,99.1%
2,155749,0.9%,100.0%


CPU times: user 16.2 s, sys: 2.93 s, total: 19.1 s
Wall time: 19.1 s


In [22]:
%%time

for sub_ in sorted(df_inf['target_subreddit'].unique()):
    print(f"\n=== {sub_} ===")
    mask_sub_ = df_inf['target_subreddit'] == sub_
    
    for c_ in l_cols_check_:
        display(
            value_counts_and_pcts(
                df_inf[mask_sub_],
                ['target_subreddit', c_],
                top_n=None,
            )
        )


=== de ===


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,subscribed,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
de,0,7008489,96.3%,96.3%
de,1,267140,3.7%,100.0%


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,legacy_user_cohort_ord,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
de,4,3586198,49.3%,49.3%
de,5,3298732,45.3%,94.6%
de,3,191579,2.6%,97.3%
de,1,136057,1.9%,99.1%
de,2,63063,0.9%,100.0%



=== fragreddit ===


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,subscribed,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
fragreddit,0,2113974,95.2%,95.2%
fragreddit,1,106379,4.8%,100.0%


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,legacy_user_cohort_ord,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
fragreddit,4,1443433,65.0%,65.0%
fragreddit,5,611858,27.6%,92.6%
fragreddit,3,80087,3.6%,96.2%
fragreddit,1,48611,2.2%,98.4%
fragreddit,2,36364,1.6%,100.0%



=== ich_iel ===


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,subscribed,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ich_iel,0,7726225,96.7%,96.7%
ich_iel,1,264471,3.3%,100.0%


Unnamed: 0_level_0,Unnamed: 1_level_0,count,percent,cumulative_percent
target_subreddit,legacy_user_cohort_ord,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ich_iel,4,3914227,49.0%,49.0%
ich_iel,5,3709932,46.4%,95.4%
ich_iel,3,198578,2.5%,97.9%
ich_iel,1,111637,1.4%,99.3%
ich_iel,2,56322,0.7%,100.0%


CPU times: user 15.3 s, sys: 2.36 s, total: 17.7 s
Wall time: 17.6 s


# Reshape data for inference

In [23]:
%%time

# Stop using categorical columns, instead do the numeric processing in SQL to make sure
#  that the same ordinal or numeric encoding is applied upfront and on all train/test & inference values


# ordinal & boolean columns should NOT be re-scaled
l_train_cols_num_no_scale = [
    'legacy_user_cohort_ord',
]

# These numeric columns are candidates to be rescaled under some pipelines
l_train_cols_num = [
    # user-only columns
    'tos_30_sub_count',
    'user_receives_pn_t7',
    'user_clicks_pn_t7',
    'user_clicks_trnd_t7',
    # 'subscribed',  # Subscribed wasn't used in the previous model, but should be used going forward
    
    'screen_view_count_14d',
    'num_post_consumes',
    'num_post_consumes_home',
    'num_post_consumes_community',
    'num_post_consumes_post_detail',
    'num_post_consumes_ios',
    'num_post_consumes_android',
    'num_post_consumes_nsfw',
    'pct_post_consumes_nsfw',
    
    # user<> subreddit cols
    'view_and_consume_unique_count',
    'consume_unique_count',
    'view_count',
    'consume_count',
    'consume_ios_count',
    'consume_android_count',
    'tos_30_pct',  # use 30_pct instead of cosine distance/similarity  
]
# For some features we want to flag nulls as negative to distinguish data missing
#  instead of filling as zeros
l_col_fill_neg = [
    'tos_30_sub_count',
    'user_receives_pn_t7',
    'user_clicks_pn_t7',
    'user_clicks_trnd_t7',
    
    'screen_view_count_14d',
    'num_post_consumes',
    'num_post_consumes_home',
    'num_post_consumes_community',
    'num_post_consumes_post_detail',
    'num_post_consumes_ios',
    'num_post_consumes_android',
    'num_post_consumes_nsfw',
    'pct_post_consumes_nsfw',
]
d_fillna_ = {c: -1 for c in l_col_fill_neg}
for c_ in l_train_cols_num:
    if c_ not in l_col_fill_neg:
        d_fillna_[c_] = 0

X = (
    # simple strategy for nulls
    # use column transformer to handle category cols as part of pipeline
    df_inf[l_train_cols_num_no_scale + l_train_cols_num].fillna(d_fillna_).copy()
)

print(X.shape)

(17486678, 21)
CPU times: user 1.86 s, sys: 1.7 s, total: 3.56 s
Wall time: 3.56 s


# Load pre-trained model

In [None]:
TODO

In [24]:
%%time

model = joblib.load(
    "/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-01_160154/model-xgboost_pre-train_202866_21.gz"
)

model

CPU times: user 3.91 s, sys: 278 ms, total: 4.19 s
Wall time: 1.3 s


GridSearchCV(cv=StratifiedKFold(n_splits=4, random_state=42, shuffle=True),
             estimator=Pipeline(steps=[('preprocess',
                                        ColumnTransformer(remainder='passthrough',
                                                          transformers=[('scale',
                                                                         StandardScaler(),
                                                                         ['tos_30_sub_count',
                                                                          'user_receives_pn_t7',
                                                                          'user_clicks_pn_t7',
                                                                          'user_clicks_trnd_t7',
                                                                          'screen_view_count_14d',
                                                                          'num_post_consumes',
                                     

# Run inference

In [None]:
TODO

In [27]:
[_ for _ in df_inf.columns if 'geo' in _]

['geo_country_code_top', 'geo_country_code']

In [28]:
model.predict_proba(X.head())[:,1]

array([0.00789476, 0.30935174, 0.6716812 , 0.21024369, 0.21130966],
      dtype=float32)

In [38]:
%%time

# set cols we'll need to use for caching
c_pred_proba = 'click_proba'

l_ix_cols = [
    'target_subreddit',
    'target_subreddit_id',
    'geo_country_code_top',  # We'll maybe need to rename it in final table
    'user_id',
]

info(f"Create new df for predictions...")
df_pred = df_inf[l_ix_cols].copy()

info(f"Get click predictions...")
df_pred[c_pred_proba] = model.predict_proba(X)[:,1]

15:13:17 | INFO | "Create new df for predictions..."
15:13:19 | INFO | "Get click predictions..."


CPU times: user 5min 8s, sys: 2.65 s, total: 5min 11s
Wall time: 10.7 s


In [48]:
l_ix_cache = [
    'target_subreddit',
    'target_subreddit_id',
    'geo_country_code_top',
]
c_user_rank = 'user_rank_by_sub_and_geo'

In [71]:
%%time

# for the cache we only keep the top 500k per subreddit+country
# create new column for ranking so that it's easier to filter down the road
df_pred[c_user_rank] = (
    df_pred
    .groupby(l_ix_cache)
    [c_pred_proba]
    .rank(method='dense', ascending=False)
    .astype(int)
)

CPU times: user 5.29 s, sys: 553 ms, total: 5.84 s
Wall time: 5.84 s


In [63]:
%%time
# Sort the df so that the final file is in the right order
df_pred = (
    df_pred
    .sort_values(by=l_ix_cache + [c_user_rank], ascending=True)
)


CPU times: user 17.2 s, sys: 1.34 s, total: 18.5 s
Wall time: 18.5 s


In [72]:
df_pred.shape

(17486678, 6)

In [73]:
df_pred.head()

Unnamed: 0,target_subreddit,target_subreddit_id,geo_country_code_top,user_id,click_proba,user_rank_by_sub_and_geo
2087226,de,t5_22i0,AR,t2_d6qpz9xs,0.981573,1
16406019,de,t5_22i0,AR,t2_85vfhpms,0.978377,2
273547,de,t5_22i0,AR,t2_5zhd4aq9,0.976157,3
4370746,de,t5_22i0,AR,t2_480ucypc,0.974954,4
11948584,de,t5_22i0,AR,t2_rlxxb,0.973342,5


# Save raw inference

In [74]:
%%time

# for prod: might need to use dask or polars to save to multiple files
r_, c_ = df_pred.shape
df_pred.to_parquet(
    path_this_model / f"df_pred-{r_}_{c_}.parquet"
    , index=False
)
del r_, c_

CPU times: user 7.66 s, sys: 596 ms, total: 8.25 s
Wall time: 10.3 s


In [75]:
!du -Lsh $path_this_model/* | sort -hr 

277M	/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/df_pred-17486678_6.parquet
248M	/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/df_pred-17486678_5.parquet


# Select cache & Reshape for BigQuery
We won't save all predictions -- keep those above threshold that gets us over 90% recall (~0.300 in latest XGBoost-pre model)

In [None]:
TODO

In [81]:
style_df_numeric(
    df_pred[df_pred[c_user_rank] <= 4]
    .head(15)
    ,
    float_round=4,
)

Unnamed: 0,target_subreddit,target_subreddit_id,geo_country_code_top,user_id,click_proba,user_rank_by_sub_and_geo
2087226,de,t5_22i0,AR,t2_d6qpz9xs,0.9816,1
16406019,de,t5_22i0,AR,t2_85vfhpms,0.9784,2
273547,de,t5_22i0,AR,t2_5zhd4aq9,0.9762,3
4370746,de,t5_22i0,AR,t2_480ucypc,0.975,4
2588090,de,t5_22i0,AT,t2_zfcszu6,0.9873,1
13469891,de,t5_22i0,AT,t2_16jzrqba,0.9844,2
12307300,de,t5_22i0,AT,t2_rvlq3,0.9841,3
3258035,de,t5_22i0,AT,t2_aomm5arj,0.9838,4
12691808,de,t5_22i0,AU,t2_3tkby4wi,0.9853,1
4526865,de,t5_22i0,AU,t2_14tz0q,0.9814,2


## Project how many users per sub+country with some thresholds
This is helpful to know how many we'll have in the final cache table before writing to it

In [107]:
%%time

mask_pred_above_0250 = (df_pred[c_pred_proba] >= 0.250)
mask_pred_above_0200 = (df_pred[c_pred_proba] >= 0.200)
mask_pred_above_0150 = (df_pred[c_pred_proba] >= 0.150)
mask_pred_above_0100 = (df_pred[c_pred_proba] >= 0.100)

mask_rank_below_500k = (df_pred[c_user_rank] <= 500000)

print(f"{mask_pred_above_0100.sum():,.0f} Users above 0.100")
print(f"{mask_pred_above_0150.sum():,.0f} Users above 0.150")
print(f"{mask_pred_above_0200.sum():,.0f} Users above 0.200")
print(f"{mask_pred_above_0250.sum():,.0f} Users above 0.250")

print(f"{mask_rank_below_500k.sum():,.0f} Users below 500k rank")

11,849,601 Users above 0.100
10,763,902 Users above 0.150
9,807,071 Users above 0.200
8,753,948 Users above 0.250
8,383,463 Users below 500k rank
CPU times: user 136 ms, sys: 3.79 ms, total: 140 ms
Wall time: 138 ms


In [110]:
display(
    style_df_numeric(
        df_pred[c_pred_proba]
        .describe(percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]).to_frame().T,
        float_round=4,
        l_bar_simple=['mean', '50%'],
    ).set_caption(f"All users")
)

display(
    style_df_numeric(
        df_pred
        [mask_rank_below_500k]
        [c_pred_proba]
        .describe(percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]).to_frame().T,
        float_round=4,
        l_bar_simple=['mean', '50%']
    ).set_caption(f"Users below 500k rank")
)

display(
    style_df_numeric(
        df_pred
        [mask_pred_above_0100 & mask_rank_below_500k]
        [c_pred_proba]
        .describe(percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]).to_frame().T,
        float_round=4,
        l_bar_simple=['mean', '50%']
    ).set_caption(f"Users below 500k rank AND above 0.100 Threshold")
)

Unnamed: 0,count,mean,std,min,10%,20%,25%,30%,40%,50%,60%,70%,75%,80%,90%,95%,99%,max
click_proba,17486678,0.2807,0.2276,0.0,0.0242,0.0468,0.0617,0.0857,0.1639,0.2505,0.3294,0.403,0.4454,0.4906,0.6011,0.6797,0.8826,0.9975


Unnamed: 0,count,mean,std,min,10%,20%,25%,30%,40%,50%,60%,70%,75%,80%,90%,95%,99%,max
click_proba,8383463,0.4011,0.2365,0.0,0.0468,0.1391,0.2162,0.2828,0.3612,0.4107,0.4774,0.5486,0.5767,0.6063,0.6852,0.795,0.924,0.9975


Unnamed: 0,count,mean,std,min,10%,20%,25%,30%,40%,50%,60%,70%,75%,80%,90%,95%,99%,max
click_proba,6938439,0.4755,0.188,0.1,0.2238,0.3203,0.3488,0.3649,0.4131,0.4664,0.5268,0.5777,0.6021,0.6276,0.7133,0.8192,0.9316,0.9975


In [116]:
display(
    style_df_numeric(
        df_pred
        .groupby(['target_subreddit'])
        [c_pred_proba]
        .describe(
            # percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]
        ),
        float_round=4,
        l_bar_simple=['mean', '50%']
    ).set_caption(f"All users")
)

display(
    style_df_numeric(
        df_pred
        [mask_pred_above_0100 & mask_rank_below_500k]
        .groupby(['target_subreddit'])
        [c_pred_proba]
        .describe(
            # percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]
        ),
        float_round=4,
        l_bar_simple=['mean', '50%']
    ).set_caption(f"Users below 500k rank AND above 0.100 Threshold")
)

Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
target_subreddit,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
de,7275629,0.2777,0.2277,0.0,0.0588,0.2464,0.4417,0.9975
fragreddit,2220353,0.2607,0.2186,0.0,0.056,0.2329,0.4045,0.997
ich_iel,7990696,0.289,0.2295,0.0,0.0662,0.2607,0.4609,0.9969


Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
target_subreddit,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
de,2881638,0.4775,0.1871,0.1,0.3542,0.4786,0.6013,0.9975
fragreddit,1043029,0.4463,0.1685,0.1,0.3367,0.4158,0.5378,0.997
ich_iel,3013772,0.4837,0.1941,0.1,0.35,0.4839,0.6158,0.9969


In [117]:
display(
    style_df_numeric(
        df_pred
        .groupby(['geo_country_code_top'])
        [c_pred_proba]
        .describe(
            # percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]
        )
        .sort_values(by=['50%'], ascending=False)
        ,
        float_round=4,
        l_bar_simple=['mean', '50%']
    ).set_caption(f"All users")
)

display(
    style_df_numeric(
        df_pred
        [mask_pred_above_0100 & mask_rank_below_500k]
        .groupby(['geo_country_code_top'])
        [c_pred_proba]
        .describe(
            # percentiles=[0.1, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.75, 0.8, 0.9, 0.95, 0.99]
        )
        .sort_values(by=['50%'], ascending=False)
        ,
        float_round=4,
        l_bar_simple=['mean', '50%']
    ).set_caption(f"Users below 500k rank AND above 0.100 Threshold")
)

Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
geo_country_code_top,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
US,6420802,0.2992,0.2313,0.0,0.0753,0.2706,0.4837,0.9964
MISSING,653018,0.2667,0.1476,0.0003,0.1526,0.2676,0.37,0.9247
IN,131692,0.3027,0.2552,0.0,0.0526,0.2671,0.4927,0.9948
MX,58348,0.2959,0.2414,0.0001,0.0569,0.2638,0.4851,0.9838
GB,599161,0.2969,0.2525,0.0,0.0492,0.2583,0.5007,0.9949
PH,42307,0.2983,0.259,0.0002,0.0496,0.2551,0.4998,0.9845
PT,38341,0.2989,0.2641,0.0001,0.0476,0.2484,0.5139,0.9858
IT,58498,0.2945,0.2544,0.0001,0.0509,0.248,0.49,0.9902
ROW,1571790,0.2842,0.2423,0.0,0.0507,0.2464,0.4638,0.9933
BR,104395,0.2911,0.248,0.0001,0.0549,0.2459,0.4752,0.9889


Unnamed: 0_level_0,count,mean,std,min,25%,50%,75%,max
geo_country_code_top,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
US,1302789,0.6333,0.1264,0.1,0.5703,0.6189,0.686,0.9964
DE,2272361,0.4895,0.1441,0.2914,0.3808,0.4515,0.5571,0.9975
PT,23799,0.4576,0.2137,0.1,0.2875,0.4452,0.6103,0.9858
GB,383955,0.4414,0.2027,0.1,0.2784,0.4355,0.584,0.9949
PH,26629,0.4498,0.2105,0.1,0.284,0.4346,0.595,0.9845
AU,175534,0.4333,0.1978,0.1,0.2704,0.4334,0.5775,0.9853
IE,38211,0.4363,0.2084,0.1,0.2657,0.422,0.5814,0.9919
CA,399919,0.4234,0.1978,0.1,0.2604,0.4182,0.5638,0.9902
IT,37518,0.4371,0.2097,0.1,0.2671,0.4172,0.5846,0.9902
IN,86455,0.4399,0.2102,0.1,0.2735,0.4165,0.5825,0.9948


## Create metadata 
This metadata will be used for all items in this batch. Set it in a dict and save it only to the final data because it can add a ton of overhead to save it for all rows of inference

In [80]:
%%time

d_topk_meta = {
    'pt': '2022-12-01',
    'model_name': 'PN click subreddit-user',
    'model_version': 'v0.1 2023-05-02',
}

# Only check if we've already written it to the file
# info(f"Checking keys for ndjson...")
# for k in tqdm(d_topk_meta.keys()):
#     assert 1 == df_pred[k].nunique()
#     d_topk_meta[k] = df_pred[k].values[0]
#     print(f"  {k}: {d_topk_meta[k]}")

CPU times: user 13 µs, sys: 1e+03 ns, total: 14 µs
Wall time: 27.4 µs


In [82]:
df_pred.columns

Index(['target_subreddit', 'target_subreddit_id', 'geo_country_code_top', 'user_id', 'click_proba', 'user_rank_by_sub_and_geo'], dtype='object')

In [79]:
%%time

n_unique_subs = df_inf['target_subreddit_id'].nunique()

CPU times: user 1.67 s, sys: 83.6 ms, total: 1.75 s
Wall time: 1.75 s


## Save as ndJSON

When we apply cache filters, it makes sense to do it upfront (on the whole df) instead of per each subreddit+country

```bash
# time doing filters per group
CPU times: user 1min 2s, sys: 5.5 s, total: 1min 8s
Wall time: 1min 8s
    
# time applying mask upfront - we save ~16 seconds (which will add up when we run on 50k subreddits)
CPU times: user 48.4 s, sys: 3.71 s, total: 52.2 s
Wall time: 52.1 s
```

In [120]:
%%time

# Set limits for what to save in BQ
n_user_limit_per_sub_and_geo = 500000  # initial goal: 500k
threshold_min = 0.100  # We get 90% recall around 0.3, but we can set it lower for other countries

# apply the mask before the groupby to speed up the whole process
mask_cache_ = (
    (df_pred[c_user_rank] <= n_user_limit_per_sub_and_geo) &
    (df_pred[c_pred_proba] >= threshold_min)
)
info(f"{mask_cache_.sum():,.0f} <- Rows to process")

verbose_ = False  # set to True to display shapes or other info to help debug

# These are the cols to nest for top_users
cols_for_nested_users = [
    'user_id', 
    'click_proba', 
    'user_rank_by_sub_and_geo',
]
# Rename dict, if needed
# d_rename_for_nested = {
#     f"{prefix_similar_sub}_subreddit_id": "subreddit_id",
#     f"{prefix_similar_sub}_subreddit_name": "subreddit_name",
# }


# Create local paths & file
info(f"Creating paths for file...")
p_local_json = path_this_model / f"click_proba_ndjson-{n_unique_subs}-{manual_model_timestamp}"
Path.mkdir(p_local_json, exist_ok=True, parents=True)
subfolder_json = p_local_json.name

f_local_json_name = f"click_proba_ndjson-{n_unique_subs}_subreddits.json"
f_local_json_full = p_local_json / f_local_json_name

# If we run this multiple times, make sure we don't append duplicated lines
try:
    info(f"  Deleting existing file...")
    f_local_json_full.unlink()
except FileNotFoundError as e:
    info(f"  NVM, file does not exist yet...\n {e}")

info(f"Start saving df as ndJSON...")
with open(f_local_json_full, 'w') as f:
    for l_ix_vals_, df_seed_ in tqdm(df_pred[mask_cache_].groupby(l_ix_cache), mininterval=2):
        # NOTE: Assumes we already applied rank & threshold limits to df_seed_!
        d_seed = {
            **d_topk_meta,
            **{k: v for k, v in zip(l_ix_cache, l_ix_vals_)},
            **{
                # each USER should be its own dict
                'top_users': (
                    df_seed_[cols_for_nested_users]
                    .to_dict(orient='records')
                )
            }
        }
        if verbose_:
            info(f"{df_seed_[cols_for_nested_users].shape} <- Output shape for {l_ix_vals_}")
        f.write(json.dumps(d_seed) + "\n")


info(f"Done saving as ndJSON")
print(f"Example subreddit:")
for k, v in d_seed.items():
    if isinstance(v, list):
        print(f"{k}:")
        for _ in v[:5]:
            print(f"    {_}")
    else:
        print(f"{k}:  {v}")

17:36:50 | INFO | "6,938,439 <- Rows to process"
17:36:50 | INFO | "Creating paths for file..."
17:36:50 | INFO | "  Deleting existing file..."
17:36:50 | INFO | "Start saving df as ndJSON..."


  0%|          | 0/60 [00:00<?, ?it/s]

17:37:41 | INFO | "Done saving as ndJSON"


Example subreddit:
pt:  2022-12-01
model_name:  PN click subreddit-user
model_version:  v0.1 2023-05-02
target_subreddit:  ich_iel
target_subreddit_id:  t5_37k29
geo_country_code_top:  US
top_users:
    {'user_id': 't2_55rwha3h', 'click_proba': 0.9963794350624084, 'user_rank_by_sub_and_geo': 1}
    {'user_id': 't2_99knt8yi', 'click_proba': 0.9922695159912109, 'user_rank_by_sub_and_geo': 2}
    {'user_id': 't2_4i9weg61', 'click_proba': 0.9921581745147705, 'user_rank_by_sub_and_geo': 3}
    {'user_id': 't2_o26uh', 'click_proba': 0.9918698072433472, 'user_rank_by_sub_and_geo': 4}
    {'user_id': 't2_axh9ys0g', 'click_proba': 0.9917759895324707, 'user_rank_by_sub_and_geo': 5}
CPU times: user 48.4 s, sys: 3.71 s, total: 52.2 s
Wall time: 52.1 s


In [121]:
!du -Lsh $path_this_model/* | sort -hr 

652M	/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658
277M	/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/df_pred-17486678_6.parquet
248M	/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/df_pred-17486678_5.parquet


## Upload to GCS
BigQuery expects the data in GCS 

In [126]:
f_local_json_full

PosixPath('/home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/click_proba_ndjson-3_subreddits.json')

In [125]:
BREAK
remote_gs_path_json = (
    "gs://i18n-subreddit-clustering/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/"
)

In [127]:
!gsutil -m cp -r -n $f_local_json_full $remote_gs_path_json

Copying file:///home/jupyter/subreddit_clustering_i18n/data/models/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/click_proba_ndjson-3_subreddits.json [Content-Type=application/json]...
==> NOTE: You are uploading one or more large file(s), which would run          
significantly faster if you enable parallel composite uploads. This
feature can be enabled by editing the
"parallel_composite_upload_threshold" value in your .boto
configuration file. However, note that if you do this large files will
be uploaded as `composite objects
<https://cloud.google.com/storage/docs/composite-objects>`_,which
means that any user who downloads such objects will need to have a
compiled crcmod installed (see "gsutil help crcmod"). This is because
without a compiled crcmod, computing checksums on composite objects is
so slow that gsutil disables downloads of composite objects.

/ [1/1 files][651.1 MiB/651.1 MiB] 100% Done                                    
Operation com

In [128]:
!gsutil ls $remote_gs_path_json

gs://i18n-subreddit-clustering/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/click_proba_ndjson-3_subreddits.json


In [131]:
remote_gs_file_json = (
    "gs://i18n-subreddit-clustering/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/click_proba_ndjson-3_subreddits.json"
)

## Upload to BigQuery

In [130]:
from subclu.pn_models.bq_schemas import pn_model_schema
from subclu.utils.big_query_utils import load_data_to_bq_table

In [None]:
BREAK

In [135]:
%%time

info(f"Updating table from file:\n{remote_gs_file_json}")

load_data_to_bq_table(
    uri=remote_gs_file_json,
    bq_project='reddit-employee-datasets',
    bq_dataset='david_bermejo',
    bq_table_name='pn_model_test',
    schema=pn_model_schema(),
    partition_column='pt',
    table_description=(
        "Cache the users from a country that are most likely to click on a PN from a target subreddit"
    ),
    update_table_description=True,
)

20:39:00 | INFO | "Updating table from file:
gs://i18n-subreddit-clustering/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/click_proba_ndjson-3_subreddits.json"
20:39:02 | INFO | "Loading this URI:
  gs://i18n-subreddit-clustering/pn_model/pn_manual_test_2023-05-02_145658/click_proba_ndjson-3-2023-05-02_145658/click_proba_ndjson-3_subreddits.json
Into this table:
  reddit-employee-datasets.david_bermejo.pn_model_test"
20:39:02 | INFO | "Created table reddit-employee-datasets.david_bermejo.pn_model_test"
20:39:03 | INFO | "  0 rows in table BEFORE adding data"
20:39:33 | INFO | "Original Table Expiration: 2023-08-10 20:39:02.503000+00:00"
20:39:33 | INFO | "NEW Table Expiration: None"
20:39:33 | INFO | "Updating subreddit description from:
  Cache the users from a country that are most likely to click on a PN from a target subreddit
to:
  Cache the users from a country that are most likely to click on a PN from a target subreddit"
20:39:33 | INFO | "  6

CPU times: user 64.3 ms, sys: 140 ms, total: 204 ms
Wall time: 33.5 s


# Appendix / EDA
Check whether the filtering logic I had in mind makes sense instead of scoring ALL the users...?

But maybe scoring all the users is fast/cheap enough that it's fine to score a few million users to get the top 500k per country?

---

Cuts to check for proba for ALL users
- by subscribe v. not subscribe
- by country (all)
- by legacy cohort
- by subreddit + 
    - country
    - cohort
    - by subscribe

Cuts to check for proba for only top 500k users per country (the ones that will be cached).
<br>Note that some countries won't have that many users per subreddit because not enough people visit them

In [None]:
TODO