# Cosine similarity between sentence vectors as a noisy label source

At sentence level. Various `sentence-bert` models are available at https://www.sbert.net/docs/pretrained_models.html

We try representations trained for different tasks: 

* **asymmetric semantic search**, where a short query is provided to try to retrieve a longer paragraph
* **symmetric semantic search**, where a query is provided to try to retrieve a similar-sized phrase

## KD notes

- it seems that symmetric is more effective than asymmetric
- the symmetric model also seems to do a good job at *not* assigning a sentence an sectorument/sector label when there is no relevant instrument/sector
- I've left the thresholds intentionally low in the notebook so we can have a look at some False Negatives
- it was much more effective to generate keyword vectors using `{KEYWORD} {SUBSECTOR}` rather than just keyword, due to some keywords being not so topic-specific

In [16]:
!pip install sentence-transformers umap-learn seaborn "numpy<1.20"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting seaborn
  Using cached seaborn-0.11.2-py3-none-any.whl (292 kB)
Collecting matplotlib>=2.2
  Downloading matplotlib-3.4.3-cp38-cp38-macosx_10_9_x86_64.whl (7.2 MB)
[K     |████████████████████████████████| 7.2 MB 7.2 MB/s eta 0:00:01
Collecting kiwisolver>=1.0.1
  Downloading kiwisolver-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl (61 kB)
[K     |████████████████████████████████| 61 kB 1.5 MB/s  eta 0:00:01
[?25hCollecting cycler>=0.10
  Using cached cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
Installing collected packages: kiwisolver, cycler, matplotlib, seaborn
Successfully installed cycler-0.10.0 kiwisolver-1.3.2 matplotlib-3.4.3 seaborn-0.11.2


In [139]:
from pathlib import Path
from typing import List, Callable
import math
import pickle

from sentence_transformers import SentenceTransformer
from sentence_transformers import util as sbert_utils
from tqdm.auto import tqdm
import pandas as pd
#import matplotlib.pyplot as plt
import numpy as np
import torch

from utils import Schema, load_policy_dataset

In [2]:
df = load_policy_dataset()

df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1666918 entries, 0 to 1666917
Data columns (total 4 columns):
 #   Column       Non-Null Count    Dtype 
---  ------       --------------    ----- 
 0   policy_id    1666918 non-null  int64 
 1   policy_name  1666918 non-null  object
 2   page_id      1666918 non-null  int64 
 3   text         1666918 non-null  object
dtypes: int64(2), object(2)
memory usage: 50.9+ MB


In [3]:
df.sample(10)

Unnamed: 0,policy_id,policy_name,page_id,text
605324,527,Bipartisan Budget Act of 2018 (H.R.1892),23,"Oct 27, 2018 Jkt 079139 PO 00123"
341113,266,National Climate Change Response Strategy 2010...,222,The support and engagement of technical expert...
891141,748,The Vanuatu Climate Change and Disaster Risk R...,20,and facilitating arrangements within Vanuatu a...
1609636,1529,Land use and Building Decree enacted under the...,47,Section 165 Altering the natural course of wat...
1566623,1475,Austrian Climate Strategy Austria (2002),28,"Here, it is evident that the uncertainties wit..."
911864,774,Joint National Action Plan on Climate Change A...,30,Key constraints or gaps identified in the Init...
556995,467,Energy Conservation Code of Practice (R-6),6,15 Table 6.1 Maximum Allowable U-values for Di...
1157625,1035,Three Year Action Agenda India (2017),44,As the success of Gujarat and the original und...
1638101,1557,Alternative Fuel Tax Exemption United States o...,9,The position holder is liable for the per gall...
490768,412,"Consolidated Appropriations Act, 2016",747,8 USC 1101 note.


In [4]:
SCHEMA_FOLDER = Path("../../schema")

instruments = Schema.from_yaml_path(SCHEMA_FOLDER/"instruments.yml")
sectors = Schema.from_yaml_path(SCHEMA_FOLDER/"sectors.yml")

## 1. Different model types

In [37]:
def plot_projections(emb_2d: np.ndarray, schema: Schema, start_end: List[int] = None):
    twod_df = pd.DataFrame(emb_2d, columns=["x", "y"])
    twod_df['keyword'] = schema.all_keywords
    twod_df['subsector'] = twod_df['keyword'].map(schema.keyword_subsector_mapping)
    
    if start_end:
        twod_df = twod_df.sort_values('subsector', ascending=True).iloc[start_end[0], start_end[1]]

    _, ax = plt.subplots(figsize=(20,15))
    sns.scatterplot(x=twod_df["x"], y=twod_df["y"], hue=twod_df["subsector"], markers=".")

# plotting the projections doesn't seem to expose any particularly useful structure but the code is here just in case!
# reducer = umap.UMAP()
# sector_keyword_embeddings_2d = reducer.fit_transform(sector_keyword_embeddings)

# plot_projections(sector_keyword_embeddings_2d, sectors)


In [10]:
instruments.all_keywords

array(['demonstration project', 'joint research', 'knowledge generation',
       'research and development', 'research body',
       'research collaboration', 'research funding', 'research programme',
       'research scheme', 'scientific inquiries', 'technology', 'advice',
       'consultation', 'education', 'indigenous knowledge',
       'knowledge distribution', 'knowledge sharing and dissemination',
       'professional training', 'projections', 'public information',
       'reporting information', 'technical assistance', 'training',
       'environmental assessment', 'federal adaptation programme',
       'national adaptation service', 'processes',
       'public-private partnerships', 'resource management',
       'risk assessments', 'stakeholder engagement', 'structures',
       'climate fund', 'debt support', 'finance', 'financial flows',
       'financial regulation', 'funding', 'grants', 'green finance',
       'interest', 'international climate finance', 'loan',
       'low-

In [13]:
instruments.all_keywords + instruments.keyword_subsector_mapping

UFuncTypeError: ufunc 'add' did not contain a loop with signature matching types (dtype('<U35'), dtype('<U35')) -> None

In [178]:
class CosineDistanceClassifier:
    def __init__(self, schema: Schema, sbert_model: str, distance_measure: str, concat_keywords_with_subsectors: bool):
        assert distance_measure in ['dot_product', 'cosine']
        
        self._normalise_vectors = distance_measure == 'dot_product'
        
        self.schema = schema
        self._keyword_subsector_mapping = self.schema.keyword_subsector_mapping
        
        self.sbert_model = sbert_model
        self.distance_measure = distance_measure
        
        self.encoder = self._get_sentence_encoder()
        self._keyword_embeddings = self._embed_keywords(concat_keywords_with_subsectors)
        
    def _get_sentence_encoder(self):
        return SentenceTransformer(self.sbert_model)
    
    def _keyword_subsector_concatenator(self, kwd: str, k_ix: int):
        """string modifier for _embed_keywords"""
        
        return f"{kwd} {self._keyword_subsector_mapping[k_ix]}"
    
    def _embed_keywords(self, concat_with_subsectors: bool):
        keywords = self.schema.all_keywords
        
        if concat_with_subsectors: 
            keywords = [self._keyword_subsector_concatenator(k, k_ix) for k_ix, k in enumerate(keywords)]
            
        keyword_embeddings = self.encoder.encode(keywords, convert_to_tensor=True)
        
        if self._normalise_vectors:
            keyword_embeddings = torch.nn.functional.normalize(keyword_embeddings, p=2, dim=1)
            
        return keyword_embeddings
            
    def _get_grouped_labels(self, pred, conf):
        """Takes the predicted labels at keyword level and returns unique predictions by level"""

        all_pred_labels = []
        all_conf_labels = []
        #for pred_ix, pred in enumerate(preds):
        # Get the predicted labels at subsector level
        pred_labels = self.schema.keyword_subsector_mapping[pred]

        # Find the unique labels at subsector level and return the first element of each unique label
        unique_labels, unique_label_idx = np.unique(pred_labels, return_index=True)
        unique_label_conf = conf[unique_label_idx]
        
        # Sort the predictions in reverse order of confidence
        sorted_conf_idx = np.argsort(unique_label_conf)[::-1]
        unique_labels = unique_labels[sorted_conf_idx]
        unique_label_conf = unique_label_conf[sorted_conf_idx]


        return unique_labels, unique_label_conf
        
    def _get_predicted_labels(self, cos_scores: torch.Tensor, threshold: float):
        all_preds = []
        all_conf = []
        for idx in range(0, cos_scores.shape[0]):
            # Get the predicted labels and confidences over threshold
            cls_idx = torch.where(cos_scores[idx] > threshold)[0]
            preds = cls_idx.numpy()
            conf = cos_scores[idx, cls_idx].numpy()

            # Get the unique subsector level predicted labels and confidences
            preds, conf = self._get_grouped_labels(preds, conf)
            
            all_preds.append(preds)
            all_conf.append(conf)

        return all_preds, all_conf 

    def predict(self, query: str, threshold: float, return_embeddings: bool):
        query_embedding = self.encoder.encode(query, convert_to_tensor=True)
        if self._normalise_vectors:
            query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=0)
        
        cos_scores = sbert_utils.cos_sim(query_embedding, self._keyword_embeddings)

        preds, conf = self._get_predicted_labels(cos_scores, threshold)

        if return_embeddings:
            result = (query_embedding, preds, conf)
        else:
            result = (preds, conf)
        
        return result



In [119]:
# Prepare fixed number of examples; set parameters for the rest of this section

n_examples = 52
examples = df.sample(n_examples, random_state=99)['text'].tolist()

### 1a) Asymmetric semantic search

In [179]:
"""
Model choices (from https://www.sbert.net/docs/pretrained-models/msmarco-v3.html):
- msmarco-MiniLM-L-6-v3: tuned for cosine similarity, which prefers retrieving shorter passages
- msmarco-distilbert-base-v4: same, but larger
- msmarco-distilbert-base-tas-b: tuned for dot product, which prefers retrieval of longer passages
"""

clf_instruments_asymm_cosine = CosineDistanceClassifier(
    schema = instruments, 
    sbert_model = "msmarco-MiniLM-L-6-v3", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)

clf_sectors_asymm_cosine = CosineDistanceClassifier(
    schema = sectors, 
    sbert_model = "msmarco-MiniLM-L-6-v3", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)




### Batched predictions

In [180]:
def get_batches(text_arr, batch_size):
    n = len(text_arr)
    batch_total_idx = math.ceil(n / batch_size) * batch_size
    for batch_idx in range(0, batch_total_idx, batch_size):
        yield text_arr[batch_idx:batch_idx + batch_size]
#def get_batches(n, batch_size):


In [181]:
def save_embeddings(batches_query_emb, save_path: Path):
    """Save query embeddings"""

    emb = torch.cat(batches_query_emb)

    with open(save_path, 'wb') as embedding_f:
        pickle.dump(emb, embedding_f, protocol=pickle.HIGHEST_PROTOCOL)

In [182]:
def stack_ragged(array_list, axis=0):
    lengths = [np.shape(a)[axis] for a in array_list]
    idx = np.cumsum(lengths[:-1])
    stacked = np.concatenate(array_list, axis=axis)
    return stacked, idx
    
def concat_preds(pred_list: List[np.array]):
    pred_list = np.concatenate(pred_list)
    stacked_preds, pred_idx = stack_ragged(pred_list)

    return stacked_preds, pred_idx

In [183]:
def save_predictions(pred_filename: Path, **kwargs):
    """Save predictions and confidences in a stacked array in an npz file.
    This can be loaded again using np.load, which will return a dictionary
    """

    batches_instr_preds = kwargs['batches_instr_preds']
    batches_instr_conf = kwargs['batches_instr_conf']
    batches_sector_preds = kwargs['batches_sector_preds']
    batches_sector_conf = kwargs['batches_sector_conf']


    stacked_instr_preds, stacked_instr_pred_idx = concat_preds(batches_instr_preds)
    stacked_instr_conf, stacked_instr_conf_idx = concat_preds(batches_instr_conf)
    stacked_sector_preds, stacked_sector_pred_idx = concat_preds(batches_sector_preds)
    stacked_sector_conf, stacked_sector_conf_idx = concat_preds(batches_sector_conf)

    np.savez(
        str(pred_filename), 
        stacked_instr_preds=stacked_instr_preds, 
        stacked_instr_preds_index=stacked_instr_pred_idx,
        stacked_instr_conf=stacked_instr_conf,
        stacked_instr_conf_index=stacked_instr_conf_idx,
        stacked_sector_preds=stacked_sector_preds, 
        stacked_sector_preds_index=stacked_sector_pred_idx,
        stacked_sector_conf=stacked_sector_conf,
        stacked_sector_conf_index=stacked_sector_conf_idx,
    )

In [202]:
data_path = Path('../../data')
embedding_path = data_path / 'policy_text_embeddings.pkl'
predictions_path = data_path / 'policy_text_predictions.npz'

threshold = 0.35
save_every = 10
reset_batches = False
batch_size = 100

all_query_emb = []
all_instr_preds = []
all_instr_conf = []
all_sector_preds = []
all_sector_conf = []
for b_ix, b_text in enumerate(get_batches(df.text.values[0:5000], batch_size)):
    batch_query_emb, instrument_preds, instrument_conf = clf_instruments_asymm_cosine.predict(b_text, threshold, True)
    sector_preds, sector_conf = clf_sectors_asymm_cosine.predict(b_text, threshold, False)

    all_query_emb.append(batch_query_emb)
    all_instr_preds.append(instrument_preds)
    all_instr_conf.append(instrument_conf)
    all_sector_preds.append(sector_preds)
    all_sector_conf.append(sector_conf)

    preds = {
        'batches_instr_preds': all_instr_preds, 
        'batches_instr_conf': all_instr_conf, 
        'batches_sector_preds': all_sector_preds,
        'batches_sector_conf': all_sector_conf
    }

    if b_ix % save_every == 0:
        save_embeddings(all_query_emb, embedding_path)
        save_predictions(predictions_path, **preds)

    #clf_sectors_asymm_cosine.predict(b)



In [205]:
(400*88)/3600

9.777777777777779

In [195]:
preds = np.load('../../data/policy_text_predictions.npz')

In [198]:
instr_labels = preds['stacked_instr_preds']
instr_labels_idx = preds['stacked_instr_preds_index']
instr_conf = preds['stacked_instr_conf']
instr_conf_idx = preds['stacked_instr_conf_index']

sector_labels = preds['stacked_sector_preds']
sector_labels_idx = preds['stacked_sector_preds_index']
sector_conf = preds['stacked_sector_conf']
sector_conf_idx = preds['stacked_sector_conf_index']

In [199]:
instr_labels = np.split(instr_labels, instr_labels_idx)
instr_conf = np.split(instr_conf, instr_conf_idx)
sector_labels = np.split(sector_labels, sector_labels_idx)
sector_conf = np.split(sector_conf, sector_conf_idx)

TODO
- Concatenate and save batches into ragged numpy array - https://tonysyu.github.io/ragged-arrays.html#.YV7SzX3TWUk
- Output batches of query encodings, concatenate and save into numpy array
- Save the numpy arrays to disk

### Example predictions (superceeded)

In [130]:
THRESHOLD = 0.35

for _str in examples:
    print(_str)
    print()
    print("INSTRUMENT PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_instruments_asymm_cosine.predict(_str, THRESHOLD)]))
    print()
    print("SECTOR PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_sectors_asymm_cosine.predict(_str, THRESHOLD)]))
    print("----------------------")

Sources of Energy Supply At present, Bangladesh has energy supply from both renewable and nonrenewable sources, 38 percent of which comes from biomass (Figure 3.1).

INSTRUMENT PREDICTIONS
- ('fossil fuel subsidy', 'Fiscal or financial incentives', 0.4097)
- ('tariff', 'Fiscal or financial incentives', 0.3522)

SECTOR PREDICTIONS
- ('fossil fuels', 'Energy production', 0.4043)
- ('fossil fuel fires', 'Energy production', 0.38)
- ('electricity subsidies', 'Energy use', 0.3708)
- ('renewable energy', 'Energy production', 0.3685)
- ('solar energy', 'Energy production', 0.362)
- ('forest tundra', 'Forestry', 0.3572)
- ('energy technology', 'Energy use', 0.3555)
----------------------
To put that in real-world context, roughly 35 jobs are created for each million board feet of wood processed.

INSTRUMENT PREDICTIONS


SECTOR PREDICTIONS

----------------------
Research on the likelihood of disasters and the assessment of the likely social, economic and environmental impacts will be conducte

### 1b) Symmetric semantic search

In [126]:
"""
Model choices (from https://www.sbert.net/docs/pretrained_models.html#semantic-search):
- multi-qa-MiniLM-L6-cos-v1: tuned for cosine similarity, which prefers retrieving shorter passages
- multi-qa-MiniLM-L6-dot-v1: tuned for dot product, which prefers retrieval of longer passages
"""

emb_symsearch = SentenceTransformer("multi-qa-MiniLM-L6-dot-v1")

clf_instruments_symm_cosine = CosineDistanceClassifier(
    schema = instruments, 
    sbert_model = "multi-qa-MiniLM-L6-cos-v1", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)

clf_sectors_symm_cosine = CosineDistanceClassifier(
    schema = sectors, 
    sbert_model = "multi-qa-MiniLM-L6-cos-v1", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)


In [128]:
THRESHOLD = 0.4

for _str in examples:
    print(_str)
    print()
    print("INSTRUMENT PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_instruments_symm_cosine.predict(_str, THRESHOLD)]))
    print()
    print("SECTOR PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_sectors_symm_cosine.predict(_str, THRESHOLD)]))
    print("----------------------")

Sources of Energy Supply At present, Bangladesh has energy supply from both renewable and nonrenewable sources, 38 percent of which comes from biomass (Figure 3.1).

INSTRUMENT PREDICTIONS


SECTOR PREDICTIONS
- ('renewables', 'Energy production', 0.4862)
- ('renewable energy', 'Energy production', 0.4853)
- ('energy demand', 'Energy production', 0.468)
- ('fossil fuels', 'Energy production', 0.4601)
- ('offshore', 'Energy production', 0.4592)
- ('onshore', 'Energy production', 0.4528)
- ('natural gas', 'Energy production', 0.4483)
- ('fossil fuel fires', 'Energy production', 0.4445)
- ('biofuels', 'Energy production', 0.4425)
- ('energy production', 'Energy production', 0.4246)
- ('energy', 'Energy (general)', 0.4242)
- ('energy industries', 'Energy production', 0.4236)
- ('fuels', 'Energy production', 0.4177)
- ('nuclear energy', 'Energy production', 0.4138)
- ('power', 'Energy use', 0.4129)
- ('gas', 'Energy production', 0.4089)
- ('energy use', 'Energy use', 0.4076)
- ('bioethanol'