# Running embeddings inference

This notebook produces `sBERT` embeddings and sector and instrument classifications (based on cosine distance) for text in a dataset. It saves these predictions in batches and then loads them to s3.

In [5]:
!pip install sentence-transformers torch tqdm boto3 fsspec s3fs



In [41]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from typing import List, Callable
import math
import pickle
import gzip

from sentence_transformers import SentenceTransformer
from sentence_transformers import util as sbert_utils
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import torch
from tqdm.notebook import tqdm
import boto3

from weak_sentence_classification.utils import Schema, CosineDistanceClassifier

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
MODEL_CONFIG = [
    {
        "model_name": "multi-qa-MiniLM-L6-cos-v1",
        "distance_measure": "cosine",
        "predict_sectors_instruments": True,
    },
    {
        "model_name": "msmarco-distilbert-dot-v5",
        "distance_measure": "dot_product",
        "predict_sectors_instruments": False,
    },
]

S3_BUCKET_URL = 's3://cpr-policy-bucket/'
DATASET_FILENAME = 'policy_dataset.csv.gz'

LOCAL_OUTPUT_PATH = Path('../data')

## 1. Load data and set up models

In [13]:
# s3_dataset_path = S3_BUCKET_URL + DATASET_FILENAME
s3_dataset_path = "../../s3-buckets/cpr-datasets/policy_dataset.csv.gz"
df = pd.read_csv(s3_dataset_path)

# Add unique id for each text entry
df['text_id'] = range(0, len(df))

df.info()


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1666918 entries, 0 to 1666917
Data columns (total 6 columns):
 #   Column       Non-Null Count    Dtype 
---  ------       --------------    ----- 
 0   Unnamed: 0   1666918 non-null  int64 
 1   policy_id    1666918 non-null  int64 
 2   policy_name  1666918 non-null  object
 3   page_id      1666918 non-null  int64 
 4   text         1666918 non-null  object
 5   text_id      1666918 non-null  int64 
dtypes: int64(4), object(2)
memory usage: 76.3+ MB


### 1.1 Filter dataset by length

In [5]:
# not sure whether this is the right thing to do here, because we need all embeddings for search. TODO: have a think about it.

### 1.2 Get sector and instrument schemas

In [6]:
SCHEMA_FOLDER = Path("../schema")

instruments = Schema.from_yaml_path(SCHEMA_FOLDER/"instruments.yml")
sectors = Schema.from_yaml_path(SCHEMA_FOLDER/"sectors.yml")

### 1.3 Get models

In [29]:
for idx, config in enumerate(MODEL_CONFIG):
    MODEL_CONFIG[idx]["sbert_model"] = SentenceTransformer(config["model_name"])
    

## 2. Run models

In [49]:
def total_batches(text_arr, batch_size):
    n = len(text_arr)
    batch_total = math.ceil(n / batch_size)
    
    return batch_total

def get_batches(text_arr, batch_size):
    n = len(text_arr)
    batch_total_idx = math.ceil(n / batch_size) * batch_size
    for batch_idx in range(0, batch_total_idx, batch_size):
        yield text_arr[batch_idx:batch_idx + batch_size]

def save_embeddings(batches_query_emb, save_path: Path):
    """Save compressed embeddings"""

    emb = torch.cat(batches_query_emb)
    
    file = gzip.GzipFile(save_path, 'wb')
    file.write(pickle.dumps(emb, protocol=pickle.HIGHEST_PROTOCOL))
    file.close()
    
def load_embeddings(filename):
    """Load compressed embeddings
    """
    file = gzip.GzipFile(filename, 'rb')
    data = file.read()
    obj = pickle.loads(data)
    file.close()
    return obj
        
def normalise_predictions(b_df, preds, conf, schema):
    b_preds = []
    for ix in range(len(b_df)):
        for p_ix, p in enumerate(preds[ix]):
            b_preds.append(
                {
                    'policy_id': b_df.iloc[ix]['policy_id'],
                    'page_id': b_df.iloc[ix]['page_id'],
                    'text_id': b_df.iloc[ix]['text_id'],
                    'schema': schema,
                    'pred': p,
                    'conf': conf[ix][p_ix]
                }
            )

    return pd.DataFrame(b_preds)

In [45]:
for idx, config in enumerate(MODEL_CONFIG):
    print(f"Running {config['model_name']}")
    
    embedding_path = LOCAL_OUTPUT_PATH / f"policy_text_embeddings_{config['model_name']}.pkl.gz"
    predictions_path = LOCAL_OUTPUT_PATH / f"policy_text_predictions_{config['model_name']}.csv.gz"

    MODEL_CONFIG[idx]["embedding_path"] = embedding_path
    MODEL_CONFIG[idx]["predictions_path"] = predictions_path
    
    if predictions_path.exists():
        predictions_path.unlink()

    threshold = 0.35
    save_every = 1
    reset_batches = False
    batch_size = 100

    n_batches = total_batches(df, batch_size)

    instrument_clf = CosineDistanceClassifier(
        instruments, 
        sbert_model= config["sbert_model"], 
        distance_measure= config["distance_measure"],
        concat_keywords_with_subsectors=True
    )
    sector_clf = CosineDistanceClassifier(
        sectors, 
        sbert_model= config["sbert_model"], 
        distance_measure= config["distance_measure"],
        concat_keywords_with_subsectors=True
    )


    all_query_emb = []
    all_instr_preds = []
    all_instr_conf = []
    all_sector_preds = []
    all_sector_conf = []

    predictions = []

    for b_ix, b_df in enumerate(tqdm(get_batches(df, batch_size), total=n_batches)):
        batch_query_emb, instrument_preds, instrument_conf = instrument_clf.predict(b_df.text.values, threshold, True)
        sector_preds, sector_conf = sector_clf.predict(b_df.text.values, threshold, False)

        all_query_emb.append(batch_query_emb)

        if config["predict_sectors_instruments"]:
            predictions = pd.concat(
                [
                    normalise_predictions(b_df, instrument_preds, instrument_conf, 'instrument'),
                    normalise_predictions(b_df, sector_preds, sector_conf, 'sector')
                ]
            )

            if b_ix % save_every == 0:
                if predictions_path.exists():
                    mode = 'at'
                    header = False
                else:
                    mode = 'wt'
                    header = True
                predictions.to_csv(predictions_path, mode=mode, header=header, index=False), 


    save_embeddings(all_query_emb, embedding_path)


Running multi-qa-MiniLM-L6-cos-v1


  0%|          | 0/2 [00:00<?, ?it/s]

Running msmarco-distilbert-dot-v5


  0%|          | 0/2 [00:00<?, ?it/s]

## 3. Save predictions and embeddings to s3

### 3.1 Predictions

In [52]:
for config in MODEL_CONFIG:       
    s3_client = boto3.client('s3')
    s3_client.upload_file(config['embedding_path'], 'cpr-datasets', config['embedding_path'].name)
    
    if config['predict_sectors_instruments']:
        s3_client.upload_file(config['predictions_path'], 'cpr-datasets', config['predictions_path'].name)
        
    print(f"{config['model_name']} uploaded")


ValueError: Filename must be a string