In [1]:
import torch
import os
import pandas as pd
import matchzoo as mz
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
TYPE = 'classification'

classification_task = mz.tasks.Classification(num_classes=3)
classification_task.metrics = ['acc']
print(classification_task.num_classes)
print(classification_task.output_shape)
print(classification_task.output_dtype)
print(classification_task)

3
(3,)
<class 'int'>
Classification Task with 3 classes


# Data Pack Sample

In [3]:
# pack
df = pd.DataFrame(data={'text_left': list('ABBCD'),
                       'text_right': list('aacbd'),
                       'label': [-1, 1, 0, 0, 1]})
mz.pack(df, task=TYPE).frame()

Unnamed: 0,id_left,text_left,id_right,text_right,label
0,L-0,A,R-0,a,-1
1,L-1,B,R-0,a,1
2,L-1,B,R-1,c,0
3,L-2,C,R-2,b,0
4,L-3,D,R-3,d,1


In [4]:
# data_pack
left = [
    ['artid1', 'A1'],
    ['artid2', 'A2'],
    ['artid3', 'A3']
]
right = [
    ['hypoid1', 'prompt1'],
    ['hypoid2', 'prompt2'],
    ['hypoid3', 'prompt3']
]
relation = [
    ['artid1', 'hypoid1', -1],
    ['artid1', 'hypoid3', 1],
    ['artid2', 'hypoid2', 0],
    ['artid3', 'hypoid3', 1]
]

relation_df = pd.DataFrame(relation)
# relation_df
left = pd.DataFrame(left)
right = pd.DataFrame(right)
dp = mz.DataPack(
    relation=relation_df,
    left=left,
    right=right
)
# print(len(dp))
# print(type(dp.frame))
# frame_slice = dp.frame[0:5]
# type(frame_slice)
# list(frame_slice.columns)
# full_frame = dp.frame()

In [5]:
data_pack = mz.datasets.toy.load_data(stage='train')
type(data_pack)

matchzoo.data_pack.data_pack.DataPack

# Prepare input data

In [6]:
annot = pd.read_csv('./annotations/annotations_merged.csv')
print(annot.dtypes)
annot.sort_values('PMCID').head(10)

UserID              int64
PromptID            int64
PMCID               int64
Valid Label          bool
Valid Reasoning      bool
Label              object
Annotations        object
Label Code          int64
In Abstract          bool
Evidence Start      int64
Evidence End        int64
dtype: object


Unnamed: 0,UserID,PromptID,PMCID,Valid Label,Valid Reasoning,Label,Annotations,Label Code,In Abstract,Evidence Start,Evidence End
13994,0,7902,60007,True,True,significantly increased,Mivacurium 250 μg/kg produced a maximal T bloc...,1,True,1280,1395
14000,3,7904,60007,True,True,no significant difference,Heart rate was similar between doses,0,True,1452,1488
12483,0,7905,60007,True,True,significantly decreased,while both AUC-SBP and AUC-DBP were significan...,-1,False,16737,16836
12484,1,7905,60007,True,True,significantly decreased,"In relation to the cardiovascular response, th...",-1,False,16641,16836
13999,0,7904,60007,True,True,no significant difference,"In relation to the cardiovascular response, th...",0,False,16641,16736
13998,3,7903,60007,True,True,no significant difference,Spontaneous recovery times were similar in bot...,0,True,1396,1450
13997,0,7903,60007,True,True,no significant difference,The times to OA and to spontaneous recovery of...,0,False,14755,15079
13995,3,7902,60007,True,True,significantly increased,Mivacurium 250 μg/kg produced a maximal T bloc...,1,True,1280,1395
2855,0,1806,111193,True,True,no significant difference,The amount of blood transfusion was identical ...,0,True,1418,1578
2334,0,1808,111193,True,True,no significant difference,There was no difference in time spent in hospi...,0,False,14701,14770


In [7]:
"""Process txt file for Articles input"""
TXT_PATH = './annotations/txt_files/'
TAR_PATH = './annotations/processed_txt_files/'
look_up = {}

