In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from relation_modeling_utils import load_data

train_df = load_data("data/atomic2020_data-feb2021/train.tsv", multi_label=True)
val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
test_df = load_data("data/atomic2020_data-feb2021/test.tsv", multi_label=True)

In [70]:
len(train_df), len(val_df), len(test_df)

(36940, 2962, 6569)

## Original data lexical overlap

In [3]:
from relation_modeling_utils import create_vocab
train_vocab, val_vocab, test_vocab = create_vocab(train_df), create_vocab(val_df), create_vocab(test_df)

100%|██████████| 140049/140049 [00:00<00:00, 903496.24it/s]
100%|██████████| 14524/14524 [00:00<00:00, 940616.26it/s]
100%|██████████| 27270/27270 [00:00<00:00, 892746.41it/s]


### Lexical overlap with stopwords

In [4]:
train_test_overlap = set(train_vocab).intersection(set(test_vocab))
len(train_test_overlap) / len(train_vocab), len(train_test_overlap) / len(test_vocab)

(0.2782412405535381, 0.8130197877831947)

In [5]:
train_val_overlap = set(train_vocab).intersection(set(val_vocab))
len(train_val_overlap) / len(train_vocab), len(train_val_overlap) / len(val_vocab)

(0.15889684954362548, 0.8837336244541485)

### Lexical overlap without stopwords

In [6]:
from spacy.lang.en.stop_words import STOP_WORDS

train_vocab_nostp = {word: freq for word, freq in train_vocab.items() if word not in STOP_WORDS}
val_vocab_nostp = {word: freq for word, freq in val_vocab.items() if word not in STOP_WORDS}
test_vocab_nostp = {word: freq for word, freq in test_vocab.items() if word not in STOP_WORDS}

In [7]:
train_test_overlap_nostp = set(train_vocab_nostp).intersection(set(test_vocab_nostp))
len(train_test_overlap_nostp) / len(train_vocab_nostp), len(train_test_overlap_nostp) / len(test_vocab_nostp)

(0.26732176877569436, 0.8037383177570093)

In [8]:
train_val_overlap_nostp = set(train_vocab_nostp).intersection(set(val_vocab_nostp))
len(train_val_overlap_nostp) / len(train_vocab_nostp), len(train_val_overlap_nostp) / len(val_vocab_nostp)

(0.14739797453123432, 0.8739595719381689)

## Create new ATOMIC datasets

In [9]:
import pandas as pd

atomic_df = pd.concat([train_df, val_df, test_df])

In [10]:
len(atomic_df)

46471

In [11]:
atomic_df.duplicated(subset=["text"]).sum()

2513

In [12]:
atomic_df = atomic_df.drop_duplicates(subset=["text"])

In [13]:
len(atomic_df)

43958

In [14]:
import spacy
from collections import defaultdict
from tqdm import tqdm
from relation_modeling_utils import IGNORE_WORDS, create_vocab
from spacy.lang.en.stop_words import STOP_WORDS

def make_docs(data, vocab, exclude_stopwords=False):
    nlp = spacy.load("en_core_web_sm", exclude=["ner"])
    docs = []

    for row in tqdm(data.itertuples(), total=len(data)):
        doc = nlp(row.text)
        words = set()

        for token in doc:
            if token.text not in IGNORE_WORDS and (not exclude_stopwords or token.text not in STOP_WORDS):
                words.add(token.lemma_)
        
        doc.user_data['words'] = words
        doc.user_data['label'] = row.label
        docs.append(doc)
    
    for doc in docs:
        freqs = 0
        for word in doc.user_data['words']:
            freqs += vocab.get(word, 0) - 1
        
        doc.user_data['relative_freq'] = freqs
    
    return docs

In [15]:
atomic_vocab = create_vocab(atomic_df)

100%|██████████| 169425/169425 [00:00<00:00, 911559.59it/s]


In [16]:
atomic_docs = make_docs(atomic_df, atomic_vocab, exclude_stopwords=True)

100%|██████████| 43958/43958 [02:45<00:00, 266.21it/s]


In [17]:
sorted(atomic_vocab.items(), key=lambda i: i[1])[-5:]

[('in', 1555), ('a', 3031), ('to', 3141), ('the', 4831), ("'s", 6519)]

In [18]:
class1_docs = [doc for doc in atomic_docs if doc.user_data['label'][0] == 1]
class2_docs = [doc for doc in atomic_docs if doc.user_data['label'][1] == 1]
class3_docs = [doc for doc in atomic_docs if doc.user_data['label'][2] == 1]

In [84]:
FREQUENCY_THRESHOLD = 1
class1_freq1_docs = [doc for doc in class1_docs if doc.user_data['relative_freq'] < 1][:500]
class2_freq1_docs = [doc for doc in class2_docs if doc.user_data['relative_freq'] < FREQUENCY_THRESHOLD][:1000]
class3_freq1_docs = [doc for doc in class3_docs if doc.user_data['relative_freq'] < FREQUENCY_THRESHOLD][:1000]

In [85]:
len(class1_freq1_docs), len(class2_freq1_docs), len(class3_freq1_docs)

(500, 213, 241)

In [86]:
test_samples = [doc.text for doc in class1_freq1_docs+class2_freq1_docs+class3_freq1_docs]

In [87]:
len(test_samples)

954

In [88]:
new_train_data, new_test_data = [], []

