# Preliminary Setup

In [9]:
import torch
import numpy as np
import datasets
import os
#import umap
import sys
import evaluate
import seaborn as sns
from pathlib import Path
from itertools import product
from IPython.core.debugger import set_trace
from datasets import Dataset, DatasetDict
from torch import nn
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
from nltk import sent_tokenize
from IPython.core.debugger import Pdb, set_trace
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import util
#from tqdm.notebook import tqdm
from tqdm import tqdm
from numpy.lib.stride_tricks import sliding_window_view
from pprint import pprint
from scipy.cluster.hierarchy import linkage

from nbtools.utils import files, strings
from nbtools.sent_encoders import from_hf

datasets.disable_caching()

cache_dir = '/data/john/cache'
proot = files.project_root()

# Set this to whatever you want
seed = 10

torch.manual_seed(seed)
np.random.seed(seed)

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Datasets, Standardize Column Names, and Aggregate

In [10]:
allsides_dir = '/data/john/datasets/all_sides/test.json'
ppp_dir = '/data/john/datasets/privacy_policy/3p_data.csv'

raw_ds = DatasetDict({
    'allsides': Dataset.from_json(allsides_dir),
    'ppp': Dataset.from_csv(ppp_dir),
})

cols = {
    'allsides': ['Left',
                 'Right',
                 'Ahmed_Intersection',
                 'Naman_Intersection',
                 'Helen_Intersection',
                 'AllSides_Intersection',],
    'ppp': ['Company_1',
            'Company_2',
            'Annotator1',
            'Annotator2',
            'Annotator3']
}

col_maps = {
    'allsides': {'Left': 'd1',
                 'Right': 'd2',
                 'Ahmed_Intersection': 'ref0',
                 'Naman_Intersection': 'ref1',
                 'Helen_Intersection': 'ref2',
                 'AllSides_Intersection': 'ref3'},
    'ppp': {'Company_1': 'd1',
            'Company_2': 'd2',
            'Annotator1': 'ref0',
            'Annotator2': 'ref1',
            'Annotator3': 'ref2'}
}

# remove extraneous columns
ds = DatasetDict({})
keep_cols = set(col_maps['allsides'].values())
for ds_key, ds_val in raw_ds.items():
    ds[ds_key] = ds_val.remove_columns(set(ds_val.features.keys()) - set(cols[ds_key]))


# standardize column names
for ds_key, ds_val in ds.items():
    for old_name, new_name in col_maps[ds_key].items():
        ds_val = ds_val.rename_column(old_name, new_name)
    ds[ds_key] = ds_val

# add ds name as column to both datasets
for ds_key, ds_val in ds.items():
    ds[ds_key] = ds_val.add_column('name', [ds_key]*len(ds_val))

# concatenate datasets
ds['agg'] = datasets.concatenate_datasets(ds.values())

print(ds)

DatasetDict({
    allsides: Dataset({
        features: ['d1', 'd2', 'ref0', 'ref1', 'ref2', 'ref3', 'name'],
        num_rows: 137
    })
    ppp: Dataset({
        features: ['d1', 'd2', 'ref0', 'ref1', 'ref2', 'name'],
        num_rows: 135
    })
    agg: Dataset({
        features: ['d1', 'd2', 'ref0', 'ref1', 'ref2', 'ref3', 'name'],
        num_rows: 272
    })
})


# Load Model


In [11]:
model_name = 'mixedbread-ai/mxbai-embed-large-v1'
model = from_hf(model_name, 
                emb_dim=1024, 
                max_seq_len=512,
                cache_dir=f'{proot}/cache')
print(model)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)


# Get Samples for Annotation

In [12]:
agg = ds['agg']
print(agg[136]['name'])

ppp_start = 137
all_sides_idx = np.random.randint(0, ppp_start, (2,))
ppp_idx = np.random.randint(ppp_start, ppp_start+len(ds['ppp']), (2,))

#sample_ids = np.concatenate((all_sides_idx, ppp_idx))
sample_ids = np.array([9, 125, 152, 201])
print(sample_ids)

allsides
[  9 125 152 201]


In [13]:
samples= agg.select(sample_ids)
sid = 3
sample = samples[sid]

