In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from relation_modeling_utils import load_data

train_df = load_data("data/atomic2020_data-feb2021/train.tsv", multi_label=True)
val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
test_df = load_data("data/atomic2020_data-feb2021/test.tsv", multi_label=True)

In [3]:
import pandas as pd

atomic_df = pd.concat([train_df, val_df, test_df])

In [4]:
len(atomic_df)

46471

In [5]:
atomic_df.duplicated(subset=["text"]).sum()

2513

In [6]:
atomic_df = atomic_df.drop_duplicates(subset=["text"])

In [7]:
len(atomic_df)

43958

In [20]:
import spacy
from collections import defaultdict
from tqdm import tqdm
from spacy.lang.en.stop_words import STOP_WORDS

STOPWORDS = STOP_WORDS | set(["PersonX", "PersonY", "PersonZ", "_", "'", "-"])

def build_doclist(data):
    nlp = spacy.load("en_core_web_sm", exclude=["ner"])
    doclist = []
    vocab = defaultdict(int)

    for row in tqdm(data.itertuples(), total=len(data)):
        doc = nlp(row.text)
        words = set()

        for token in doc:
            if token.text not in STOPWORDS:
                vocab[token.lemma_] += 1
                words.add(token.lemma_)
        
        doc.user_data['words'] = words
        doc.user_data['label'] = row.label
        doclist.append(doc)
    
    for doc in doclist:
        freqs = 0
        for word in doc.user_data['words']:
            freqs += vocab.get(word, 0)
        
        doc.user_data['freqs'] = freqs
        doc.user_data['freq_diff'] = abs(freqs - len(doc.user_data['words']))
    
    return doclist

In [21]:
atomic_doclist = build_doclist(atomic_df)

100%|██████████| 43958/43958 [02:22<00:00, 308.47it/s]


In [22]:
class1_doclist = [doc for doc in atomic_doclist if doc.user_data['label'][0] == 1]
class2_doclist = [doc for doc in atomic_doclist if doc.user_data['label'][1] == 1]
class3_doclist = [doc for doc in atomic_doclist if doc.user_data['label'][2] == 1]

In [43]:
class1_freq1_docs = [doc for doc in class1_doclist if doc.user_data['freq_diff'] < 1][:1000]
class2_freq1_docs = [doc for doc in class2_doclist if doc.user_data['freq_diff'] < 5][:1000]
class3_freq1_docs = [doc for doc in class3_doclist if doc.user_data['freq_diff'] < 5][:1000]

In [44]:
len(class1_freq1_docs), len(class2_freq1_docs), len(class3_freq1_docs)

(1000, 525, 652)

In [45]:
test_samples = [doc.text for doc in class1_freq1_docs+class2_freq1_docs+class3_freq1_docs]

In [46]:
len(test_samples)

2177

In [47]:
train_data, test_data = [], []

for row in atomic_df.itertuples():
    if row.text in test_samples:
        test_data.append((row.text, row.label))
    else:
        train_data.append((row.text, row.label))

In [48]:
new_train_df = pd.DataFrame(train_data, columns=["text", "label"])
new_test_df = pd.DataFrame(test_data, columns=["text", "label"])

In [49]:
from relation_modeling_utils import create_vocab
train_vocab, test_vocab = create_vocab(new_train_df), create_vocab(new_test_df)

100%|██████████| 165332/165332 [00:00<00:00, 862240.57it/s]
100%|██████████| 4093/4093 [00:00<00:00, 735782.88it/s]


In [50]:
len(train_vocab.intersection(test_vocab)) / len(train_vocab), len(train_vocab.intersection(test_vocab)) / len(test_vocab)

(0.03989473684210526, 0.22060535506402795)

In [51]:
from relation_modeling_utils import explode_labels
new_train_df, new_test_df = explode_labels(new_train_df), explode_labels(new_test_df)

In [52]:
new_test_df.label_0.value_counts(), new_test_df.label_1.value_counts(), new_test_df.label_2.value_counts()

(1    1068
 0     743
 Name: label_0, dtype: int64,
 0    1286
 1     525
 Name: label_1, dtype: int64,
 0    1159
 1     652
 Name: label_2, dtype: int64)

In [53]:
from relation_modeling_utils import get_class_dist_report

get_class_dist_report(new_test_df)

{('class_0', 0): 0.4102705687465489,
 ('class_0', 'class_0', 0, 0): 0.4102705687465489,
 ('class_0', 'class_0', 0, 1): 0.0,
 ('class_0', 1): 0.5897294312534511,
 ('class_0', 'class_0', 1, 0): 0.0,
 ('class_0', 'class_0', 1, 1): 0.5897294312534511,
 ('class_0', 'class_1', 0, 0): 0.16344561016013254,
 ('class_0', 'class_1', 0, 1): 0.24682495858641634,
 ('class_0', 'class_1', 1, 0): 0.5466593042517945,
 ('class_0', 'class_1', 1, 1): 0.04307012700165654,
 ('class_0', 'class_2', 0, 0): 0.05632247377139702,
 ('class_0', 'class_2', 0, 1): 0.35394809497515184,
 ('class_0', 'class_2', 1, 0): 0.5836554389839868,
 ('class_0', 'class_2', 1, 1): 0.006073992269464384,
 ('class_1', 0): 0.7101049144119271,
 ('class_1', 'class_0', 0, 0): 0.16344561016013254,
 ('class_1', 'class_0', 0, 1): 0.5466593042517945,
 ('class_1', 1): 0.28989508558807286,
 ('class_1', 'class_0', 1, 0): 0.24682495858641634,
 ('class_1', 'class_0', 1, 1): 0.04307012700165654,
 ('class_1', 'class_1', 0, 0): 0.7101049144119271,
 ('c