# Cosine similarity between sentence vectors as a noisy label source

At sentence level. Various `sentence-bert` models are available at https://www.sbert.net/docs/pretrained_models.html

We try representations trained for different tasks: 

* **asymmetric semantic search**, where a short query is provided to try to retrieve a longer paragraph
* **symmetric semantic search**, where a query is provided to try to retrieve a similar-sized phrase

## KD notes

- it seems that symmetric is more effective than asymmetric
- the symmetric model also seems to do a good job at *not* assigning a sentence an instrument/sector label when there is no relevant instrument/sector
- I've left the thresholds intentionally low in the notebook so we can have a look at some False Negatives
- it was much more effective to generate keyword vectors using `{KEYWORD} {SUBSECTOR}` rather than just keyword, due to some keywords being not so topic-specific

In [16]:
!pip install sentence-transformers umap-learn seaborn "numpy<1.20"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting seaborn
  Using cached seaborn-0.11.2-py3-none-any.whl (292 kB)
Collecting matplotlib>=2.2
  Downloading matplotlib-3.4.3-cp38-cp38-macosx_10_9_x86_64.whl (7.2 MB)
[K     |████████████████████████████████| 7.2 MB 7.2 MB/s eta 0:00:01
Collecting kiwisolver>=1.0.1
  Downloading kiwisolver-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl (61 kB)
[K     |████████████████████████████████| 61 kB 1.5 MB/s  eta 0:00:01
[?25hCollecting cycler>=0.10
  Using cached cycler-0.10.0-py2.py3-none-any.whl (6.5 kB)
Installing collected packages: kiwisolver, cycler, matplotlib, seaborn
Successfully installed cycler-0.10.0 kiwisolver-1.3.2 matplotlib-3.4.3 seaborn-0.11.2


In [40]:
from pathlib import Path
from typing import List, Callable

from sentence_transformers import SentenceTransformer
from sentence_transformers import util as sbert_utils
import umap
from tqdm.auto import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch

from utils import Schema, load_policy_dataset

In [2]:
df = load_policy_dataset()

df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1666918 entries, 0 to 1666917
Data columns (total 4 columns):
 #   Column       Non-Null Count    Dtype 
---  ------       --------------    ----- 
 0   policy_id    1666918 non-null  int64 
 1   policy_name  1666918 non-null  object
 2   page_id      1666918 non-null  int64 
 3   text         1666918 non-null  object
dtypes: int64(2), object(2)
memory usage: 50.9+ MB


In [3]:
SCHEMA_FOLDER = Path("../../schema")

instruments = Schema.from_yaml_path(SCHEMA_FOLDER/"instruments.yml")
sectors = Schema.from_yaml_path(SCHEMA_FOLDER/"sectors.yml")

## 1. Different model types

In [37]:
def plot_projections(emb_2d: np.ndarray, schema: Schema, start_end: List[int] = None):
    twod_df = pd.DataFrame(emb_2d, columns=["x", "y"])
    twod_df['keyword'] = schema.all_keywords
    twod_df['subsector'] = twod_df['keyword'].map(schema.keyword_subsector_mapping)
    
    if start_end:
        twod_df = twod_df.sort_values('subsector', ascending=True).iloc[start_end[0], start_end[1]]

    _, ax = plt.subplots(figsize=(20,15))
    sns.scatterplot(x=twod_df["x"], y=twod_df["y"], hue=twod_df["subsector"], markers=".")

# plotting the projections doesn't seem to expose any particularly useful structure but the code is here just in case!
# reducer = umap.UMAP()
# sector_keyword_embeddings_2d = reducer.fit_transform(sector_keyword_embeddings)

# plot_projections(sector_keyword_embeddings_2d, sectors)


In [112]:
class CosineDistanceClassifier:
    def __init__(self, schema: Schema, sbert_model: str, distance_measure: str, concat_keywords_with_subsectors: bool):
        assert distance_measure in ['dot_product', 'cosine']
        
        self._normalise_vectors = distance_measure == 'dot_product'
        
        self.schema = schema
        self._keyword_subsector_mapping = self.schema.keyword_subsector_mapping
        
        self.sbert_model = sbert_model
        self.distance_measure = distance_measure
        
        self.encoder = self._get_sentence_encoder()
        self._keyword_embeddings = self._embed_keywords(concat_keywords_with_subsectors)
        
    def _get_sentence_encoder(self):
        return SentenceTransformer(self.sbert_model)
    
    def _keyword_subsector_concatenator(self, kwd: str):
        """string modifier for _embed_keywords"""
        
        return f"{kwd} {self._keyword_subsector_mapping[kwd]}"
    
    def _embed_keywords(self, concat_with_subsectors: bool):
        keywords = self.schema.all_keywords
        
        if concat_with_subsectors: 
            keywords = [self._keyword_subsector_concatenator(k) for k in keywords]
            
        keyword_embeddings = self.encoder.encode(keywords, convert_to_tensor=True)
        
        if self._normalise_vectors:
            keyword_embeddings = torch.nn.functional.normalize(keyword_embeddings, p=2, dim=1)
            
        return keyword_embeddings
            
    def predict(self, query: str, threshold: float):
        query_embedding = self.encoder.encode(query, convert_to_tensor=True)
        if self._normalise_vectors:
            query_embedding = torch.nn.functional.normalize(query_embedding, p=2, dim=0)
        
        cos_scores = sbert_utils.cos_sim(query_embedding, self._keyword_embeddings)[0]
        scores_above_threshold_idxs = torch.where(cos_scores >= threshold)[0]
        
        results = []
        
        for idx in scores_above_threshold_idxs:
            _keyword = self.schema.all_keywords[idx]
            _subsector = self._keyword_subsector_mapping[_keyword]
            _score = cos_scores[idx]
            
            results.append(
                (_keyword, _subsector, round(_score.item(), 4))
            )
            
        results.sort(key=lambda x: x[2], reverse=True)
        
        return results


In [129]:
# Prepare fixed number of examples; set parameters for the rest of this section

n_examples = 30
examples = df.sample(n_examples, random_state=99)['text'].tolist()

### 1a) Asymmetric semantic search

In [124]:
"""
Model choices (from https://www.sbert.net/docs/pretrained-models/msmarco-v3.html):
- msmarco-MiniLM-L-6-v3: tuned for cosine similarity, which prefers retrieving shorter passages
- msmarco-distilbert-base-v4: same, but larger
- msmarco-distilbert-base-tas-b: tuned for dot product, which prefers retrieval of longer passages
"""

clf_instruments_asymm_cosine = CosineDistanceClassifier(
    schema = instruments, 
    sbert_model = "msmarco-MiniLM-L-6-v3", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)

clf_sectors_asymm_cosine = CosineDistanceClassifier(
    schema = sectors, 
    sbert_model = "msmarco-MiniLM-L-6-v3", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)


In [130]:
THRESHOLD = 0.35

for _str in examples:
    print(_str)
    print()
    print("INSTRUMENT PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_instruments_asymm_cosine.predict(_str, THRESHOLD)]))
    print()
    print("SECTOR PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_sectors_asymm_cosine.predict(_str, THRESHOLD)]))
    print("----------------------")

Sources of Energy Supply At present, Bangladesh has energy supply from both renewable and nonrenewable sources, 38 percent of which comes from biomass (Figure 3.1).

INSTRUMENT PREDICTIONS
- ('fossil fuel subsidy', 'Fiscal or financial incentives', 0.4097)
- ('tariff', 'Fiscal or financial incentives', 0.3522)

SECTOR PREDICTIONS
- ('fossil fuels', 'Energy production', 0.4043)
- ('fossil fuel fires', 'Energy production', 0.38)
- ('electricity subsidies', 'Energy use', 0.3708)
- ('renewable energy', 'Energy production', 0.3685)
- ('solar energy', 'Energy production', 0.362)
- ('forest tundra', 'Forestry', 0.3572)
- ('energy technology', 'Energy use', 0.3555)
----------------------
To put that in real-world context, roughly 35 jobs are created for each million board feet of wood processed.

INSTRUMENT PREDICTIONS


SECTOR PREDICTIONS

----------------------
Research on the likelihood of disasters and the assessment of the likely social, economic and environmental impacts will be conducte

### 1b) Symmetric semantic search

In [126]:
"""
Model choices (from https://www.sbert.net/docs/pretrained_models.html#semantic-search):
- multi-qa-MiniLM-L6-cos-v1: tuned for cosine similarity, which prefers retrieving shorter passages
- multi-qa-MiniLM-L6-dot-v1: tuned for dot product, which prefers retrieval of longer passages
"""

emb_symsearch = SentenceTransformer("multi-qa-MiniLM-L6-dot-v1")

clf_instruments_symm_cosine = CosineDistanceClassifier(
    schema = instruments, 
    sbert_model = "multi-qa-MiniLM-L6-cos-v1", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)

clf_sectors_symm_cosine = CosineDistanceClassifier(
    schema = sectors, 
    sbert_model = "multi-qa-MiniLM-L6-cos-v1", 
    distance_measure= "cosine",
    concat_keywords_with_subsectors= True,
)


In [128]:
THRESHOLD = 0.4

for _str in examples:
    print(_str)
    print()
    print("INSTRUMENT PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_instruments_symm_cosine.predict(_str, THRESHOLD)]))
    print()
    print("SECTOR PREDICTIONS")
    print("\n".join([f"- {pred}" for pred in clf_sectors_symm_cosine.predict(_str, THRESHOLD)]))
    print("----------------------")

Sources of Energy Supply At present, Bangladesh has energy supply from both renewable and nonrenewable sources, 38 percent of which comes from biomass (Figure 3.1).

INSTRUMENT PREDICTIONS


SECTOR PREDICTIONS
- ('renewables', 'Energy production', 0.4862)
- ('renewable energy', 'Energy production', 0.4853)
- ('energy demand', 'Energy production', 0.468)
- ('fossil fuels', 'Energy production', 0.4601)
- ('offshore', 'Energy production', 0.4592)
- ('onshore', 'Energy production', 0.4528)
- ('natural gas', 'Energy production', 0.4483)
- ('fossil fuel fires', 'Energy production', 0.4445)
- ('biofuels', 'Energy production', 0.4425)
- ('energy production', 'Energy production', 0.4246)
- ('energy', 'Energy (general)', 0.4242)
- ('energy industries', 'Energy production', 0.4236)
- ('fuels', 'Energy production', 0.4177)
- ('nuclear energy', 'Energy production', 0.4138)
- ('power', 'Energy use', 0.4129)
- ('gas', 'Energy production', 0.4089)
- ('energy use', 'Energy use', 0.4076)
- ('bioethanol'