if not os.path.exists(TAR_PATH):
    os.mkdir(TAR_PATH)
# else: os.removedirs(TAR_PATH)
for file in os.listdir(TXT_PATH):
    fname = file[3:]
    look_up[int(file[3:-4])] = fname
    with open(TXT_PATH+file, 'r', encoding='utf-8') as f, open(TAR_PATH+fname, 'w', encoding='utf-8') as t:
        for line in f.readlines():
            t.write(line.strip())

In [8]:
"""Articles Map"""
left_articles = []
for file in os.listdir(TAR_PATH):
    with open(TAR_PATH+file, 'r', encoding='utf-8') as f:
        left_articles.append([int(file[:-4]), f.readlines()[0]])
# print(left_articles[0:5])

"""Prompts Map"""
right_prompts = [list(pair) for pair in zip(annot['PromptID'].values, annot['Annotations'].values)]
print(right_prompts[0:5])

"""Articles <-> Prompts Map"""
article_prompt_relations = [list(triplet) for triplet in zip(annot['PMCID'].values, 
                                                             annot['PromptID'].values, 
                                                             annot['Label Code'].values)]
print(article_prompt_relations[0:5])

"""Create Data-pack"""
left = pd.DataFrame(left_articles, columns=['id_left', 'text_left'])
right = pd.DataFrame(right_prompts, columns=['id_right', 'text_right'])
relation = pd.DataFrame(article_prompt_relations, columns=['id_left', 'id_right', 'label'])
dp = mz.DataPack(
    relation = relation,
    left = left,
    right = right
)

# print(left)
# print(len(dp))
# print(type(dp.frame))
# frame = dp.frame
# print(list(frame().columns))
# # frame_slice = dp.frame[0:5]
# # type(frame_slice)
# # list(frame_slice.columns)
# full_frame = dp.frame()

[[213, 'IL-6r (ng/ml)\t\t\t\t\t\t\t Group A\t43.6 (1.7–125.0)\t\t47.4 (0.7–109.5)\t56.2 (25.2–226.3)\t\t\t0.949† Group B\t40.7 (15.6–94.6)\t\t42.4 (22.2–100.5)\t50.2 (13.2–104.9)\t\t\t0.861†\tp = 0.607*\t\tp = 0.914*\tp = 0.304*'], [213, 'There was no significant difference in IL 6, IL-6r and C-reactive protein values between groups.'], [213, 'There was no significant difference in IL 6, IL-6r and C-reactive protein values between groups'], [213, 'There was no significant difference in IL 6, IL-6r and C-reactive protein values between groups'], [98, 'After two weeks of treatment, the reduction in ulcer area was doubled in the HBOT group (P = 0.037)']]
[[2206488, 213, 0], [2206488, 213, 0], [2206488, 213, 0], [2206488, 213, 0], [2858204, 98, 1]]


In [9]:
""" Read article by name """
def read_article(id) -> str:
    with open(TAR_PATH+look_up[id], 'r', encoding='utf-8') as f:
        return f.readlines()[0]

text_left = [read_article(id) for id in annot['PMCID']]
# print(text_left[0:5])

In [10]:
df = pd.DataFrame(data={
    'id_left': annot['PMCID'],
    'text_left': text_left,
    'id_right': annot['PromptID'],
    'text_right': annot['Annotations'],
    'label': annot['Label Code']
})

""" Split data pack into train/valid """
train, valid = train_test_split(df, test_size=0.2)
train_pack = mz.pack(train, task=TYPE)
valid_pack = mz.pack(valid, task=TYPE)
train_pack.frame().head(10) # DataFrame

