In [3]:
import sys
sys.path.append('..')
from tqdm import tqdm

import pandas as pd
import numpy as np

from adat.utils import calculate_normalized_wer
from adat.masker import get_default_masker

In [4]:
train = pd.read_csv('../data/ag_news_mini/train.csv')
test = pd.read_csv('../data/ag_news_mini/test.csv')

In [5]:
data = pd.concat([train, test])

In [6]:
masker = get_default_masker()

In [7]:
data.columns

Index(['sequences', 'labels'], dtype='object')

In [8]:
transactions = data.sequences.values

In [9]:
close_to_zero_examples = []

num_close_to_zero = 20000
close_to_zero_indexes = np.random.randint(0, len(transactions), size=(num_close_to_zero, 2))

for id1, id2 in tqdm(close_to_zero_indexes):
    tr1 = transactions[id1]
    tr2 = transactions[id2]
    wer_sim = 1 - calculate_normalized_wer(tr1, tr2)
    close_to_zero_examples.append((tr1, tr2, wer_sim))

100%|██████████| 20000/20000 [00:00<00:00, 63599.74it/s]


In [10]:
some_examples = []
num_some_examples = 100000
some_examples_indexes = np.random.randint(0, len(transactions), size=num_some_examples)

for idx in tqdm(some_examples_indexes):
    tr1 = transactions[idx]
    tr2, applied = masker.mask(tr1)
    if applied:
        wer_sim = 1 - calculate_normalized_wer(tr1, tr2)
        some_examples.append((tr1, tr2, wer_sim))

100%|██████████| 100000/100000 [01:11<00:00, 1398.25it/s]


In [9]:
len(some_examples)

52903

In [10]:
len(close_to_zero_examples)

20000

In [11]:
examples = []
examples.extend(close_to_zero_examples)
examples.extend(some_examples)

In [12]:
examples = pd.DataFrame(examples, columns=['seq_a', 'seq_b', 'similarity'])

In [13]:
examples.head()

Unnamed: 0,seq_a,seq_b,similarity
0,bin3_trans3 bin2_trans12 bin2_trans82 bin2_tra...,bin4_trans34 bin3_trans64 bin3_trans4 bin0_tra...,0.0
1,bin0_trans34 bin0_trans36 bin4_trans1 bin1_tra...,bin1_trans3 bin2_trans1 bin2_trans4 bin3_trans...,0.066667
2,bin4_trans1 bin3_trans1 bin3_trans3 bin4_trans...,bin4_trans3 bin2_trans1 bin3_trans11 bin4_tran...,0.0
3,bin0_trans4 bin3_trans1 bin0_trans20 bin4_tran...,bin2_trans9 bin4_trans2 bin4_trans9 bin4_trans...,0.142857
4,bin1_trans15 bin3_trans1 bin0_trans61 bin0_tra...,bin4_trans1 bin0_trans1 bin3_trans50 bin1_tran...,0.0


In [14]:
np.median(examples.similarity)

0.7692307692307692

In [15]:
examples.similarity.mean()

0.5967572830005643

In [16]:
examples.similarity.max()

1.0

In [17]:
examples.similarity.min()

0.0

In [18]:
examples = examples.sample(frac=1).reset_index(drop=True)

In [19]:
examples.head()

Unnamed: 0,seq_a,seq_b,similarity
0,bin1_trans1 bin0_trans1 bin0_trans1 bin0_trans...,bin1_trans1 bin1_trans1 bin0_trans1 bin0_trans...,0.625
1,bin3_trans1 bin4_trans9 bin4_trans1 bin4_trans...,bin3_trans1 bin4_trans9 bin4_trans9 bin4_trans...,0.785714
2,bin0_trans3 bin0_trans1 bin0_trans9 bin3_trans...,bin2_trans3 bin1_trans29 bin1_trans1 bin2_tran...,0.083333
3,bin3_trans4 bin3_trans3 bin3_trans36 bin4_tran...,bin3_trans4 bin3_trans3 bin3_trans36 bin4_tran...,0.888889
4,bin0_trans20 bin0_trans3 bin1_trans18 bin0_tra...,bin0_trans47 bin0_trans20 bin0_trans3 bin1_tra...,0.8125


In [20]:
examples.shape

(72903, 3)

In [21]:
from sklearn.model_selection import train_test_split

In [22]:
tr, te = train_test_split(examples, test_size=0.07, random_state=23)

In [23]:
tr.shape, te.shape

((67799, 3), (5104, 3))

In [24]:
tr.to_csv('../data/ai_academy_data_mini/levenshtein/train.csv', index=False)
te.to_csv('../data/ai_academy_data_mini/levenshtein/test.csv', index=False)

In [29]:
# tr.to_csv('../data/deep_lev/train.csv', index=False)
# te.to_csv('../data/deep_lev/test.csv', index=False)

In [30]:
te

Unnamed: 0,seq_a,seq_b,similarity
119458,bin2_trans1 bin2_trans1 bin1_trans1 bin3_trans...,bin2_trans1 bin2_trans1 bin3_trans3 bin3_trans3,0.800000
181328,bin1_trans20 bin3_trans24 bin3_trans1 bin3_tra...,bin1_trans20 bin3_trans24 bin3_trans1 bin3_tra...,0.857143
229909,bin3_trans3 bin3_trans1 bin3_trans3 bin1_trans...,bin2_trans42 bin0_trans31 bin1_trans25 bin3_tr...,0.000000
176646,bin4_trans1 bin1_trans1 bin4_trans34 bin3_tran...,bin3_trans1 bin1_trans1 bin4_trans34 bin4_tran...,0.666667
676003,bin1_trans4 bin4_trans0 bin3_trans15 bin4_tran...,bin3_trans3 bin4_trans1 bin1_trans11 bin4_tran...,0.111111
...,...,...,...
298405,bin1_trans3 bin3_trans1 bin0_trans18 bin2_tran...,bin2_trans1 bin3_trans1 bin0_trans18 bin1_tran...,0.750000
449464,bin3_trans1 bin2_trans15 bin1_trans1 bin1_tran...,bin3_trans1 bin2_trans15 bin1_trans1 bin1_tran...,0.857143
3076,bin1_trans1 bin1_trans1 bin2_trans11 bin2_tran...,bin1_trans1 bin1_trans1 bin2_trans11 bin2_tran...,0.875000
484385,bin3_trans1 bin2_trans1 bin4_trans1 bin2_trans...,bin3_trans1 bin2_trans1 bin2_trans18 bin4_tran...,0.857143
