In [1]:
import re
import pandas as pd
import numpy as np
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams

from torch import nn
import torch.nn.functional as F
from transformers import RobertaConfig, BertTokenizerFast, BertModel, RobertaTokenizer, RobertaModel, AdamW,RobertaTokenizerFast
from tqdm import tqdm

from model_amazon_self_attention import SRoberta,DNNSelfAttention,AttentionPooling

from sklearn.decomposition import PCA
from scipy.stats import spearmanr
from sklearn.utils import shuffle
from pingouin import ttest

In [2]:
device = 'cuda'
model_path = '/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/authorship_models/final-roberta-cosine-modified_anchor-mask-0.1-delta-0.4-0.6-alpha-30.0/model-5'
model = torch.load(model_path).to(device)
model.eval()
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

In [7]:
def extract_emb(text,mask=False,prob=0.05):
    tokenized = tokenizer.encode_plus(text,add_special_tokens=True, max_length=102,truncation=True,return_tensors="pt")
    if mask == True:
        tokenized['input_ids'] = masking(tokenized['input_ids'],mlm_prob=0.1)
    hidden = model(tokenized['input_ids'].to(device),tokenized['attention_mask'].to(device))
    hidden = F.normalize(hidden,dim=-1)
    hidden = hidden.cpu().detach().numpy()
    return hidden

In [11]:
def get_analogy(texta,textb,mask=False):
    
    # Extract embeddings
    emba = np.stack([extract_emb(i) for i in tqdm(texta)])
    if mask == True:
        embb = np.stack([extract_emb(i,mask=True) for i in tqdm(textb)])
    else:
        embb = np.stack([extract_emb(i) for i in tqdm(textb)])
    
    emba = torch.tensor(emba).to(device).squeeze()
    embb = torch.tensor(embb).to(device).squeeze()
    print(emba.shape)
    print(embb.shape)
    # Get pairwise similarity
    diff = F.normalize(emba - embb,dim=-1)
    sim = torch.matmul(diff,diff.transpose(1,0))
    sim = torch.tril(sim, diagonal=-1)
    sim = sim.cpu().detach().numpy()
    # Get mean similarity
    count = np.sum(np.where(sim!=0,1,0))
    if count == 0.0:
        return 0.0
    else:
        mean = np.sum(sim)/count
        return mean
    
    
def masking(text,mlm_prob=0.1,mask=50264):
    
    indices_replaced = torch.bernoulli(torch.full(text.shape, mlm_prob)).bool()
    text[indices_replaced] = mask
    return text

In [3]:
data = pd.read_csv('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/authorship/amazon_test_samples',sep='\t')

In [None]:
def masking(text,mlm_prob=0.1,mask=50264):
    
    indices_replaced = torch.bernoulli(torch.full(text.shape, mlm_prob)).bool()
    text[indices_replaced] = mask
    return text
    


# baseline
texta = [i for i in data.sample(1000)['text']]
textb = [i for i in data.sample(1000)['text']]

out = get_analogy(texta,texta,mask=True)
print(out)