Unnamed: 0,id_left,text_left,id_right,text_right,label
0,5122106,TITLE:ABSTRACT.BACKGROUND::Obesity is a worldw...,6192,There was no severe drug reaction or other com...,0
1,4472927,TITLE: Effect ofABSTRACT:Different amounts of ...,9556,The serum IL-10 in group A was significantly h...,0
2,4065280,"TITLE: A Randomized, Double-Blind, Placebo-Con...",5881,There was a significant mean reduction in tota...,-1
3,3290117,TITLE: Bootcamp During Neoadjuvant Chemotherap...,3623,Final tumor size was 2.59 cm in the NC + BC gr...,0
4,524504,TITLE: Comparison of Misoprostol and Dinoprost...,56,"<td colspan=""6""><hr></td>",0
5,4883760,TITLE: The Effectiveness of an Educational Gam...,6021,Post-intervention test scores increased signif...,1
6,5716426,TITLE: Improvement in children’s fine motor sk...,8869,There was no significant difference between gr...,0
7,3395326,TITLE: The Use of Epoetin-ABSTRACT:Introductio...,3758,There was no difference in average hidden bloo...,0
8,5341634,TITLE: Perineural Nalbuphine in Ambulatory Upp...,7820,time to first analgesic use were significantly...,1
9,4532714,TITLE: Synchronized personalized music audio-p...,7544,Patients randomized to personalized audio-play...,1


In [11]:
dp = mz.pack(df, task=TYPE)
print(type(dp.frame))
frame_slice = dp.frame[0:5]
print(list(frame_slice.columns))
full_frame = dp.frame()
len(full_frame) == len(dp)

<class 'matchzoo.data_pack.data_pack.DataPack.FrameView'>
['id_left', 'text_left', 'id_right', 'text_right', 'label']


True

# Model and Train 

In [12]:
preprocessor = mz.models.DIIN.get_default_preprocessor()
train_processed = preprocessor.fit_transform(train_pack)
valid_processed = preprocessor.transform(valid_pack)

Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2398/2398 [01:28<00:00, 27.16it/s]
Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 9443/9443 [00:03<00:00, 2425.95it/s]
Processing text_right with append: 100%|██████████| 9443/9443 [00:00<00:00, 652812.92it/s]
Building FrequencyFilter from a datapack.: 100%|██████████| 9443/9443 [00:00<00:00, 116663.56it/s]
Processing text_right with transform: 100%|██████████| 9443/9443 [00:00<00:00, 115986.09it/s]
Processing text_left with extend: 100%|██████████| 2398/2398 [00:00<00:00, 14620.35it/s]
Processing text_right with extend: 100%|██████████| 9443/9443 [00:00<00:00, 467728.87it/s]
Building Vocabulary from a datapack.: 100%|██████████| 8811439/8811439 [00:02<00:00, 3138053.76it/s]
Processing text_left with transform: 100%|██████████| 2398/2398 [00:26<00:00, 90.98it/s] 
Processing text_right with transform: 100%|██████████| 9443/9443 [00:01<

In [14]:
trainset = mz.dataloader.Dataset(
    data_pack=train_processed,
    mode='pair',
    num_dup=1,
    num_neg=4
)
validset = mz.dataloader.Dataset(
    data_pack=valid_processed,
    mode='point'
)

In [15]:
padding_callback = mz.models.DIIN.get_default_padding_callback()

trainloader = mz.dataloader.DataLoader(
    dataset=trainset,
    stage='train',
    callback=padding_callback
)
validloader = mz.dataloader.DataLoader(
    dataset=validset,
    stage='dev',
    callback=padding_callback
)

In [17]:
model = mz.models.DIIN()
model.params['task'] = classification_task
model.params['embedding_output_dim'] = 100
model.params['embedding_input_dim'] = preprocessor.context['embedding_input_dim']
model.guess_and_fill_missing_params()
model.build()

In [18]:
optimizer = torch.optim.Adam(model.parameters())

trainer = mz.trainers.Trainer(
    model=model,
    optimizer=optimizer,
    trainloader=trainloader,
    validloader=validloader,
    epochs=10
)

trainer.run()

HBox(children=(IntProgress(value=0, max=209), HTML(value='')))




KeyError: 'ngram_left'

In [19]:
pwd

'/Users/shanewang/Desktop/Codes/evidence-inference'