for row in atomic_df.itertuples():
    if row.text in test_samples:
        new_test_data.append((row.text, row.label))
    else:
        new_train_data.append((row.text, row.label))

In [89]:
from sklearn.model_selection import train_test_split

new_train_df, new_val_df = train_test_split(pd.DataFrame(new_train_data, columns=["text", "label"]), test_size=0.1, random_state=42)
new_test_df = pd.DataFrame(new_test_data, columns=["text", "label"])

In [90]:
len(new_train_df), len(new_val_df), len(new_test_df)

(38822, 4314, 822)

In [91]:
from relation_modeling_utils import create_vocab
new_train_vocab, new_val_vocab, new_test_vocab = create_vocab(new_train_df), create_vocab(new_val_df), create_vocab(new_test_df)

100%|██████████| 151030/151030 [00:00<00:00, 665154.35it/s]
100%|██████████| 16802/16802 [00:00<00:00, 756469.47it/s]
100%|██████████| 1593/1593 [00:00<00:00, 655244.31it/s]


### New lexical overlap with stopwords

In [92]:
new_train_val_overlap = set(new_train_vocab).intersection(set(new_val_vocab))
len(new_train_val_overlap) / len(new_train_vocab), len(new_train_val_overlap) / len(new_val_vocab)

(0.24600900532132625, 0.8215994531784006)

In [93]:
new_train_test_overlap = set(new_train_vocab).intersection(set(new_test_vocab))
len(new_train_test_overlap) / len(new_train_vocab), len(new_train_test_overlap) / len(new_test_vocab)

(0.010233319688907082, 0.12004801920768307)

### New lexical overlap without stopwords

In [94]:
new_train_vocab_nostp = {word: freq for word, freq in new_train_vocab.items() if word not in STOP_WORDS}
new_val_vocab_nostp = {word: freq for word, freq in new_val_vocab.items() if word not in STOP_WORDS}
new_test_vocab_nostp = {word: freq for word, freq in new_test_vocab.items() if word not in STOP_WORDS}

In [95]:
new_train_val_overlap_nostp = set(new_train_vocab_nostp).intersection(set(new_val_vocab_nostp))
len(new_train_val_overlap_nostp) / len(new_train_vocab_nostp), len(new_train_val_overlap_nostp) / len(new_val_vocab_nostp)

(0.23518944944525852, 0.8114842903575298)

In [96]:
new_train_test_overlap_nostp = set(new_train_vocab_nostp).intersection(set(new_test_vocab_nostp))
len(new_train_test_overlap_nostp) / len(new_train_vocab_nostp), len(new_train_test_overlap_nostp) / len(new_test_vocab_nostp)

(0.0015700230270043961, 0.020053475935828877)

### Class distributions

In [97]:
from relation_modeling_utils import explode_labels
new_train_df, new_val_df, new_test_df = explode_labels(new_train_df), explode_labels(new_val_df), explode_labels(new_test_df)

In [98]:
new_val_df.label_0.value_counts(), new_val_df.label_1.value_counts(), new_val_df.label_2.value_counts()

(0    2729
 1    1585
 Name: label_0, dtype: int64,
 1    2261
 0    2053
 Name: label_1, dtype: int64,
 1    2557
 0    1757
 Name: label_2, dtype: int64)

In [99]:
new_test_df.label_0.value_counts(), new_test_df.label_1.value_counts(), new_test_df.label_2.value_counts()

(1    526
 0    296
 Name: label_0, dtype: int64,
 0    609
 1    213
 Name: label_1, dtype: int64,
 0    581
 1    241
 Name: label_2, dtype: int64)

In [68]:
from relation_modeling_utils import get_class_dist_report

get_class_dist_report(new_test_df)

{('class_0', 0): 0.48237476808905383,
 ('class_0', 'class_0', 0, 0): 0.48237476808905383,
 ('class_0', 'class_0', 0, 1): 0.0,
 ('class_0', 1): 0.5176252319109462,
 ('class_0', 'class_0', 1, 0): 0.0,
 ('class_0', 'class_0', 1, 1): 0.5176252319109462,
 ('class_0', 'class_1', 0, 0): 0.19851576994434136,
 ('class_0', 'class_1', 0, 1): 0.28385899814471244,
 ('class_0', 'class_1', 1, 0): 0.45732838589981445,
 ('class_0', 'class_1', 1, 1): 0.06029684601113173,
 ('class_0', 'class_2', 0, 0): 0.07792207792207792,
 ('class_0', 'class_2', 0, 1): 0.4044526901669759,
 ('class_0', 'class_2', 1, 0): 0.5111317254174397,
 ('class_0', 'class_2', 1, 1): 0.006493506493506494,
 ('class_1', 0): 0.6558441558441559,
 ('class_1', 'class_0', 0, 0): 0.19851576994434136,
 ('class_1', 'class_0', 0, 1): 0.45732838589981445,
 ('class_1', 1): 0.34415584415584416,
 ('class_1', 'class_0', 1, 0): 0.28385899814471244,
 ('class_1', 'class_0', 1, 1): 0.06029684601113173,
 ('class_1', 'class_1', 0, 0): 0.6558441558441559,
 

In [69]:
new_train_df.to_csv("data/atomic_ood/f1/train_f1.csv")
new_val_df.to_csv("data/atomic_ood/f1/val_f1.csv")
new_test_df.to_csv("data/atomic_ood/f1/test_f1.csv")