In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd

device = torch.device('cpu')

In [2]:
# List of unique models. Both each label_key and each model_version must be unique
model_list = [{
        'model_type': x[0],
        'label_key': x[1],
        'model_version': x[2],
        'train_state': x[3],
        'model': torch.jit.load(f'./saves/{x[2]}/{x[3]}.pt', map_location = device)
    }
    for x in [
        ('financial_health', 'financial_sentiment', 'financial_health_v1-financial_sentiment-20231112', 'epoch_002_step_00000'),
        ('financial_health', 'employment_status', 'financial_health_v1-employment_status-20231112', 'epoch_002_step_00000')
    ]
]

In [3]:
# Test inference on models
from helpers.loaders import TextDataset
from torch.utils.data import DataLoader
from transformers import AlbertTokenizer

tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')

ds = TextDataset(tokenizer, [
    'The job market sucks! Unemployed and broke.', 
    'Just got a big pay raise! How do start saving for retirement?',
    'I like my job, but I am underpaid. What is the best way to negotiate for a higher salary?', 
    '$1 million in credit card debt due to vaseline addiction. Should I declare bankruptcy?',
    'I just inherited $100k from my great-great grandmother! What to do with the money?',
    'Recommendations for pet insurance? Dog smells.',
    'Laid off at work, what do?'
])
dl = DataLoader(ds, batch_size = 1, shuffle = False)

for model_obj in model_list:
    print(model_obj['label_key'])
    for b in dl:
        with torch.no_grad():
            logits = model_obj['model'](b['input_ids'].to(device), b['attention_mask'].to(device))['logits'].cpu()
            probs = F.softmax(logits, dim = 1)
            result = torch.argmax(probs, dim = 1).numpy()[0]
            print(probs.numpy().flatten()[result], result)

financial_sentiment
0.9192377 0
0.80933076 1
0.5541288 0
0.8097795 0
0.7933027 1
0.7505795 2
0.7638647 0
employment_status
0.9055272 0
0.5437704 2
0.7681394 1
0.7446556 2
0.81502765 2
0.84382254 2
0.9220889 0


In [4]:
# Get list of datasets to be scored under each model
from helpers.db import get_postgres_query

required_datasets = [
    f"SELECT '{m['model_type']}' AS model_type, '{m['label_key']}' AS label_key, '{m['model_version']}' AS model_version" 
    for m in model_list
]

raw_data = get_postgres_query(
    """
    /* Pulls the first instance of every post (by post id), even if scraped multiple times.
     * Then, cross joins on all desired combinations of model_type x model_versions x label_keys given above.
     * Finally, removes any posts where the existing model_type/label_key/model_version combinations already exist.
     * 
     * Notes on text_scraper_reddit_classifier_scores:
     *  - train_state is merely a descriptive column, and has no uniqueness properties.
     *  - The unique index is model_type x model_version x post_id x label_key
     */
    WITH desired_combinations_0 AS (
        SELECT
            scrape_id, post_id,
            CONCAT(TRIM(title), '\n', REGEXP_REPLACE(TRIM(selftext), '[\t\n\r]', ' ', 'g')) AS input_text,
            ARRAY_LENGTH(REGEXP_SPLIT_TO_ARRAY(trim(selftext), E'\\W+'), 1) * 1.33 AS n_tokens,
            ROW_NUMBER() OVER (PARTITION BY post_id ORDER BY created_dttm) AS rn
        FROM text_scraper_reddit_scrapes 
        WHERE
            selftext IS NOT NULL
            AND source_board IN ('jobs', 'careerguidance', 'personalfinance')
            AND (
                (DATE(created_dttm) BETWEEN '2018-01-01' AND '2022-12-31' AND scrape_method = 'pushshift_backfill')
                OR (DATE(created_dttm) BETWEEN '2023-01-01' AND '2023-08-20' AND scrape_method = 'pullpush_backfill')
                OR DATE(created_dttm) > '2023-08-20'
            )
    ), desired_combinations AS (
    	SELECT scrape_id, post_id, input_text, n_tokens
    	FROM desired_combinations_0
    	WHERE rn = 1 AND n_tokens >= 5 AND n_tokens <= 1024
    )
    SELECT a.*, b.*
    FROM desired_combinations a
    CROSS JOIN ({q}) b
    -- Anti join on any model type/version/post combinations that already exist
    LEFT JOIN text_scraper_reddit_classifier_scores c
    	ON b.model_type = c.model_type AND b.model_version = c.model_version AND b.model_type = c.model_type AND a.post_id = c.post_id
    WHERE c.score_id IS NULL
    LIMIT 2000
    """.format(q = '\nUNION ALL '.join(required_datasets))
)


datasets = [
    raw_data\
        .pipe(lambda df: df[(df['model_type'] == m['model_type']) & (df['label_key'] == m['label_key']) & (df['model_version'] == m['model_version'])])\
        .reset_index(drop = True)
    for m in model_list
]

for d in datasets:
    display(d)

