In [18]:
from kogito.core.relation import PHYSICAL_RELATIONS, SOCIAL_RELATIONS, EVENT_RELATIONS

def load_data(datapath):
    data = []
    head_label_set = set()

    with open(datapath) as f:
        for line in f:
            try:
                head, relation, _ = line.split('\t')

                label = 0 

                if relation in EVENT_RELATIONS:
                    label = 1
                elif relation in SOCIAL_RELATIONS:
                    label = 2

                if (head, label) not in head_label_set:
                    data.append((head, label))
                    head_label_set.add((head, label))
            except:
                pass

    return data

In [23]:
train_data = load_data("/Users/mismayil/Desktop/EPFL/nlplab/comet-atomic-2020/data/atomic2020_data-feb2021/train.tsv")
dev_data = load_data("/Users/mismayil/Desktop/EPFL/nlplab/comet-atomic-2020/data/atomic2020_data-feb2021/dev.tsv")

In [24]:
len(train_data), len(dev_data)

(53891, 4823)

In [34]:
import spacy
from tqdm import tqdm
nlp = spacy.load("en_core_web_lg")

label_map = ['physical', 'event', 'social']

def make_docs(data):
    """
    this will take a list of texts and labels 
    and transform them in spacy documents
    
    data: list(tuple(text, label))
    
    returns: List(spacy.Doc.doc)
    """
    
    docs = []
    
    for doc, label in tqdm(nlp.pipe(data, as_tuples=True), total = len(data)):
        
        for label_txt in label_map:
            doc.cats[label_txt] = 0

        doc.cats[label_map[label]] = 1
        
        # put them into a nice list
        docs.append(doc)
    
    return docs

In [35]:
from spacy.tokens import DocBin

train_docs = make_docs(train_data)
dev_docs = make_docs(dev_data)

# then we save it in a binary file to disc
train_doc_bin = DocBin(docs=train_docs)
train_doc_bin.to_disk("./data/train.spacy")

dev_doc_bin = DocBin(docs=dev_docs)
dev_doc_bin.to_disk("./data/dev.spacy")

100%|██████████| 53891/53891 [00:45<00:00, 1176.59it/s]
100%|██████████| 4823/4823 [00:04<00:00, 1183.67it/s]