In [None]:
#and to &
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace('and','&') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# . to space
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace('.',' ') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# ! to !!!
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace('!','!!!!') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# lower case
texta = [i for i in data.sample(1000)['text']]
textb = [i.lower() for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# upper case
texta = [i for i in data.sample(1000)['text']]
textb = [i.upper() for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace(' I ',' ') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# going to to gonna

texta = [i for i in data['text'] if len(re.findall(r'going to',' '.join(i.split(' ')[:100]))) >=2 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('going to','gonna') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# want to to wanna

texta = [i for i in data['text'] if len(re.findall(r'want to',' '.join(i.split(' ')[:100]))) >=2 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('want to','wonna') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# -ing to -in'

texta = [i for i in data['text'] if len(re.findall(r'ing ',' '.join(i.split(' ')[:100]))) >=3 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('ing ',"in' ") for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
texta = [i for i in data['text'] if len(re.findall(r'good',' '.join(i.split(' ')[:100]))) >=3 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('good','goooooood') for i in texta]

out = get_analogy(texta,textb)
print(out)


In [15]:
# random baseline
original = [i for i in data.sample(2000)['text']]
less = [i for i in data.sample(2000)['text']]
more = [i for i in data.sample(2000)['text']]

bsim, bmag = get_quant_analogy(original, original, original,mask=True)
print(np.mean(bsim))
print(np.mean(bmag))




  0%|          | 0/2000 [00:00<?, ?it/s][A[A

  0%|          | 8/2000 [00:00<00:25, 77.26it/s][A[A

  1%|          | 16/2000 [00:00<00:25, 77.03it/s][A[A

  1%|          | 24/2000 [00:00<00:25, 76.98it/s][A[A

  2%|▏         | 32/2000 [00:00<00:25, 76.02it/s][A[A

  2%|▏         | 40/2000 [00:00<00:25, 76.24it/s][A[A

  2%|▏         | 48/2000 [00:00<00:25, 76.37it/s][A[A

  3%|▎         | 56/2000 [00:00<00:25, 75.99it/s][A[A

  3%|▎         | 64/2000 [00:00<00:25, 75.47it/s][A[A

  4%|▎         | 72/2000 [00:00<00:25, 75.67it/s][A[A

  4%|▍         | 80/2000 [00:01<00:25, 76.06it/s][A[A

  4%|▍         | 88/2000 [00:01<00:25, 75.18it/s][A[A

  5%|▍         | 96/2000 [00:01<00:25, 75.35it/s][A[A

  5%|▌         | 104/2000 [00:01<00:25, 75.53it/s][A[A

  6%|▌         | 112/2000 [00:01<00:24, 75.85it/s][A[A

  6%|▌         | 120/2000 [00:01<00:24, 76.35it/s][A[A

  6%|▋         | 128/2000 [00:01<00:24, 75.89it/s][A[A

  7%|▋         | 136/2000 [00:01<00

 56%|█████▋    | 1128/2000 [00:14<00:11, 76.80it/s][A[A

 57%|█████▋    | 1136/2000 [00:14<00:11, 76.63it/s][A[A

 57%|█████▋    | 1144/2000 [00:14<00:11, 76.62it/s][A[A

 58%|█████▊    | 1152/2000 [00:15<00:11, 76.42it/s][A[A

 58%|█████▊    | 1160/2000 [00:15<00:10, 76.43it/s][A[A

 58%|█████▊    | 1168/2000 [00:15<00:10, 76.37it/s][A[A

 59%|█████▉    | 1176/2000 [00:15<00:10, 76.59it/s][A[A

 59%|█████▉    | 1184/2000 [00:15<00:10, 76.48it/s][A[A

 60%|█████▉    | 1192/2000 [00:15<00:10, 76.66it/s][A[A

 60%|██████    | 1200/2000 [00:15<00:10, 76.54it/s][A[A

 60%|██████    | 1208/2000 [00:15<00:10, 76.44it/s][A[A

 61%|██████    | 1216/2000 [00:15<00:10, 76.45it/s][A[A

 61%|██████    | 1224/2000 [00:15<00:10, 76.29it/s][A[A

 62%|██████▏   | 1232/2000 [00:16<00:10, 75.81it/s][A[A

 62%|██████▏   | 1240/2000 [00:16<00:10, 75.07it/s][A[A

 62%|██████▏   | 1248/2000 [00:16<00:09, 75.80it/s][A[A

 63%|██████▎   | 1256/2000 [00:16<00:09, 75.41it/s][A[

  6%|▌         | 116/2000 [00:02<00:47, 39.34it/s][A[A

  6%|▌         | 120/2000 [00:03<00:47, 39.43it/s][A[A

  6%|▌         | 124/2000 [00:03<00:47, 39.14it/s][A[A

  6%|▋         | 128/2000 [00:03<00:47, 39.18it/s][A[A

  7%|▋         | 132/2000 [00:03<00:47, 39.26it/s][A[A

  7%|▋         | 136/2000 [00:03<00:47, 39.06it/s][A[A

  7%|▋         | 140/2000 [00:03<00:47, 39.19it/s][A[A

  7%|▋         | 144/2000 [00:03<00:47, 39.18it/s][A[A

  7%|▋         | 148/2000 [00:03<00:47, 39.22it/s][A[A

  8%|▊         | 152/2000 [00:03<00:46, 39.33it/s][A[A

  8%|▊         | 156/2000 [00:03<00:46, 39.26it/s][A[A

  8%|▊         | 160/2000 [00:04<00:46, 39.22it/s][A[A

  8%|▊         | 164/2000 [00:04<00:46, 39.30it/s][A[A

  8%|▊         | 168/2000 [00:04<00:46, 39.26it/s][A[A

  9%|▊         | 172/2000 [00:04<00:46, 39.29it/s][A[A

  9%|▉         | 176/2000 [00:04<00:46, 39.28it/s][A[A

  9%|▉         | 180/2000 [00:04<00:46, 39.19it/s][A[A

  9%|▉        

 34%|███▍      | 680/2000 [00:17<00:33, 39.05it/s][A[A

 34%|███▍      | 684/2000 [00:17<00:33, 39.10it/s][A[A

 34%|███▍      | 688/2000 [00:17<00:33, 39.04it/s][A[A

 35%|███▍      | 692/2000 [00:17<00:33, 39.14it/s][A[A

 35%|███▍      | 696/2000 [00:17<00:33, 39.22it/s][A[A

 35%|███▌      | 700/2000 [00:17<00:33, 39.24it/s][A[A

 35%|███▌      | 704/2000 [00:17<00:33, 39.26it/s][A[A

 35%|███▌      | 708/2000 [00:18<00:32, 39.31it/s][A[A

 36%|███▌      | 712/2000 [00:18<00:32, 39.27it/s][A[A

 36%|███▌      | 716/2000 [00:18<00:32, 39.28it/s][A[A

 36%|███▌      | 720/2000 [00:18<00:32, 39.33it/s][A[A

 36%|███▌      | 724/2000 [00:18<00:32, 39.18it/s][A[A

 36%|███▋      | 728/2000 [00:18<00:32, 39.19it/s][A[A

 37%|███▋      | 732/2000 [00:18<00:32, 39.22it/s][A[A

 37%|███▋      | 736/2000 [00:18<00:32, 39.12it/s][A[A

 37%|███▋      | 740/2000 [00:18<00:32, 39.20it/s][A[A

 37%|███▋      | 744/2000 [00:18<00:31, 39.26it/s][A[A

 37%|███▋     

 62%|██████▏   | 1240/2000 [00:31<00:19, 38.33it/s][A[A

 62%|██████▏   | 1244/2000 [00:31<00:19, 38.58it/s][A[A

 62%|██████▏   | 1248/2000 [00:31<00:19, 38.77it/s][A[A

 63%|██████▎   | 1252/2000 [00:31<00:19, 38.62it/s][A[A

 63%|██████▎   | 1256/2000 [00:32<00:19, 38.82it/s][A[A

 63%|██████▎   | 1260/2000 [00:32<00:19, 38.89it/s][A[A

 63%|██████▎   | 1264/2000 [00:32<00:18, 38.94it/s][A[A

 63%|██████▎   | 1268/2000 [00:32<00:18, 38.96it/s][A[A

 64%|██████▎   | 1272/2000 [00:32<00:18, 39.04it/s][A[A

 64%|██████▍   | 1276/2000 [00:32<00:18, 38.68it/s][A[A

 64%|██████▍   | 1280/2000 [00:32<00:18, 38.68it/s][A[A

 64%|██████▍   | 1284/2000 [00:32<00:18, 38.72it/s][A[A

 64%|██████▍   | 1288/2000 [00:32<00:18, 38.78it/s][A[A

 65%|██████▍   | 1292/2000 [00:33<00:18, 38.85it/s][A[A

 65%|██████▍   | 1296/2000 [00:33<00:18, 38.96it/s][A[A

 65%|██████▌   | 1300/2000 [00:33<00:17, 38.92it/s][A[A

 65%|██████▌   | 1304/2000 [00:33<00:17, 38.95it/s][A[

 90%|████████▉ | 1792/2000 [00:45<00:05, 38.88it/s][A[A

 90%|████████▉ | 1796/2000 [00:45<00:05, 38.84it/s][A[A

 90%|█████████ | 1800/2000 [00:46<00:05, 38.83it/s][A[A

 90%|█████████ | 1804/2000 [00:46<00:05, 38.93it/s][A[A

 90%|█████████ | 1808/2000 [00:46<00:04, 38.91it/s][A[A

 91%|█████████ | 1812/2000 [00:46<00:04, 38.96it/s][A[A

 91%|█████████ | 1816/2000 [00:46<00:04, 39.04it/s][A[A

 91%|█████████ | 1820/2000 [00:46<00:04, 39.03it/s][A[A

 91%|█████████ | 1824/2000 [00:46<00:04, 39.02it/s][A[A

 91%|█████████▏| 1828/2000 [00:46<00:04, 39.04it/s][A[A

 92%|█████████▏| 1832/2000 [00:46<00:04, 39.07it/s][A[A

 92%|█████████▏| 1836/2000 [00:47<00:04, 38.94it/s][A[A

 92%|█████████▏| 1840/2000 [00:47<00:04, 38.93it/s][A[A

 92%|█████████▏| 1844/2000 [00:47<00:04, 38.90it/s][A[A

 92%|█████████▏| 1848/2000 [00:47<00:03, 38.86it/s][A[A

 93%|█████████▎| 1852/2000 [00:47<00:03, 38.84it/s][A[A

 93%|█████████▎| 1856/2000 [00:47<00:03, 38.93it/s][A[

0.68611115
0.05959391





In [14]:
# I to null

samples = [i for i in data['text'] if len(re.findall(r'I ',' '.join(i.split(' ')[:100])))>=10]
samples = shuffle(samples[:2000])


original = [i.lower() for i in samples]
less = [re.sub(r'i ',' ',i.lower(),count=4) for i in original]
more = [re.sub(r'i ',' ',i.lower()) for i in original]


sim, mag = get_quant_analogy(original, less, more)
print(np.mean(sim))
print(np.mean(mag))




  0%|          | 0/1804 [00:00<?, ?it/s][A[A

  0%|          | 8/1804 [00:00<00:23, 77.96it/s][A[A

  1%|          | 16/1804 [00:00<00:23, 77.31it/s][A[A

  1%|▏         | 24/1804 [00:00<00:23, 76.41it/s][A[A

  2%|▏         | 32/1804 [00:00<00:23, 76.12it/s][A[A

  2%|▏         | 40/1804 [00:00<00:23, 76.03it/s][A[A

  3%|▎         | 48/1804 [00:00<00:23, 76.14it/s][A[A

  3%|▎         | 56/1804 [00:00<00:23, 75.72it/s][A[A

  4%|▎         | 64/1804 [00:00<00:22, 75.85it/s][A[A

  4%|▍         | 72/1804 [00:00<00:22, 75.80it/s][A[A

  4%|▍         | 80/1804 [00:01<00:22, 75.74it/s][A[A

  5%|▍         | 88/1804 [00:01<00:22, 75.37it/s][A[A

  5%|▌         | 96/1804 [00:01<00:22, 75.74it/s][A[A

  6%|▌         | 104/1804 [00:01<00:22, 75.41it/s][A[A

  6%|▌         | 112/1804 [00:01<00:22, 75.42it/s][A[A

  7%|▋         | 120/1804 [00:01<00:22, 75.58it/s][A[A

  7%|▋         | 128/1804 [00:01<00:22, 75.32it/s][A[A

  8%|▊         | 136/1804 [00:01<00

 63%|██████▎   | 1128/1804 [00:14<00:08, 75.74it/s][A[A

 63%|██████▎   | 1136/1804 [00:15<00:08, 75.73it/s][A[A

 63%|██████▎   | 1144/1804 [00:15<00:08, 75.70it/s][A[A

 64%|██████▍   | 1152/1804 [00:15<00:08, 75.70it/s][A[A

 64%|██████▍   | 1160/1804 [00:15<00:08, 75.86it/s][A[A

 65%|██████▍   | 1168/1804 [00:15<00:08, 75.49it/s][A[A

 65%|██████▌   | 1176/1804 [00:15<00:08, 75.27it/s][A[A

 66%|██████▌   | 1184/1804 [00:15<00:08, 75.27it/s][A[A

 66%|██████▌   | 1192/1804 [00:15<00:08, 75.43it/s][A[A

 67%|██████▋   | 1200/1804 [00:15<00:08, 75.37it/s][A[A

 67%|██████▋   | 1208/1804 [00:15<00:07, 75.50it/s][A[A

 67%|██████▋   | 1216/1804 [00:16<00:07, 75.44it/s][A[A

 68%|██████▊   | 1224/1804 [00:16<00:07, 75.52it/s][A[A

 68%|██████▊   | 1232/1804 [00:16<00:07, 75.60it/s][A[A

 69%|██████▊   | 1240/1804 [00:16<00:07, 75.73it/s][A[A

 69%|██████▉   | 1248/1804 [00:16<00:07, 76.01it/s][A[A

 70%|██████▉   | 1256/1804 [00:16<00:07, 75.85it/s][A[

 24%|██▍       | 440/1804 [00:05<00:18, 74.61it/s][A[A

 25%|██▍       | 448/1804 [00:05<00:18, 75.02it/s][A[A

 25%|██▌       | 456/1804 [00:06<00:17, 75.00it/s][A[A

 26%|██▌       | 464/1804 [00:06<00:17, 74.52it/s][A[A

 26%|██▌       | 472/1804 [00:06<00:17, 74.82it/s][A[A

 27%|██▋       | 480/1804 [00:06<00:17, 75.17it/s][A[A

 27%|██▋       | 488/1804 [00:06<00:17, 75.53it/s][A[A

 27%|██▋       | 496/1804 [00:06<00:17, 75.78it/s][A[A

 28%|██▊       | 504/1804 [00:06<00:17, 75.13it/s][A[A

 28%|██▊       | 512/1804 [00:06<00:17, 75.27it/s][A[A

 29%|██▉       | 520/1804 [00:06<00:16, 75.54it/s][A[A

 29%|██▉       | 528/1804 [00:07<00:16, 75.41it/s][A[A

 30%|██▉       | 536/1804 [00:07<00:16, 75.53it/s][A[A

 30%|███       | 544/1804 [00:07<00:16, 75.26it/s][A[A

 31%|███       | 552/1804 [00:07<00:16, 75.36it/s][A[A

 31%|███       | 560/1804 [00:07<00:16, 75.11it/s][A[A

 31%|███▏      | 568/1804 [00:07<00:16, 75.13it/s][A[A

 32%|███▏     

 86%|████████▋ | 1560/1804 [00:20<00:03, 75.34it/s][A[A

 87%|████████▋ | 1568/1804 [00:20<00:03, 75.48it/s][A[A

 87%|████████▋ | 1576/1804 [00:20<00:03, 75.02it/s][A[A

 88%|████████▊ | 1584/1804 [00:20<00:02, 75.06it/s][A[A

 88%|████████▊ | 1592/1804 [00:21<00:02, 75.31it/s][A[A

 89%|████████▊ | 1600/1804 [00:21<00:02, 75.22it/s][A[A

 89%|████████▉ | 1608/1804 [00:21<00:02, 75.28it/s][A[A

 90%|████████▉ | 1616/1804 [00:21<00:02, 75.50it/s][A[A

 90%|█████████ | 1624/1804 [00:21<00:02, 75.60it/s][A[A

 90%|█████████ | 1632/1804 [00:21<00:02, 75.72it/s][A[A

 91%|█████████ | 1640/1804 [00:21<00:02, 75.99it/s][A[A

 91%|█████████▏| 1648/1804 [00:21<00:02, 75.53it/s][A[A

 92%|█████████▏| 1656/1804 [00:21<00:01, 74.99it/s][A[A

 92%|█████████▏| 1664/1804 [00:22<00:01, 74.86it/s][A[A

 93%|█████████▎| 1672/1804 [00:22<00:01, 75.35it/s][A[A

 93%|█████████▎| 1680/1804 [00:22<00:01, 75.43it/s][A[A

 94%|█████████▎| 1688/1804 [00:22<00:01, 74.95it/s][A[

 49%|████▉     | 880/1804 [00:11<00:12, 75.92it/s][A[A

 49%|████▉     | 888/1804 [00:11<00:12, 76.02it/s][A[A

 50%|████▉     | 896/1804 [00:11<00:11, 75.84it/s][A[A

 50%|█████     | 904/1804 [00:11<00:11, 75.92it/s][A[A

 51%|█████     | 912/1804 [00:12<00:11, 75.93it/s][A[A

 51%|█████     | 920/1804 [00:12<00:11, 75.86it/s][A[A

 51%|█████▏    | 928/1804 [00:12<00:11, 75.77it/s][A[A

 52%|█████▏    | 936/1804 [00:12<00:11, 75.98it/s][A[A

 52%|█████▏    | 944/1804 [00:12<00:11, 76.16it/s][A[A

 53%|█████▎    | 952/1804 [00:12<00:11, 76.16it/s][A[A

 53%|█████▎    | 960/1804 [00:12<00:11, 76.26it/s][A[A

 54%|█████▎    | 968/1804 [00:12<00:10, 76.03it/s][A[A

 54%|█████▍    | 976/1804 [00:12<00:10, 75.62it/s][A[A

 55%|█████▍    | 984/1804 [00:13<00:10, 75.66it/s][A[A

 55%|█████▍    | 992/1804 [00:13<00:10, 74.83it/s][A[A

 55%|█████▌    | 1000/1804 [00:13<00:10, 75.21it/s][A[A

 56%|█████▌    | 1008/1804 [00:13<00:10, 75.40it/s][A[A

 56%|█████▋ 

0.8596929
0.27536327





In [10]:
samples = [i for i in data['text'] if len(re.findall(r'and ',' '.join(i.split(' ')[:100])))>=6]
samples = shuffle(samples[:2000])


original = [i for i in samples]
less = [re.sub(r'and ','& ',i,count=3) for i in original]
more = [re.sub(r'and ','& ',i) for i in original]


sim, mag = get_quant_analogy(original, less, more)

print(np.mean(sim))
print(np.mean(mag))




  0%|          | 0/2000 [00:00<?, ?it/s][A[A

  0%|          | 8/2000 [00:00<00:25, 77.33it/s][A[A

  1%|          | 16/2000 [00:00<00:25, 76.58it/s][A[A

  1%|          | 24/2000 [00:00<00:25, 76.34it/s][A[A

  2%|▏         | 32/2000 [00:00<00:26, 75.57it/s][A[A

  2%|▏         | 40/2000 [00:00<00:25, 75.98it/s][A[A

  2%|▏         | 48/2000 [00:00<00:25, 75.72it/s][A[A

  3%|▎         | 56/2000 [00:00<00:25, 75.77it/s][A[A

  3%|▎         | 64/2000 [00:00<00:25, 75.93it/s][A[A

  4%|▎         | 72/2000 [00:00<00:25, 76.17it/s][A[A

  4%|▍         | 80/2000 [00:01<00:25, 75.60it/s][A[A

  4%|▍         | 88/2000 [00:01<00:25, 75.52it/s][A[A

  5%|▍         | 96/2000 [00:01<00:25, 75.10it/s][A[A

  5%|▌         | 104/2000 [00:01<00:25, 75.40it/s][A[A

  6%|▌         | 112/2000 [00:01<00:25, 75.39it/s][A[A

  6%|▌         | 120/2000 [00:01<00:24, 75.57it/s][A[A

  6%|▋         | 128/2000 [00:01<00:24, 75.68it/s][A[A

  7%|▋         | 136/2000 [00:01<00

 56%|█████▋    | 1128/2000 [00:14<00:11, 75.22it/s][A[A

 57%|█████▋    | 1136/2000 [00:15<00:11, 75.48it/s][A[A

 57%|█████▋    | 1144/2000 [00:15<00:11, 75.75it/s][A[A

 58%|█████▊    | 1152/2000 [00:15<00:11, 75.69it/s][A[A

 58%|█████▊    | 1160/2000 [00:15<00:11, 76.01it/s][A[A

 58%|█████▊    | 1168/2000 [00:15<00:11, 75.02it/s][A[A

 59%|█████▉    | 1176/2000 [00:15<00:10, 75.34it/s][A[A

 59%|█████▉    | 1184/2000 [00:15<00:10, 75.65it/s][A[A

 60%|█████▉    | 1192/2000 [00:15<00:10, 76.02it/s][A[A

 60%|██████    | 1200/2000 [00:15<00:10, 76.27it/s][A[A

 60%|██████    | 1208/2000 [00:15<00:10, 76.31it/s][A[A

 61%|██████    | 1216/2000 [00:16<00:10, 76.17it/s][A[A

 61%|██████    | 1224/2000 [00:16<00:10, 76.16it/s][A[A

 62%|██████▏   | 1232/2000 [00:16<00:10, 75.96it/s][A[A

 62%|██████▏   | 1240/2000 [00:16<00:10, 75.86it/s][A[A

 62%|██████▏   | 1248/2000 [00:16<00:09, 76.14it/s][A[A

 63%|██████▎   | 1256/2000 [00:16<00:09, 76.19it/s][A[

 12%|█▏        | 230/2000 [00:03<00:23, 74.37it/s][A[A

 12%|█▏        | 238/2000 [00:03<00:23, 74.61it/s][A[A

 12%|█▏        | 246/2000 [00:03<00:23, 75.19it/s][A[A

 13%|█▎        | 254/2000 [00:03<00:23, 75.61it/s][A[A

 13%|█▎        | 262/2000 [00:03<00:22, 75.84it/s][A[A

 14%|█▎        | 270/2000 [00:03<00:22, 76.02it/s][A[A

 14%|█▍        | 278/2000 [00:03<00:22, 75.66it/s][A[A

 14%|█▍        | 286/2000 [00:03<00:22, 75.83it/s][A[A

 15%|█▍        | 294/2000 [00:03<00:22, 75.97it/s][A[A

 15%|█▌        | 302/2000 [00:04<00:22, 75.77it/s][A[A

 16%|█▌        | 310/2000 [00:04<00:22, 75.47it/s][A[A

 16%|█▌        | 318/2000 [00:04<00:22, 75.26it/s][A[A

 16%|█▋        | 326/2000 [00:04<00:22, 75.05it/s][A[A

 17%|█▋        | 334/2000 [00:04<00:22, 75.15it/s][A[A

 17%|█▋        | 342/2000 [00:04<00:22, 75.22it/s][A[A

 18%|█▊        | 350/2000 [00:04<00:21, 75.37it/s][A[A

 18%|█▊        | 358/2000 [00:04<00:21, 75.48it/s][A[A

 18%|█▊       

 68%|██████▊   | 1350/2000 [00:17<00:08, 75.93it/s][A[A

 68%|██████▊   | 1358/2000 [00:17<00:08, 75.76it/s][A[A

 68%|██████▊   | 1366/2000 [00:18<00:08, 76.03it/s][A[A

 69%|██████▊   | 1374/2000 [00:18<00:08, 76.19it/s][A[A

 69%|██████▉   | 1382/2000 [00:18<00:08, 76.44it/s][A[A

 70%|██████▉   | 1390/2000 [00:18<00:08, 76.23it/s][A[A

 70%|██████▉   | 1398/2000 [00:18<00:07, 76.03it/s][A[A

 70%|███████   | 1406/2000 [00:18<00:07, 75.84it/s][A[A

 71%|███████   | 1414/2000 [00:18<00:07, 75.97it/s][A[A

 71%|███████   | 1422/2000 [00:18<00:07, 76.03it/s][A[A

 72%|███████▏  | 1430/2000 [00:18<00:07, 75.77it/s][A[A

 72%|███████▏  | 1438/2000 [00:18<00:07, 75.50it/s][A[A

 72%|███████▏  | 1446/2000 [00:19<00:07, 75.87it/s][A[A

 73%|███████▎  | 1454/2000 [00:19<00:07, 75.88it/s][A[A

 73%|███████▎  | 1462/2000 [00:19<00:07, 75.90it/s][A[A

 74%|███████▎  | 1470/2000 [00:19<00:07, 75.52it/s][A[A

 74%|███████▍  | 1478/2000 [00:19<00:06, 75.05it/s][A[

 23%|██▎       | 464/2000 [00:06<00:20, 75.66it/s][A[A

 24%|██▎       | 472/2000 [00:06<00:20, 75.95it/s][A[A

 24%|██▍       | 480/2000 [00:06<00:20, 75.87it/s][A[A

 24%|██▍       | 488/2000 [00:06<00:20, 75.29it/s][A[A

 25%|██▍       | 496/2000 [00:06<00:19, 75.69it/s][A[A

 25%|██▌       | 504/2000 [00:06<00:19, 75.51it/s][A[A

 26%|██▌       | 512/2000 [00:06<00:19, 75.96it/s][A[A

 26%|██▌       | 520/2000 [00:06<00:19, 74.97it/s][A[A

 26%|██▋       | 528/2000 [00:06<00:19, 75.27it/s][A[A

 27%|██▋       | 536/2000 [00:07<00:19, 75.75it/s][A[A

 27%|██▋       | 544/2000 [00:07<00:19, 75.50it/s][A[A

 28%|██▊       | 552/2000 [00:07<00:19, 75.39it/s][A[A

 28%|██▊       | 560/2000 [00:07<00:19, 75.44it/s][A[A

 28%|██▊       | 568/2000 [00:07<00:18, 75.65it/s][A[A

 29%|██▉       | 576/2000 [00:07<00:18, 75.44it/s][A[A

 29%|██▉       | 584/2000 [00:07<00:18, 75.78it/s][A[A

 30%|██▉       | 592/2000 [00:07<00:18, 75.80it/s][A[A

 30%|███      

 79%|███████▉  | 1584/2000 [00:20<00:05, 75.22it/s][A[A

 80%|███████▉  | 1592/2000 [00:20<00:05, 75.48it/s][A[A

 80%|████████  | 1600/2000 [00:21<00:05, 75.12it/s][A[A

 80%|████████  | 1608/2000 [00:21<00:05, 75.57it/s][A[A

 81%|████████  | 1616/2000 [00:21<00:05, 75.72it/s][A[A

 81%|████████  | 1624/2000 [00:21<00:04, 75.76it/s][A[A

 82%|████████▏ | 1632/2000 [00:21<00:04, 74.75it/s][A[A

 82%|████████▏ | 1640/2000 [00:21<00:04, 75.33it/s][A[A

 82%|████████▏ | 1648/2000 [00:21<00:04, 75.26it/s][A[A

 83%|████████▎ | 1656/2000 [00:21<00:04, 75.49it/s][A[A

 83%|████████▎ | 1664/2000 [00:21<00:04, 75.82it/s][A[A

 84%|████████▎ | 1672/2000 [00:22<00:04, 75.83it/s][A[A

 84%|████████▍ | 1680/2000 [00:22<00:04, 76.19it/s][A[A

 84%|████████▍ | 1688/2000 [00:22<00:04, 76.13it/s][A[A

 85%|████████▍ | 1696/2000 [00:22<00:04, 75.76it/s][A[A

 85%|████████▌ | 1704/2000 [00:22<00:03, 75.91it/s][A[A

 86%|████████▌ | 1712/2000 [00:22<00:03, 75.94it/s][A[

0.9573367
0.17594744





In [12]:
samples = [i for i in data['text'] if len(re.findall(r'. ',' '.join(i.split(' ')[:100])))>=6]
samples = shuffle(samples[:2000])


original = [i for i in samples]
less = [re.sub(r'\. ','!!!! ',i,count=3) for i in original]
more = [re.sub(r'\. ','!!!! ',i) for i in original]


sim, mag = get_quant_analogy(original, less, more)

print(np.mean(sim))
print(np.mean(mag))




  0%|          | 0/2000 [00:00<?, ?it/s][A[A

  0%|          | 8/2000 [00:00<00:25, 77.83it/s][A[A

  1%|          | 16/2000 [00:00<00:25, 77.57it/s][A[A

  1%|          | 24/2000 [00:00<00:25, 77.28it/s][A[A

  2%|▏         | 32/2000 [00:00<00:25, 76.60it/s][A[A

  2%|▏         | 40/2000 [00:00<00:25, 76.59it/s][A[A

  2%|▏         | 48/2000 [00:00<00:25, 76.80it/s][A[A

  3%|▎         | 56/2000 [00:00<00:25, 76.82it/s][A[A

  3%|▎         | 64/2000 [00:00<00:25, 76.81it/s][A[A

  4%|▎         | 72/2000 [00:00<00:25, 76.93it/s][A[A

  4%|▍         | 80/2000 [00:01<00:24, 77.03it/s][A[A

  4%|▍         | 88/2000 [00:01<00:24, 76.50it/s][A[A

  5%|▍         | 96/2000 [00:01<00:24, 76.62it/s][A[A

  5%|▌         | 104/2000 [00:01<00:24, 76.63it/s][A[A

  6%|▌         | 112/2000 [00:01<00:24, 76.37it/s][A[A

  6%|▌         | 120/2000 [00:01<00:24, 76.05it/s][A[A

  6%|▋         | 128/2000 [00:01<00:24, 76.08it/s][A[A

  7%|▋         | 136/2000 [00:01<00

 56%|█████▋    | 1128/2000 [00:14<00:11, 76.09it/s][A[A

 57%|█████▋    | 1136/2000 [00:14<00:11, 75.83it/s][A[A

 57%|█████▋    | 1144/2000 [00:14<00:11, 75.98it/s][A[A

 58%|█████▊    | 1152/2000 [00:15<00:11, 76.34it/s][A[A

 58%|█████▊    | 1160/2000 [00:15<00:10, 76.57it/s][A[A

 58%|█████▊    | 1168/2000 [00:15<00:10, 76.09it/s][A[A

 59%|█████▉    | 1176/2000 [00:15<00:10, 76.27it/s][A[A

 59%|█████▉    | 1184/2000 [00:15<00:10, 76.09it/s][A[A

 60%|█████▉    | 1192/2000 [00:15<00:10, 76.18it/s][A[A

 60%|██████    | 1200/2000 [00:15<00:10, 76.19it/s][A[A

 60%|██████    | 1208/2000 [00:15<00:10, 76.34it/s][A[A

 61%|██████    | 1216/2000 [00:15<00:10, 76.66it/s][A[A

 61%|██████    | 1224/2000 [00:16<00:10, 75.66it/s][A[A

 62%|██████▏   | 1232/2000 [00:16<00:10, 75.94it/s][A[A

 62%|██████▏   | 1240/2000 [00:16<00:09, 76.12it/s][A[A

 62%|██████▏   | 1248/2000 [00:16<00:09, 76.18it/s][A[A

 63%|██████▎   | 1256/2000 [00:16<00:09, 76.29it/s][A[

 12%|█▏        | 232/2000 [00:03<00:23, 76.44it/s][A[A

 12%|█▏        | 240/2000 [00:03<00:23, 76.44it/s][A[A

 12%|█▏        | 248/2000 [00:03<00:22, 76.53it/s][A[A

 13%|█▎        | 256/2000 [00:03<00:22, 76.63it/s][A[A

 13%|█▎        | 264/2000 [00:03<00:22, 76.60it/s][A[A

 14%|█▎        | 272/2000 [00:03<00:22, 76.31it/s][A[A

 14%|█▍        | 280/2000 [00:03<00:22, 76.64it/s][A[A

 14%|█▍        | 288/2000 [00:03<00:22, 76.50it/s][A[A

 15%|█▍        | 296/2000 [00:03<00:22, 76.45it/s][A[A

 15%|█▌        | 304/2000 [00:03<00:22, 75.12it/s][A[A

 16%|█▌        | 312/2000 [00:04<00:22, 75.57it/s][A[A

 16%|█▌        | 320/2000 [00:04<00:22, 75.89it/s][A[A

 16%|█▋        | 328/2000 [00:04<00:21, 76.04it/s][A[A

 17%|█▋        | 336/2000 [00:04<00:21, 76.07it/s][A[A

 17%|█▋        | 344/2000 [00:04<00:21, 76.16it/s][A[A

 18%|█▊        | 352/2000 [00:04<00:21, 76.14it/s][A[A

 18%|█▊        | 360/2000 [00:04<00:21, 76.24it/s][A[A

 18%|█▊       

 68%|██████▊   | 1352/2000 [00:17<00:08, 76.55it/s][A[A

 68%|██████▊   | 1360/2000 [00:17<00:08, 75.95it/s][A[A

 68%|██████▊   | 1368/2000 [00:17<00:08, 76.19it/s][A[A

 69%|██████▉   | 1376/2000 [00:18<00:08, 75.83it/s][A[A

 69%|██████▉   | 1384/2000 [00:18<00:08, 75.85it/s][A[A

 70%|██████▉   | 1392/2000 [00:18<00:08, 75.60it/s][A[A

 70%|███████   | 1400/2000 [00:18<00:07, 76.10it/s][A[A

 70%|███████   | 1408/2000 [00:18<00:07, 76.37it/s][A[A

 71%|███████   | 1416/2000 [00:18<00:07, 76.21it/s][A[A

 71%|███████   | 1424/2000 [00:18<00:07, 75.84it/s][A[A

 72%|███████▏  | 1432/2000 [00:18<00:07, 75.54it/s][A[A

 72%|███████▏  | 1440/2000 [00:18<00:07, 75.69it/s][A[A

 72%|███████▏  | 1448/2000 [00:18<00:07, 75.81it/s][A[A

 73%|███████▎  | 1456/2000 [00:19<00:07, 76.00it/s][A[A

 73%|███████▎  | 1464/2000 [00:19<00:07, 76.16it/s][A[A

 74%|███████▎  | 1472/2000 [00:19<00:06, 76.53it/s][A[A

 74%|███████▍  | 1480/2000 [00:19<00:06, 76.21it/s][A[

 23%|██▎       | 464/2000 [00:06<00:20, 76.48it/s][A[A

 24%|██▎       | 472/2000 [00:06<00:20, 76.10it/s][A[A

 24%|██▍       | 480/2000 [00:06<00:20, 75.88it/s][A[A

 24%|██▍       | 488/2000 [00:06<00:19, 75.89it/s][A[A

 25%|██▍       | 496/2000 [00:06<00:19, 76.00it/s][A[A

 25%|██▌       | 504/2000 [00:06<00:19, 76.27it/s][A[A

 26%|██▌       | 512/2000 [00:06<00:19, 76.26it/s][A[A

 26%|██▌       | 520/2000 [00:06<00:19, 76.23it/s][A[A

 26%|██▋       | 528/2000 [00:06<00:19, 76.61it/s][A[A

 27%|██▋       | 536/2000 [00:07<00:19, 76.75it/s][A[A

 27%|██▋       | 544/2000 [00:07<00:18, 76.82it/s][A[A

 28%|██▊       | 552/2000 [00:07<00:18, 76.74it/s][A[A

 28%|██▊       | 560/2000 [00:07<00:18, 76.73it/s][A[A

 28%|██▊       | 568/2000 [00:07<00:18, 76.40it/s][A[A

 29%|██▉       | 576/2000 [00:07<00:18, 76.74it/s][A[A

 29%|██▉       | 584/2000 [00:07<00:18, 76.53it/s][A[A

 30%|██▉       | 592/2000 [00:07<00:18, 76.33it/s][A[A

 30%|███      

 79%|███████▉  | 1584/2000 [00:20<00:05, 76.42it/s][A[A

 80%|███████▉  | 1592/2000 [00:20<00:05, 76.43it/s][A[A

 80%|████████  | 1600/2000 [00:20<00:05, 76.63it/s][A[A

 80%|████████  | 1608/2000 [00:21<00:05, 76.80it/s][A[A

 81%|████████  | 1616/2000 [00:21<00:04, 76.94it/s][A[A

 81%|████████  | 1624/2000 [00:21<00:04, 76.99it/s][A[A

 82%|████████▏ | 1632/2000 [00:21<00:04, 76.81it/s][A[A

 82%|████████▏ | 1640/2000 [00:21<00:04, 76.71it/s][A[A

 82%|████████▏ | 1648/2000 [00:21<00:04, 76.62it/s][A[A

 83%|████████▎ | 1656/2000 [00:21<00:04, 76.67it/s][A[A

 83%|████████▎ | 1664/2000 [00:21<00:04, 76.82it/s][A[A

 84%|████████▎ | 1672/2000 [00:21<00:04, 76.64it/s][A[A

 84%|████████▍ | 1680/2000 [00:22<00:04, 76.44it/s][A[A

 84%|████████▍ | 1688/2000 [00:22<00:04, 75.85it/s][A[A

 85%|████████▍ | 1696/2000 [00:22<00:04, 75.76it/s][A[A

 85%|████████▌ | 1704/2000 [00:22<00:03, 75.83it/s][A[A

 86%|████████▌ | 1712/2000 [00:22<00:03, 75.74it/s][A[

0.9640015
0.08512825





In [13]:
# -ing to -in'
samples = [i for i in data['text'] if len(re.findall(r'ing ',' '.join(i.split(' ')[:100])))>=6]
samples = shuffle(samples[:2000])


original = [i for i in samples]
less = [re.sub(r'ing ',"in' ",i,count=3) for i in original]
more = [re.sub(r'ing ',"in' ",i) for i in original]


sim, mag = get_quant_analogy(original, less, more)

print(np.mean(sim))
print(np.mean(mag))




  0%|          | 0/2000 [00:00<?, ?it/s][A[A

  0%|          | 8/2000 [00:00<00:25, 77.68it/s][A[A

  1%|          | 16/2000 [00:00<00:26, 76.28it/s][A[A

  1%|          | 24/2000 [00:00<00:25, 76.09it/s][A[A

  2%|▏         | 32/2000 [00:00<00:25, 76.15it/s][A[A

  2%|▏         | 40/2000 [00:00<00:25, 75.89it/s][A[A

  2%|▏         | 48/2000 [00:00<00:25, 75.78it/s][A[A

  3%|▎         | 56/2000 [00:00<00:25, 75.85it/s][A[A

  3%|▎         | 64/2000 [00:00<00:25, 75.89it/s][A[A

  4%|▎         | 72/2000 [00:00<00:25, 75.32it/s][A[A

  4%|▍         | 80/2000 [00:01<00:25, 75.63it/s][A[A

  4%|▍         | 88/2000 [00:01<00:25, 75.64it/s][A[A

  5%|▍         | 96/2000 [00:01<00:25, 75.35it/s][A[A

  5%|▌         | 104/2000 [00:01<00:25, 75.01it/s][A[A

  6%|▌         | 112/2000 [00:01<00:25, 75.13it/s][A[A

  6%|▌         | 120/2000 [00:01<00:24, 75.35it/s][A[A

  6%|▋         | 128/2000 [00:01<00:25, 74.57it/s][A[A

  7%|▋         | 136/2000 [00:01<00

 56%|█████▋    | 1128/2000 [00:14<00:11, 75.95it/s][A[A

 57%|█████▋    | 1136/2000 [00:15<00:11, 75.95it/s][A[A

 57%|█████▋    | 1144/2000 [00:15<00:11, 76.02it/s][A[A

 58%|█████▊    | 1152/2000 [00:15<00:11, 76.11it/s][A[A

 58%|█████▊    | 1160/2000 [00:15<00:11, 76.26it/s][A[A

 58%|█████▊    | 1168/2000 [00:15<00:10, 76.30it/s][A[A

 59%|█████▉    | 1176/2000 [00:15<00:10, 75.33it/s][A[A

 59%|█████▉    | 1184/2000 [00:15<00:10, 75.45it/s][A[A

 60%|█████▉    | 1192/2000 [00:15<00:10, 75.56it/s][A[A

 60%|██████    | 1200/2000 [00:15<00:10, 75.37it/s][A[A

 60%|██████    | 1208/2000 [00:15<00:10, 75.65it/s][A[A

 61%|██████    | 1216/2000 [00:16<00:10, 75.72it/s][A[A

 61%|██████    | 1224/2000 [00:16<00:10, 75.80it/s][A[A

 62%|██████▏   | 1232/2000 [00:16<00:10, 75.65it/s][A[A

 62%|██████▏   | 1240/2000 [00:16<00:10, 75.83it/s][A[A

 62%|██████▏   | 1248/2000 [00:16<00:09, 75.98it/s][A[A

 63%|██████▎   | 1256/2000 [00:16<00:09, 76.01it/s][A[

 12%|█▏        | 232/2000 [00:03<00:23, 75.57it/s][A[A

 12%|█▏        | 240/2000 [00:03<00:23, 75.51it/s][A[A

 12%|█▏        | 248/2000 [00:03<00:23, 75.73it/s][A[A

 13%|█▎        | 256/2000 [00:03<00:22, 75.83it/s][A[A

 13%|█▎        | 264/2000 [00:03<00:22, 76.12it/s][A[A

 14%|█▎        | 272/2000 [00:03<00:22, 76.34it/s][A[A

 14%|█▍        | 280/2000 [00:03<00:22, 76.36it/s][A[A

 14%|█▍        | 288/2000 [00:03<00:22, 76.21it/s][A[A

 15%|█▍        | 296/2000 [00:03<00:22, 76.06it/s][A[A

 15%|█▌        | 304/2000 [00:04<00:22, 75.83it/s][A[A

 16%|█▌        | 312/2000 [00:04<00:22, 75.83it/s][A[A

 16%|█▌        | 320/2000 [00:04<00:22, 75.44it/s][A[A

 16%|█▋        | 328/2000 [00:04<00:22, 75.55it/s][A[A

 17%|█▋        | 336/2000 [00:04<00:22, 75.47it/s][A[A

 17%|█▋        | 344/2000 [00:04<00:21, 75.96it/s][A[A

 18%|█▊        | 352/2000 [00:04<00:21, 75.42it/s][A[A

 18%|█▊        | 360/2000 [00:04<00:21, 75.46it/s][A[A

 18%|█▊       

 68%|██████▊   | 1352/2000 [00:17<00:08, 75.80it/s][A[A

 68%|██████▊   | 1360/2000 [00:17<00:08, 75.54it/s][A[A

 68%|██████▊   | 1368/2000 [00:18<00:08, 75.69it/s][A[A

 69%|██████▉   | 1376/2000 [00:18<00:08, 75.76it/s][A[A

 69%|██████▉   | 1384/2000 [00:18<00:08, 75.74it/s][A[A

 70%|██████▉   | 1392/2000 [00:18<00:08, 75.37it/s][A[A

 70%|███████   | 1400/2000 [00:18<00:07, 75.71it/s][A[A

 70%|███████   | 1408/2000 [00:18<00:07, 75.70it/s][A[A

 71%|███████   | 1416/2000 [00:18<00:07, 75.14it/s][A[A

 71%|███████   | 1424/2000 [00:18<00:07, 75.49it/s][A[A

 72%|███████▏  | 1432/2000 [00:18<00:07, 75.16it/s][A[A

 72%|███████▏  | 1440/2000 [00:19<00:07, 75.54it/s][A[A

 72%|███████▏  | 1448/2000 [00:19<00:07, 75.29it/s][A[A

 73%|███████▎  | 1456/2000 [00:19<00:07, 75.50it/s][A[A

 73%|███████▎  | 1464/2000 [00:19<00:07, 75.63it/s][A[A

 74%|███████▎  | 1472/2000 [00:19<00:07, 75.25it/s][A[A

 74%|███████▍  | 1480/2000 [00:19<00:06, 75.04it/s][A[

 23%|██▎       | 464/2000 [00:06<00:20, 75.01it/s][A[A

 24%|██▎       | 472/2000 [00:06<00:20, 75.28it/s][A[A

 24%|██▍       | 480/2000 [00:06<00:20, 75.44it/s][A[A

 24%|██▍       | 488/2000 [00:06<00:20, 75.34it/s][A[A

 25%|██▍       | 496/2000 [00:06<00:20, 75.15it/s][A[A

 25%|██▌       | 504/2000 [00:06<00:19, 75.08it/s][A[A

 26%|██▌       | 512/2000 [00:06<00:19, 74.60it/s][A[A

 26%|██▌       | 520/2000 [00:06<00:19, 74.87it/s][A[A

 26%|██▋       | 528/2000 [00:07<00:20, 73.40it/s][A[A

 27%|██▋       | 536/2000 [00:07<00:19, 73.42it/s][A[A

 27%|██▋       | 544/2000 [00:07<00:19, 74.08it/s][A[A

 28%|██▊       | 552/2000 [00:07<00:19, 74.37it/s][A[A

 28%|██▊       | 560/2000 [00:07<00:19, 75.05it/s][A[A

 28%|██▊       | 568/2000 [00:07<00:19, 74.42it/s][A[A

 29%|██▉       | 576/2000 [00:07<00:19, 74.90it/s][A[A

 29%|██▉       | 584/2000 [00:07<00:18, 74.91it/s][A[A

 30%|██▉       | 592/2000 [00:07<00:18, 75.43it/s][A[A

 30%|███      

 79%|███████▉  | 1584/2000 [00:21<00:05, 75.91it/s][A[A

 80%|███████▉  | 1592/2000 [00:21<00:05, 75.71it/s][A[A

 80%|████████  | 1600/2000 [00:21<00:05, 75.64it/s][A[A

 80%|████████  | 1608/2000 [00:21<00:05, 75.36it/s][A[A

 81%|████████  | 1616/2000 [00:21<00:05, 75.76it/s][A[A

 81%|████████  | 1624/2000 [00:21<00:04, 75.20it/s][A[A

 82%|████████▏ | 1632/2000 [00:21<00:04, 74.10it/s][A[A

 82%|████████▏ | 1640/2000 [00:21<00:04, 74.13it/s][A[A

 82%|████████▏ | 1648/2000 [00:21<00:04, 74.02it/s][A[A

 83%|████████▎ | 1656/2000 [00:22<00:04, 74.64it/s][A[A

 83%|████████▎ | 1664/2000 [00:22<00:04, 73.81it/s][A[A

 84%|████████▎ | 1672/2000 [00:22<00:04, 73.80it/s][A[A

 84%|████████▍ | 1680/2000 [00:22<00:04, 74.56it/s][A[A

 84%|████████▍ | 1688/2000 [00:22<00:04, 74.69it/s][A[A

 85%|████████▍ | 1696/2000 [00:22<00:04, 75.13it/s][A[A

 85%|████████▌ | 1704/2000 [00:22<00:03, 75.28it/s][A[A

 86%|████████▌ | 1712/2000 [00:22<00:03, 75.53it/s][A[

0.9330549
0.10285112





In [None]:
device = 'cuda'
model_path = '/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/authorship_models/redditroberta-cosine-modified_anchor-mask-0.1-delta-0.4-0.6-alpha-30.0/model-4'
model = torch.load(model_path).to(device)
model.eval()
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

data = pd.read_csv('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/authorship/amazon_test_samples',sep='\t')

In [None]:
texta = [i for i in data.sample(1000)['text']]
textb = [i for i in data.sample(1000)['text']]

out = get_analogy(texta,textb,mask=True)
print(out)

In [None]:
#and to &
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace('and','&') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# . to space
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace('.',' ') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# ! to !!!
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace('!','!!!!') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# lower case
texta = [i for i in data.sample(1000)['text']]
textb = [i.lower() for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# upper case
texta = [i for i in data.sample(1000)['text']]
textb = [i.upper() for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
texta = [i for i in data.sample(1000)['text']]
textb = [i.replace(' I ',' ') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# going to to gonna

texta = [i for i in data['text'] if len(re.findall(r'going to',' '.join(i.split(' ')[:100]))) >=2 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('going to','gonna') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# want to to wanna

texta = [i for i in data['text'] if len(re.findall(r'want to',' '.join(i.split(' ')[:100]))) >=2 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('want to','wonna') for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
# -ing to -in'

texta = [i for i in data['text'] if len(re.findall(r'ing ',' '.join(i.split(' ')[:100]))) >=3 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('ing ',"in' ") for i in texta]

out = get_analogy(texta,textb)
print(out)

In [None]:
texta = [i for i in data['text'] if len(re.findall(r'good',' '.join(i.split(' ')[:100]))) >=3 ]
texta = [i for i in shuffle(texta)[:1000]]
textb = [i.replace('good','goooooood') for i in texta]

out = get_analogy(texta,textb)
print(out)


In [4]:

def get_quant_analogy(original, less, more, mask=False):
    
    anchor = np.stack([extract_emb(i) for i in tqdm(original)])
    if mask == True:
        emb = [extract_quant_emb(i,mask=True,prob=0.05) for i in tqdm(original)]
        emba = np.stack([i[0] for i in emb])
        embb = np.stack([i[1] for i in emb])
        
    else:
        emba = np.stack([extract_emb(i) for i in tqdm(less)])
        embb = np.stack([extract_emb(i) for i in tqdm(more)])
    
    anchor = torch.tensor(anchor).to(device)
    emba = torch.tensor(emba).to(device)
    embb = torch.tensor(embb).to(device)

    # Get pairwise similarity
    diff_less = anchor - emba
    diff_more = anchor - embb
    
    distance = F.normalize(diff_less,dim=-1) * F.normalize(diff_more,dim=-1)
    distance = torch.sum(distance,dim=-1).cpu().detach().numpy()

    less_norm = torch.norm(diff_less,dim=-1)
    more_norm = torch.norm(diff_more,dim=-1)
    mag_diff = (more_norm - less_norm).cpu().detach().numpy()

    return distance, mag_diff

In [5]:
def extract_quant_emb(text,mask=False,prob=0.05):
    tokenized = tokenizer.encode_plus(text,add_special_tokens=True, max_length=102,truncation=True,return_tensors="pt")
    if mask == True:
        tokenized['input_ids'] = masking(tokenized['input_ids'],mlm_prob=0.05)
        hidden = model(tokenized['input_ids'].to(device),tokenized['attention_mask'].to(device))
        hidden = F.normalize(hidden,dim=-1)
        hidden_less = hidden.cpu().detach().numpy()
        
        tokenized['input_ids'] = masking(tokenized['input_ids'],mlm_prob=0.05)
        hidden = model(tokenized['input_ids'].to(device),tokenized['attention_mask'].to(device))
        hidden = F.normalize(hidden,dim=-1)
        hidden_more = hidden.cpu().detach().numpy()
        
        return hidden_less, hidden_more
        
    else:
        hidden = model(tokenized['input_ids'].to(device),tokenized['attention_mask'].to(device))
        hidden = F.normalize(hidden,dim=-1)
        hidden = hidden.cpu().detach().numpy()
        return hidden

In [None]:
# random baseline
original = [i for i in data.sample(2000)['text']]
less = [i for i in data.sample(2000)['text']]
more = [i for i in data.sample(2000)['text']]

bsim, bmag = get_quant_analogy(original, original, original,mask=True)
print(np.mean(bsim))
print(np.mean(bmag))


In [None]:
# I to null

samples = [i for i in data['text'] if len(re.findall(r'I ',' '.join(i.split(' ')[:100])))>=10]
samples = shuffle(samples[:2000])


original = [i.lower() for i in samples]
less = [re.sub(r'i ',' ',i.lower(),count=4) for i in original]
more = [re.sub(r'i ',' ',i.lower()) for i in original]


sim, mag = get_quant_analogy(original, less, more)
print(np.mean(sim))
print(np.mean(mag))


In [None]:
samples = [i for i in data['text'] if len(re.findall(r'and ',' '.join(i.split(' ')[:100])))>=6]
samples = shuffle(samples[:2000])


original = [i for i in samples]
less = [re.sub(r'and ','& ',i,count=3) for i in original]
more = [re.sub(r'and ','& ',i) for i in original]


sim, mag = get_quant_analogy(original, less, more)

print(np.mean(sim))
print(np.mean(mag))


In [None]:
samples = [i for i in data['text'] if len(re.findall(r'. ',' '.join(i.split(' ')[:100])))>=6]
samples = shuffle(samples[:2000])


original = [i for i in samples]
less = [re.sub(r'\. ','!!!! ',i,count=3) for i in original]
more = [re.sub(r'\. ','!!!! ',i) for i in original]


sim, mag = get_quant_analogy(original, less, more)

print(np.mean(sim))
print(np.mean(mag))


In [None]:
# -ing to -in'
samples = [i for i in data['text'] if len(re.findall(r'ing ',' '.join(i.split(' ')[:100])))>=6]
samples = shuffle(samples[:2000])


original = [i for i in samples]
less = [re.sub(r'ing ',"in' ",i,count=3) for i in original]
more = [re.sub(r'ing ',"in' ",i) for i in original]


sim, mag = get_quant_analogy(original, less, more)

print(np.mean(sim))
print(np.mean(mag))