Unnamed: 0,scrape_id,post_id,input_text,n_tokens,model_type,label_key,model_version
0,1679796,7nf2fp,Weekday Help Thread for the week of January 01...,6.65,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
1,1679807,7ng6hg,Paying off my father's credit\nMy father has a...,5.32,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
2,1679818,7nh538,Recently married couple in early 30’s trying t...,5.32,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
3,1679823,7nhrsp,"Warning: AT&amp;T applying ""customer loyalty s...",6.65,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
4,1679827,7ni34j,Newly married couple with awful financials. I ...,6.65,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
...,...,...,...,...,...,...,...
995,1351656,8vo1u3,"Yet Another Robo-Horse-race update, comparing ...",10.64,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
996,1351659,8vop1y,Rental Car Hit &amp; Run - Hit with Admin Fees...,6.65,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
997,1447384,8vpv8z,I really need job help.\nI'm a 25 year old col...,5.32,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112
998,1351671,8vse9q,Canadian Relocating to the US\nI'm a 25 year o...,5.32,financial_health,financial_sentiment,financial_health_v1-financial_sentiment-20231112


Unnamed: 0,scrape_id,post_id,input_text,n_tokens,model_type,label_key,model_version
0,1679796,7nf2fp,Weekday Help Thread for the week of January 01...,6.65,financial_health,employment_status,financial_health_v1-employment_status-20231112
1,1679807,7ng6hg,Paying off my father's credit\nMy father has a...,5.32,financial_health,employment_status,financial_health_v1-employment_status-20231112
2,1679818,7nh538,Recently married couple in early 30’s trying t...,5.32,financial_health,employment_status,financial_health_v1-employment_status-20231112
3,1679823,7nhrsp,"Warning: AT&amp;T applying ""customer loyalty s...",6.65,financial_health,employment_status,financial_health_v1-employment_status-20231112
4,1679827,7ni34j,Newly married couple with awful financials. I ...,6.65,financial_health,employment_status,financial_health_v1-employment_status-20231112
...,...,...,...,...,...,...,...
995,1351656,8vo1u3,"Yet Another Robo-Horse-race update, comparing ...",10.64,financial_health,employment_status,financial_health_v1-employment_status-20231112
996,1351659,8vop1y,Rental Car Hit &amp; Run - Hit with Admin Fees...,6.65,financial_health,employment_status,financial_health_v1-employment_status-20231112
997,1447384,8vpv8z,I really need job help.\nI'm a 25 year old col...,5.32,financial_health,employment_status,financial_health_v1-employment_status-20231112
998,1351671,8vse9q,Canadian Relocating to the US\nI'm a 25 year o...,5.32,financial_health,employment_status,financial_health_v1-employment_status-20231112


In [None]:
@torch.no_grad()
def run_inference(ts_model, input_texts: list, device, batch_size = 16):
    """
    Run model inference
    
    Params:
        @model_obj: The torchscript model object
        @input_texts: A list with input texts to run
    """
    ts_model.eval()
    ds = TextDataset(tokenizer, input_texts)
    dl = DataLoader(ds, batch_size = batch_size, shuffle = False)

    values_list = []
    probs_list = []
    
    for i, b in tqdm(enumerate(dl), total = len(dl)):
        logits = ts_model(b['input_ids'].to(device), b['attention_mask'].to(device))['logits']
        probs = F.softmax(logits, dim = 1)
        argmax = torch.argmax(probs, dim = 1)
        argmax_probs = probs.gather(1, index = argmax.view(-1, 1)).squeeze()
        
        values_list.extend(argmax.numpy().tolist())
        probs_list.extend(argmax_probs.numpy().tolist())
        
    return (values_list, probs_list)

results_list = []
for model_obj, dataset in zip(model_list, datasets):
    values_list, probs_list = run_inference(model_obj['model'], dataset['input_text'].tolist(), device)
    model_df = pd.DataFrame({
        'model_type': model_obj['model_type'],
        'model_version': model_obj['model_version'],
        'train_state': model_obj['train_state'],
        'scrape_id': dataset['scrape_id'].tolist(),
        'post_id': dataset['post_id'].tolist(),
        'label_key': model_obj['label_key'],
        'label_value': values_list,
        'label_prob': probs_list
    })
    results_list.append(model_df)

output_df = pd.concat(results_list)

  2%|█▉                                                                                                                         | 1/63 [00:04<05:09,  4.99s/it]

In [None]:
from helpers.db import write_postgres_df

def split_df(df, chunk_size = 200):
   chunks = []
   num_chunks = len(df) // chunk_size + 1
   for i in range(num_chunks):
       chunks.append(df[i * chunk_size:(i + 1) * chunk_size])
   return chunks
    
for df in tqdm(split_df(output_df)):
    write_postgres_df(
        df, 
        'text_scraper_reddit_classifier_scores', 
        """
        ON CONFLICT (model_type, model_version, post_id, label_key) 
        DO UPDATE SET
            train_state=EXCLUDED.train_state,
            scrape_id=EXCLUDED.scrape_id,
            label_value=EXCLUDED.label_value,
            label_prob=EXCLUDED.label_prob
        """
    )