In [105]:
import os
import json
import random
from itertools import combinations as comb
from transformers import AutoTokenizer
import torch

In [49]:
DATAPATH = "../data/subtask3-coreference/en-train.json"
TRAINPATH= "../data/subtask3-coreference/train.json"
TESTPATH = "../data/subtask3-coreference/test.json"

In [50]:
with open(DATAPATH,"r") as file:
    docs = [json.loads(line) for line in file]

In [51]:
!rm ../data/subtask3-coreference/train.json
!rm ../data/subtask3-coreference/test.json

random.seed(44)
test_idx = random.sample(range(569),119)
train_idx= [i for i in range(569) if i not in test_idx]

with open("../data/subtask3-coreference/train.json","a") as file:
    for line in list(map(docs.__getitem__,train_idx)):
        json.dump(line,file)
        file.write("\n")
    file.close()
    
with open("../data/subtask3-coreference/test.json","a") as file:
    for line in list(map(docs.__getitem__,test_idx)):
        json.dump(line,file)
        file.write("\n")
    file.close()
    
for p in [TRAINPATH,TESTPATH]:
    data,sents = parsedata(p)
    codes,sents= createpairs(data,sents)
    data = pd.DataFrame(codes,columns=[1,2,3,"label"])
    print(dict(data.label.value_counts())[1]/dict(data.label.value_counts())[0])

2.155765340525961
2.234256926952141


In [52]:
def parsedata(path):
    data = {}
    sents= {}
    with open(path,"r") as file:
        for line in file:
            doc = json.loads(line) 
            doc_id    = str(doc["id"])
            doc_events= [[doc_id+"_"+str(__) for __ in _] for _ in doc["event_clusters"]]
            doc_sentno= [doc_id+"_"+str(_) for _ in doc["sentence_no"]]
            for sentno,sent in zip(doc_sentno,doc["sentences"]):
                 sents[sentno]=sent
            data[doc_id]={"events":doc_events,"sents":doc_sentno}
    return data,sents

In [53]:
random.seed(44)
def createpairs(data,sents,cross=False,ratio=None):
    allpairs_code = []
    positive = []
    pair_idx = {}
    for doc in data:
        allpairs_code+=list(map(list,comb(sorted(data[doc]["sents"]),2)))
        positive+=sum([list(map(list,list(comb(sorted(_),2)))) for _ in data[doc]["events"]],[])
    for idx,p in enumerate(allpairs_code):
        pair_idx['_'.join(p)]=idx
        allpairs_code[idx] = [idx]+p+[0.0]
        
    for p in positive:
        index = pair_idx['_'.join(p)]
        allpairs_code[index][-1] = 1.0
        
    if cross:
        allpairs_code_ = [i[1:-1] for i in allpairs_code] 
        total_num = int(len(positive)/ratio)
        missing_num = total_num-len(allpairs_code)+1
        all_ = sorted(list(set([j for i in allpairs_code for j in i[1:-1]])))
        all_p= list(map(list,(random.sample(list(comb(all_,2)),missing_num*2))))
        counter = len(allpairs_code)
        for i in all_p:
            if i not in positive and i not in allpairs_code_:
                allpairs_code.append([counter]+i+[0])
                counter+=1
            
            if counter == total_num:
                break
                
    allpairs_code = [[idx]+s[1:] for idx,s in enumerate(list(random.sample(allpairs_code,len(allpairs_code))))]
    allpairs_sent = [[sents[i] for i in s[1:-1]]+s[-1:] for s in allpairs_code]
    
    
    return allpairs_code,allpairs_sent

In [56]:
random.seed(44)
train_codes,train_sents=createpairs(*parsedata(TRAINPATH),True ,0.25)
test_codes ,test_sents =createpairs(*parsedata(TESTPATH ),False,0.25)

In [57]:
train_df = pd.DataFrame(train_codes,columns=["index","sent1","sent2","label"])
train_df

Unnamed: 0,index,sent1,sent2,label
0,0,55474_1,55474_6,0.0
1,1,55165_1,55165_14,1.0
2,2,55349_3,55400_3,0.0
3,3,55057_1,55608_1,0.0
4,4,55225_11,55225_7,1.0
...,...,...,...,...
12783,12783,55044_17,55093_13,0.0
12784,12784,55169_11,55169_21,1.0
12785,12785,55328_2,55690_1,0.0
12786,12786,55034_5,55540_22,0.0


In [58]:
train_df.label.value_counts()

0.0    9591
1.0    3197
Name: label, dtype: int64

In [59]:
test_df = pd.DataFrame(test_codes,columns=["index","sent1","sent2","label"])
test_df

Unnamed: 0,index,sent1,sent2,label
0,0,55617_1,55617_5,1.0
1,1,55498_4,55498_8,1.0
2,2,55542_17,55542_9,0.0
3,3,55172_11,55172_2,1.0
4,4,55279_1,55279_9,0.0
...,...,...,...,...
1279,1279,55097_4,55097_7,0.0
1280,1280,55154_16,55154_2,1.0
1281,1281,55533_1,55533_14,1.0
1282,1282,55381_15,55381_16,0.0


In [60]:
test_df.label.value_counts()

1.0    887
0.0    397
Name: label, dtype: int64

## Create Embeddings

In [96]:
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [103]:
def gimme_tensors(sents,tokenizer):
    II, AM, LABEL= [],[],[]
    for i,_ in enumerate(sents):
        label = sents[i][-1]
        text = "</s>".join(sents[i][1:-1])
        encoded = tokenizer.encode_plus(text,                     
                                        add_special_tokens = False,
                                        truncation=True,
                                        max_length=256,
                                        padding="max_length",
                                        return_tensors = 'pt')

        input_id = encoded['input_ids']
        attention_mask = encoded['attention_mask']

        II.append(input_id)
        AM.append(attention_mask)
        LABEL.append(label)
    II = torch.stack(II).squeeze(1)
    AM = torch.stack(AM).squeeze(1)
    LABEL =  torch.tensor(LABEL).view(-1,1).to(torch.float32)
    
    return II,AM,LABEL

In [106]:
II,AM,LABEL = gimme_tensors(test_sents,tokenizer)