In [35]:
import sys
sys.path.append('..')
from tqdm import tqdm

import pandas as pd
import numpy as np

from adat.utils import calculate_normalized_wer
from adat.masker import get_default_masker

In [60]:
train = pd.read_csv('../data/ag_news_mini/train.csv')
test = pd.read_csv('../data/ag_news_mini/test.csv')

In [61]:
data = pd.concat([train, test])

In [62]:
masker = get_default_masker()

In [63]:
data.columns

Index(['sequences', 'labels'], dtype='object')

In [64]:
transactions = data.sequences.values

In [65]:
close_to_zero_examples = []

num_close_to_zero = 20000
close_to_zero_indexes = np.random.randint(0, len(transactions), size=(num_close_to_zero, 2))

for id1, id2 in tqdm(close_to_zero_indexes):
    tr1 = transactions[id1]
    tr2 = transactions[id2]
    wer_sim = 1 - calculate_normalized_wer(tr1, tr2)
    close_to_zero_examples.append((tr1, tr2, wer_sim))

100%|██████████| 20000/20000 [00:00<00:00, 84722.72it/s]


In [66]:
some_examples = []
num_some_examples = 100000
some_examples_indexes = np.random.randint(0, len(transactions), size=num_some_examples)

for idx in tqdm(some_examples_indexes):
    tr1 = transactions[idx]
    tr2, applied = masker.mask(tr1)
    if applied:
        wer_sim = 1 - calculate_normalized_wer(tr1, tr2)
        some_examples.append((tr1, tr2, wer_sim))

100%|██████████| 100000/100000 [01:09<00:00, 1431.92it/s]


In [67]:
len(some_examples)

47617

In [68]:
len(close_to_zero_examples)

20000

In [69]:
examples = []
examples.extend(close_to_zero_examples)
examples.extend(some_examples)

In [70]:
examples = pd.DataFrame(examples, columns=['seq_a', 'seq_b', 'similarity'])

In [71]:
examples.head()

Unnamed: 0,seq_a,seq_b,similarity
0,space station food supply lower than thought,greek weightlifter stripped of olympic medal ...,0.0
1,bombs hit us british targets in turkish citie...,four confirmed dead in peru police shootout,0.111111
2,before the bell merck rises pct shares traded,aussie misses out on miss world,0.0
3,schroeder to meet european commission chief,dollar continues slide vs euro reuters,0.0
4,nikkei at week closing low,twins keep us alive in davis cup finals,0.0


In [72]:
np.median(examples.similarity)

0.7142857142857143

In [73]:
examples.similarity.mean()

0.5390905456218099

In [74]:
examples.similarity.max()

1.0

In [75]:
examples.similarity.min()

0.0

In [76]:
examples = examples.sample(frac=1).reset_index(drop=True)

In [77]:
examples.head()

Unnamed: 0,seq_a,seq_b,similarity
0,livewire travel search sites look for bargain...,livewire travel search sites look for bargains...,0.888889
1,typhoon meari passing through japan s norther...,typhoon meari passing through check s northern...,0.888889
2,google groups get going,goals google groups get going,0.8
3,pfizer reports positive trial results,cocacola pfizer reports positive trial results,0.833333
4,local search missing pieces falling into place,local local search missing pieces falling into...,0.875


In [78]:
examples.shape

(67617, 3)

In [55]:
from sklearn.model_selection import train_test_split

In [56]:
tr, te = train_test_split(examples, test_size=0.07, random_state=23)

In [57]:
tr.shape, te.shape

((67685, 3), (5095, 3))

In [None]:
tr.to_csv('../data/ag_news_mini/levenshtein/train.csv', index=False)
te.to_csv('../data/ag_news_mini/levenshtein/test.csv', index=False)

In [29]:
# tr.to_csv('../data/deep_lev/train.csv', index=False)
# te.to_csv('../data/deep_lev/test.csv', index=False)

In [30]:
te

Unnamed: 0,seq_a,seq_b,similarity
119458,bin2_trans1 bin2_trans1 bin1_trans1 bin3_trans...,bin2_trans1 bin2_trans1 bin3_trans3 bin3_trans3,0.800000
181328,bin1_trans20 bin3_trans24 bin3_trans1 bin3_tra...,bin1_trans20 bin3_trans24 bin3_trans1 bin3_tra...,0.857143
229909,bin3_trans3 bin3_trans1 bin3_trans3 bin1_trans...,bin2_trans42 bin0_trans31 bin1_trans25 bin3_tr...,0.000000
176646,bin4_trans1 bin1_trans1 bin4_trans34 bin3_tran...,bin3_trans1 bin1_trans1 bin4_trans34 bin4_tran...,0.666667
676003,bin1_trans4 bin4_trans0 bin3_trans15 bin4_tran...,bin3_trans3 bin4_trans1 bin1_trans11 bin4_tran...,0.111111
...,...,...,...
298405,bin1_trans3 bin3_trans1 bin0_trans18 bin2_tran...,bin2_trans1 bin3_trans1 bin0_trans18 bin1_tran...,0.750000
449464,bin3_trans1 bin2_trans15 bin1_trans1 bin1_tran...,bin3_trans1 bin2_trans15 bin1_trans1 bin1_tran...,0.857143
3076,bin1_trans1 bin1_trans1 bin2_trans11 bin2_tran...,bin1_trans1 bin1_trans1 bin2_trans11 bin2_tran...,0.875000
484385,bin3_trans1 bin2_trans1 bin4_trans1 bin2_trans...,bin3_trans1 bin2_trans1 bin2_trans18 bin4_tran...,0.857143