# sent tokenize each doc
sents1 = sent_tokenize(sample['d1'])
sents2 = sent_tokenize(sample['d2'])

# even out list lengths
len_max= max(len(sents1), len(sents2))
sents1 += ['']*(len_max - len(sents1))
sents2 += ['']*(len_max - len(sents2))


# create md table for docs
doc_table = '| id | d1 | d2 |\n| - | - |\n'
for i, (s1, s2) in enumerate(zip(sents1, sents2)):
    s1 = s1.strip().replace('\n', ' ')
    s2 = s2.strip().replace('\n', ' ')
    doc_table += f'| {i} | {s1} | {s2} |\n'
print(doc_table)

# create md table for refs
ref_table = (
    f'| refs             |\n'
    f'| ----             |\n'
    f'| {sample["ref0"]} |\n'
    f'| {sample["ref1"]} |\n'
    f'| {sample["ref2"]} |\n'
)
if sample['ref3'] is not None:
    ref_table += f'| {sample["ref3"]} |\n'
print(ref_table)


| id | d1 | d2 |
| - | - |
| 0 | Privacy Policy   Important Update In September 2012, we announced that Instagram had been acquired by Facebook. | Welcome to the Google Privacy Policy When you use Google services, you trust us with your information. |
| 1 | We knew that by teaming up with Facebook, we could build a better Instagram for you. | This Privacy Policy is meant to help you understand what data we collect, why we collect it, and what we do with it. |
| 2 | Since then, we've been collaborating with Facebook's team on ways to do just that. | This is important; we hope you will take time to read it carefully. |
| 3 | As part of our new collaboration, we've learned that by being able to share insights and information with each other, we can build better experiences for our users. | And remember, you can find controls to manage your information and protect your privacy and security at My Account. |
| 4 | We're updating our Privacy Policy to highlight this new collaboration, but we 

In [14]:
import json
print(json.dumps(raw_ds['allsides'][9], indent=4))

{
    "Key": "news9",
    "Left": "The Russian government successfully obtained access to U.S. voter registration databases in multiple states prior to the 2016 election, the federal official responsible for monitoring hacking said. Jeannette Manfra, the head of cybersecurity at the Department of Homeland Security, told NBC News Thursday that Russia targeted 21 states and managed to actually penetrate \"an exceptionally small number of them\" in an interview published Thursday. \"We were able to determine that the scanning and probing of voter registration databases was coming from the Russian government,\" Manfra added. Five states, including Texas and California, denied that they ever suffered attacks. Jeh Johnson, who was DHS secretary at the time, told NBC that states and the federal government should \"do something about it,\" though he lamented many of the targeted states haven't taken action since the election. Manfra disagreed, claiming that \"they have all taken it seriously.\

# Build Groups

In [25]:
def l2_dist(a, b):
    if type(a) == np.ndarray:
        a_sqr = np.sum(a**2, keepdims=True, axis=-1)
        b_sqr = np.sum(b**2, keepdims=True, axis=-1).T
        dists = np.sqrt(a_sqr + b_sqr - 2*a@b.T) 
        return dists
    elif type(a) == torch.Tensor:
        pass

trgt = 1
sample = samples[trgt]

st = 0.73
dt = 13.5

print(sample['d1'])
print(sample['d2'])
s1 = sent_tokenize(sample['d1'])
s2 = sent_tokenize(sample['d2'])

emb1 = model.encode(s1)
emb2 = model.encode(s2)

sim_scores = util.cos_sim(emb1, emb2).cpu().numpy()
dists = l2_dist(emb1, emb2)


sim_preds = (sim_scores > st).astype(int)
dist_preds = (dists < dt).astype(int)

print(f'\npairs within cosine similarity threshold t={st}')
print(sim_preds.shape)
print(sim_preds)

print(f'\npairs within distance threshold t={dt}')
print(dist_preds.shape)
print(dist_preds)

np.set_printoptions(precision=2, linewidth=400)
print(f'\nsim scores (avg={np.mean(sim_scores)}:\n{sim_scores}')
print(f'\ndists (avg={np.mean(dists)}):\n{dists}')

IndentationError: expected an indented block after function definition on line 24 (1298193447.py, line 25)