<a href="https://colab.research.google.com/github/mahopman/CTP/blob/main/CTPSetUp_10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
local_path = "/content/drive/MyDrive/MS_DataScience/DS595/CTP"

Unzip materials

In [None]:
!unzip {local_path}/evidence_integration/materials.zip -d {local_path}/evidence_integration

Archive:  /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials.zip
   creating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/sec2label.json  
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/split2ids.json  
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/pmc_contents.json  
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/pmcid2picoid.json  
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/secname2sec.json  
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/sec2count.json  
  inflating: /content/drive/MyDrive/MS_DataScience/DS595/CTP/evidence_integration/materials/prompt_labels.json  
  inflating: /content/drive/MyDrive/MS_

Generate Evidence Integration

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

import json

def generate(picoids, split_path):
	output = []
	for picoid in picoids:
		if prompt_info[picoid]['label'] != 'invalid prompt':
			output.append({})
			output[-1]['picoid'] = picoid
			output[-1]['pmcid'] = prompt_info[picoid]['PMCID']
			output[-1]['i_text'] = prompt_info[picoid]['I']
			output[-1]['c_text'] = prompt_info[picoid]['C']
			output[-1]['o_text'] = prompt_info[picoid]['O']
			output[-1]['label'] = result2label[prompt_info[picoid]['label']]

			passage = ''
			if str(prompt_info[picoid]['PMCID']) in pmcid2content:
				content = pmcid2content[str(prompt_info[picoid]['PMCID'])]
				for secname, text in content:
					if secname[:len('ABSTRACT')] != 'ABSTRACT': continue
					if sec2label[secname2sec[secname]] == '1':
						passage += text

			output[-1]['passage'] = passage

	with open(split_path, 'w') as f:
		json.dump(output, f, indent=4)

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

result2label = {'significantly decreased': 0,\
				'no significant difference': 1,\
				'significantly increased': 2}

prompt_info = json.load(open(local_path + '/evidence_integration/materials/prompt_info.json'))
split2ids = json.load(open(local_path + '/evidence_integration/materials/split2ids.json'))
pmcid2picoid = json.load(open(local_path + '/evidence_integration/materials/pmcid2picoid.json'))
pmcid2content = json.load(open(local_path + '/evidence_integration/materials/pmc_contents.json'))
secname2sec = json.load(open(local_path + '/evidence_integration/materials/secname2sec.json'))
sec2label = json.load(open(local_path + '/evidence_integration/materials/sec2label.json'))

for split, ids in split2ids.items():
    picoids = []

    for pmcid in ids:
	    pmcid = str(pmcid)
        if pmcid in pmcid2picoid:
		    picoids += pmcid2picoid[pmcid]

    split_path = f"{local_path}/evidence_integration/{split}.json"
    generate(picoids, split_path)

Index dataset

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

splits = ['train', 'validation', 'test']

for split in splits:
    split_path = local_path + '/evidence_integration/' + split + '.json'
    ori_data = json.load(open(split_path))
    picos = []
    ctxs = []
    pmcid2ctxid = {}


    for entry in ori_data:
        pico = {k: entry[k] for k in ['i_text', 'c_text', 'o_text', 'label']}
        pmcid = entry['pmcid']

        if pmcid not in pmcid2ctxid:
            pmcid2ctxid[pmcid] = len(ctxs)
            ctx = {'ctx_id': pmcid2ctxid[pmcid], 'passage': entry['passage']}
            ctxs.append(ctx)

        pico['ctx_id'] = pmcid2ctxid[pmcid]
        picos.append(pico)

    with open(f"{local_path}/evidence_integration/indexed_{split}_picos.json", 'w') as f:
        json.dump(picos, f, indent=4)
    with open(f"{local_path}/evidence_integration/indexed_{split}_ctxs.json", 'w') as f:
        json.dump(ctxs, f, indent=4)

Download and unzip PubMed baseline splits

In [None]:
__author__ = 'Mia Hopman'

import os
from ftplib import FTP
import gzip

def download_and_extract_gz_files(ftp_server, ftp_path, destination_path):
    if not os.path.exists(destination_path):
        os.makedirs(destination_path)

    ftp = FTP(ftp_server)
    ftp.login()
    ftp.cwd(ftp_path)

    file_list = ftp.nlst()

    for file_name in file_list:
        if os.path.exists(os.path.join(destination_path, file_name)):
            print(f'{file_name} already exists. Skipping download.')
        else:
            if file_name.endswith('.gz'):
                local_file_path = os.path.join(destination_path, file_name)
                with open(local_file_path, 'wb') as f:
                    ftp.retrbinary(f'RETR {file_name}', f.write)

                with gzip.open(local_file_path, 'rb') as gz_file:
                    uncompressed_file_path = os.path.splitext(local_file_path)[0]
                    with open(uncompressed_file_path, 'wb') as uncompressed_file:
                        uncompressed_file.write(gz_file.read())

                print(f'Downloaded and extracted {file_name}')

    ftp.quit()

In [None]:
__author__ = 'Mia Hopman'

ftp_server = 'ftp.ncbi.nlm.nih.gov'
ftp_path = '/pubmed/baseline/'
destination_path = f"{local_path}/pretraining_dataset/pubmed_baseline"

download_and_extract_gz_files(ftp_server, ftp_path, destination_path)

pubmed24n0002.xml.gz already exists. Skipping download.
pubmed24n0004.xml.gz already exists. Skipping download.
pubmed24n0003.xml.gz already exists. Skipping download.
pubmed24n0001.xml.gz already exists. Skipping download.
pubmed24n0006.xml.gz already exists. Skipping download.
pubmed24n0008.xml.gz already exists. Skipping download.
pubmed24n0005.xml.gz already exists. Skipping download.
pubmed24n0009.xml.gz already exists. Skipping download.
pubmed24n0007.xml.gz already exists. Skipping download.
pubmed24n0010.xml.gz already exists. Skipping download.
pubmed24n0013.xml.gz already exists. Skipping download.
pubmed24n0014.xml.gz already exists. Skipping download.
pubmed24n0011.xml.gz already exists. Skipping download.
pubmed24n0012.xml.gz already exists. Skipping download.
pubmed24n0015.xml.gz already exists. Skipping download.
pubmed24n0016.xml.gz already exists. Skipping download.
pubmed24n0018.xml.gz already exists. Skipping download.
pubmed24n0017.xml.gz already exists. Skipping do

Preprocess PubMed splits

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

import glob
import xml.etree.ElementTree as ET

for xml_path in glob.glob(f"{local_path}/pretraining_dataset/pubmed_baseline/pubmed24n*.xml"):
    xml_file = xml_path.split('/')[-1]
    print('Processing %s' % xml_path)
    output = []

    tree = ET.parse(xml_path)
    root = tree.getroot()

    for citation in root.iter('MedlineCitation'):
        pmid = citation.find('PMID')
        if pmid == None:
            continue
        else:
            pmid = pmid.text

        texts = []
        sec_labels = []

        title = citation.find('Article/ArticleTitle')
        if title != None:
            texts.append(title.text)
            sec_labels.append('TITLE')

        for info in citation.iter('AbstractText'):
            if info.text:
                texts.append(info.text)
                sec_labels.append(info.get('Label'))

        assert len(texts) == len(sec_labels)

        output.append({'pmid': pmid,
                       'texts': texts,
                       'sec_labels': sec_labels})

    with open(f"{xml_path.split('.')[0]}.json", 'w') as f:
        json.dump(output, f, indent=4)

Download Stanford POS tagger

In [None]:
!wget http://nlp.stanford.edu/software/stanford-postagger-full-2015-04-20.zip -P {local_path}/stanford_pos
!unzip {local_path}/stanford_pos/stanford-postagger-full-2015-04-20.zip -d {local_path}/stanford_pos
!ls {local_path}/stanford_pos/stanford-postagger-full-2015-04-20
!export STANFORDTOOLSDIR=${local_path}/stanford_pos/stanford-postagger-full-2015-04-20
!export CLASSPATH=${STANFORDTOOLSDIR}/stanford-postagger.jar
!export STANFORD_MODELSDIR=${STANFORDTOOLSDIR}/models

Tag dataset

In [None]:
__author__ = 'Qiao Jin'

def mask_and_label(sent):
	sent = ' ' + sent
	lower_sent = sent.lower()

	if ' than ' in lower_sent and all([exc not in lower_sent for exc in exclude]):
		words = tokenize.word_tokenize(sent)
		lowers = [word.lower() for word in words]
		words_ctr = Counter(lowers)

		if words_ctr['than'] == 1: # more than 1 are not useful (mostly describing only quantitative relations)
			than_idx = lowers.index('than')
			inter = set(lowers[:than_idx]).intersection(key_words)

			if len(inter) >= 1:
				up_indices = [1 if word.lower() in ups else 0 for word in words]
				down_indices = [1 if word.lower() in downs else 0 for word in words]

				if any(up_indices) and not any(down_indices):
					if than_idx + 1 < len(lowers) and (lowers[than_idx+1].isnumeric() or lowers[than_idx+1] in nums):
						pass
					else:
						indices = [idx_ for idx_, up in enumerate(up_indices) if up == 1] + [than_idx]
						final = words
						direction = 2

				elif any(down_indices) and not any(up_indices):
					if than_idx + 1 < len(lowers) and (lowers[than_idx+1].isnumeric() or lowers[than_idx+1] in nums):
						pass
					else:
						indices = [idx_ for idx_, down in enumerate(down_indices) if down == 1] + [than_idx]
						final = words
						direction = 0

	elif ' similar' in lower_sent and ' to ' in lower_sent:
		words = tokenize.word_tokenize(sent)
		lowers = [word.lower() for word in words]
		words_ctr = Counter(lowers)

		for idx, lower in enumerate(lowers):
			if lower in sims:
				sim_idx = idx
				break

		if 'sim_idx' in locals():
			if 'to' in lowers[sim_idx:]:
				to_idx = sim_idx + lowers[sim_idx:].index('to')

				indices = [sim_idx] + [to_idx]
				final = words
				direction = 1

	elif ' no' in lower_sent and ' differ' in lower_sent and 'and' in lower_sent:
		words = tokenize.word_tokenize(sent)
		lowers = [word.lower() for word in words]
		words_ctr = Counter(lowers)

		for idx, word in enumerate(lowers):
			if word in diffs:
				diff_idx = idx
				break

		if 'diff_idx' in locals():
			# first find the left no, then scan the middle words
			for i in range(idx):
				word = words[idx-1-i]
				if word in nos:
					no_idx = idx-1-i
					break

			if 'no_idx' in locals():
				bet_indices = [1 if word == 'between' and idx > diff_idx else 0 for idx, word in enumerate(lowers)]

				if any(bet_indices):
					first_bet = bet_indices.index(1)
					if 'and' in lowers[first_bet:]:
						and_idx = first_bet + lowers[first_bet:].index('and')
						indices = list(range(no_idx, diff_idx+1)) + [idx for idx, bet in enumerate(bet_indices) if bet == 1] + [and_idx]
						final = words
						direction = 1

	if 'final' in locals() and 'direction' in locals():
		if type(final) == list and len(final) > 0 and final[-1] != '?':
			return [final, direction, indices]

	else:
		return False

In [None]:
__author__ = 'Qiao Jin'

def process(item):
	# an item is an article
	# also need to save the context
	# as well as save the evidence
	pmid = item['pmid']
	texts = item['texts']
	labels = item['sec_labels']

	evi_output = []
	ctx_output = {'pmid': pmid, 'ctx': ''}

	bg_status = True

	for text, label in zip(texts, labels):
		if label == 'TITLE': continue
		sents = tokenize.sent_tokenize(text)
		if not label or label not in sec2label:
			for sent in sents:
				result = mask_and_label(sent)
				if result:
					bg_status = False
					evi_output.append({'pmid': pmid, 'pos': result[0], 'label': result[1], 'indices': result[2]})
				else:
					if bg_status:
						ctx_output['ctx'] += ' ' + sent
		else:
			judge = sec2label[label]

			if judge == '1': # all background
				ctx_output['ctx'] += ' ' + text
			else:
				bg_status = False # starting no background
				for sent in sents:
					result = mask_and_label(sent)
					if result: evi_output.append({'pmid': pmid, 'pos': result[0], 'label': result[1], 'indices': result[2]})

	return evi_output, ctx_output

In [None]:
## FIGURE OUT HOW TO DOWNLOAD sec2label.json and bad_pmids.json ##

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

from collections import Counter

import nltk
from nltk import tokenize
from nltk.tag import StanfordPOSTagger
from nltk.corpus import stopwords

import sys

import random as rd

nltk.download('punkt')

'''
used to pseudo label the dataset
'''

jar = local_path + '/stanford_pos/stanford-postagger-full-2015-04-20/stanford-postagger.jar'
model = local_path + '/stanford_pos/stanford-postagger-full-2015-04-20/models/english-left3words-distsim.tagger'

st = StanfordPOSTagger(model, jar, encoding='utf8')

exclude = set(['rather than', 'other than'])
ups = set(['better', 'greater', 'higher', 'later', 'more', 'faster', 'older', 'longer', \
		'larger', 'broader', 'wider', 'stronger', 'deeper', 'more', 'commoner', 'richer', \
		'further', 'bigger'])
downs = set(['worse', 'smaller', 'lower', 'earlier', 'less', 'slower', 'younger', 'shorter', \
		'smaller', 'narrower', 'narrower', 'weaker', 'shallower', 'fewer', 'rarer', 'poorer', \
		'closer', 'smaller'])

key_words = ups.union(downs)

diffs = set(['difference', 'differences', 'different', 'differently', 'differ'])
sims = set(['similar', 'similarly', 'similarity', 'similarities'])

nos = set(['no', 'not'])
middles = set(['significant', 'significantly', 'statistic', 'statistically', 'statistical'])

nums = set(["twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety", "zero", \
            "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", \
            "twelve", "thirteen", "fourteen", "fifteen", "sixteen", "seventeen", "eighteen", "nineteen"])

sec2label = json.load(open(f"{local_path}/pretraining_dataset/sec2label.json"))

# CHANGED FROM 1016
chunks = list(range(1, 10))
rd.shuffle(chunks)

total = 0

for idx, chunk_id in enumerate(chunks):
	if not os.path.exists(f"{local_path}/pretraining_dataset/evidence/evidence_pos_{chunk_id:04d}_10.json"):
		data_path = f"{local_path}/pretraining_dataset/pubmed_baseline/pubmed24n{chunk_id:04d}.json"
		if not os.path.exists(data_path): continue
		evi_output = []
		ctx_output = []
		data = json.load(open(data_path))

		for item in data:
			results = process(item)
			evi_output += results[0]
			ctx_output.append(results[1])

		pos_list = st.tag_sents(o['pos'] for o in evi_output)

		for _idx in range(len(evi_output)):
			evi_output[_idx]['pos'] = pos_list[_idx]

		with open(f"{local_path}/pretraining_dataset/evidence/evidence_pos_{chunk_id:04d}_10.json", 'w') as f:
			json.dump(evi_output, f)
		with open(f"{local_path}/pretraining_dataset/evidence/contexts_{chunk_id:04d}_10.json", 'w') as f:
			json.dump(ctx_output, f)

	else:
		evi_output = json.load(open(f"{local_path}/pretraining_dataset/evidence/evidence_pos_{chunk_id:04d}_10.json"))

	total += len(evi_output)

	print('%d/%d; Processing %s; Number of evidence: %d; Total: %d' % (idx+1, len(chunks), chunk_id, len(evi_output), total))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


1/9; Processing 9; Number of evidence: 1255; Total: 1255
2/9; Processing 8; Number of evidence: 4601; Total: 5856
3/9; Processing 1; Number of evidence: 4798; Total: 10654
4/9; Processing 7; Number of evidence: 4786; Total: 15440
5/9; Processing 2; Number of evidence: 3804; Total: 19244
6/9; Processing 6; Number of evidence: 4801; Total: 24045
7/9; Processing 3; Number of evidence: 3172; Total: 27217
8/9; Processing 4; Number of evidence: 3890; Total: 31107
9/9; Processing 5; Number of evidence: 3395; Total: 34502


Process tags

In [None]:
__author__ = 'Qiao Jin'

def reversed(words, label):

	all_rev = (words + ['MASK'])[::-1]

	mask_idx = [idx for idx, word in enumerate(all_rev) if word == '[MASK]']
	mask_idx = [0] + mask_idx + [len(all_rev)]

	for i, idx in enumerate(mask_idx[:-1]):
		all_rev[idx+1: mask_idx[i+1]] = all_rev[idx+1: mask_idx[i+1]][::-1]

	all_rev = all_rev[1:]

	if label in up2down:
		rev_label = up2down[label]
	elif label in down2up:
		rev_label = down2up[label]
	else:
		rev_label = label

	return all_rev, rev_label

In [None]:
__author__ = 'Qiao Jin'

def get_label(pos, indices, label2ctr):
	ind_words = [pos[ind][0] for ind in indices]

	if len(ind_words) == 2:
		label = ind_words[0].lower()
		if label not in label2idx:
			return False
		else:
			label2ctr[label] += 1
			return label2idx[label]
	else:
		if ind_words[-1] == 'than':
			return False
		else:
			label2ctr['nodiff'] += 1
			return label2idx['nodiff']

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

# process labels
ups = ['better', 'greater', 'higher', 'later', 'more', 'faster', 'older', 'longer', \
		'larger', 'broader', 'wider', 'stronger', 'deeper', 'more', 'commoner', 'richer', \
		'further', 'bigger']
downs = ['worse', 'smaller', 'lower', 'earlier', 'less', 'slower', 'younger', 'shorter', \
		'smaller', 'narrower', 'narrower', 'weaker', 'shallower', 'fewer', 'rarer', 'poorer', \
		'closer', 'smaller']
sims = ['nodiff', 'similar']

label_set =  list(set(downs)) + list(set(sims)) + list(set(ups))

label2idx = {label: idx for idx, label in enumerate(label_set)}
label2idx['similarly'] = label2idx['similar']
label2idx['similarity'] = label2idx['similar']
label2idx['similarities'] = label2idx['similar']
label2idx['farther'] = label2idx['further']

label2ctr = {k: 0 for k in list(label2idx)}

up2down = {label2idx[k]: label2idx[v] for k, v in zip(ups, downs)}
down2up = {label2idx[k]: label2idx[v] for k, v in zip(downs, ups)}

# start
output = []

removed = set(['CD'])
indicators = set(['significant', 'significantly', 'statistically', 'statistic', '%'])
sims = set(['similar', 'similarly', 'similarity', 'similarities'])

pmids = set()

## CHANGED FROM 1016
for chunk_id in range(1, 10):
	data_path = f"{local_path}/pretraining_dataset/evidence/evidence_pos_{chunk_id:04d}_10.json"
	if not os.path.exists(data_path): continue

	data = json.load(open(data_path))

	for item in data:
		# first detect parentheses
		pmid = item['pmid']
		pos = item['pos']
		indices = item['indices']

		label = get_label(pos, indices, label2ctr)

		if not label: continue # lose about ~20%

		par_stack = []
		idx_stack = []
		lefts = []
		rights = []
		for idx, info in enumerate(pos):
			if info[0] in {'(', ')'}:
				if not par_stack:
					if info[0] == ')': continue
					par_stack.append(info[0])
					idx_stack.append(idx)
				else:
					if par_stack[-1] == info[0]:
						par_stack.append(info[0])
						idx_stack.append(idx)
					else:
						par_stack = par_stack[:-1]
						lefts.append(idx_stack[-1])
						rights.append(idx)
						idx_stack = idx_stack[:-1]

		within_par = []
		if lefts and rights:
			for left, right in zip(lefts, rights):
				within_par += list(range(left, right+1))

		# detect irrelavent subsentences
		dot_indices = [idx for idx, info in enumerate(pos) if info[0] == ',']
		outer_idx = []
		if dot_indices:
			left, right = min(item['indices']), max(item['indices'])
			# item['indices'] save the important indices
			dot_indices = [-1] + dot_indices + [len(pos)]
			# print(left, right, dot_indices)
			for i in range(len(dot_indices)-1):
				if dot_indices[i] <= left < dot_indices[i+1]:
					left_start = i
				if dot_indices[i] <= right < dot_indices[i+1]:
					right_start = i
			left = dot_indices[left_start]
			right = dot_indices[right_start+1]
			for i in range(len(pos)):
				if i <= left or i >= right:
					outer_idx.append(i)

		# detect irrelavent show that / suggest that

		# RB before JJR in generally bad
		include_idx = []
		that_judged = False # only judge once
		for idx, i in enumerate(pos):
			if idx in outer_idx:
				#print(i, '----------OUT')
				pass
			elif idx+1 < len(pos) and (pos[idx+1][1] == 'JJR' or pos[idx+1][1] == 'RBR') and \
				((i[1] == 'RB' and i[0].lower() != 'not') \
				or i[0].lower() == 'times'):
				#print(i, '----------FRONT_RB')
				pass
			elif idx in item['indices']:
				#print(i, '----------DETECTED')
				pass
			elif i[1] in removed:
				#print(i, '----------TOREMOVE')
				pass
			elif i[0].lower() in indicators:
				#print(i, '----------INDICATOR')
				pass
			elif idx in within_par:
				#print(i, '----------INPAR')
				pass
			elif not that_judged and i[0].lower() == 'that':
				if idx < min(item['indices']):
					#print(i, '----------THAT')
					that_judged = True
					include_idx = []
			else:
				#print(i)
				include_idx.append(idx)

		final_evidence = []
		for idx, i in enumerate(pos):

			if idx in include_idx:
				final_evidence.append(i[0])
			else:
				if  final_evidence and final_evidence[-1] != '[MASK]':
					final_evidence.append('[MASK]')

		if not final_evidence: continue

		if final_evidence[-1] in ['.', '[MASK]']:
			final_evidence = final_evidence[:-1]

		# Make every word after [MASK] upper cased
		for idx, word in enumerate(final_evidence):
			if idx == 0 and word != '[MASK]':
				final_evidence[idx] = word[0].upper() + word[1:]
			elif word == '[MASK]' and idx + 1 < len(final_evidence) and final_evidence[idx+1]:
				final_evidence[idx+1] = final_evidence[idx+1][0].upper() + final_evidence[idx+1][1:]

		rev_evidence, rev_label = reversed(final_evidence, label)

		output.append({'pmid': pmid,
				'pico': ' '.join(final_evidence), 'label': label,
				'rev_pico': ' '.join(rev_evidence), 'rev_label': rev_label})

		pmids.add(pmid)

	print('Processed chunk #%d. Got %d insts' % (chunk_id, len(output)))

with open(f'{local_path}/pretraining_dataset/evidence_10.json', 'w') as f:
	json.dump(output, f, indent=4)

with open(f'{local_path}/pretraining_dataset/evidence_pmids_10.json', 'w') as f:
	json.dump(list(pmids), f, indent=4)

Processed chunk #1. Got 4541 insts
Processed chunk #2. Got 8151 insts
Processed chunk #3. Got 11161 insts
Processed chunk #4. Got 14847 insts
Processed chunk #5. Got 18087 insts
Processed chunk #6. Got 22672 insts
Processed chunk #7. Got 27234 insts
Processed chunk #8. Got 31605 insts
Processed chunk #9. Got 32773 insts


Aggregate contexts

In [None]:
## DELETE THIS BEFORE SUBMISSION ##
import glob

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

'''
codes to aggregate the contexts
only aggregate the needed contexts
'''

ctxs = glob.glob(f'{local_path}/pretraining_dataset/evidence/contexts_*')

pmids = set(json.load(open(f'{local_path}/pretraining_dataset/evidence_pmids_10.json')))
bad_pmids = set(json.load(open(f'{local_path}/pretraining_dataset/bad_pmids.json')))

pmid2ctx = {}

for ctx in ctxs:
	if '_10.json' in ctx:
		print('Processing %s' % ctx)
		data = json.load(open(ctx))

		for item in data:
			pmid = item['pmid']
			ctx = item['ctx']

			if pmid not in pmids or pmid in bad_pmids: continue

			pmid2ctx[pmid] = ctx

with open(f'{local_path}/pretraining_dataset/pmid2ctx_10.json', 'w') as f:
	json.dump(pmid2ctx, f, indent=4)

Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0009_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0008_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0001_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0007_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0002_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0006_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0003_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0004_10.json
Processing /content/drive/MyDrive/MS_DataScience/DS595/CTP/pretraining_dataset/evidence/contexts_0005_10.json


Index dataset

In [None]:
## DELETE THIS BEFORE SUBMISSION ##
import json
local_path = "/content/drive/MyDrive/MS_DataScience/DS595/CTP"

In [None]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

pmid2ctxid = {}

pmid2ctx = json.load(open(f'{local_path}/pretraining_dataset/pmid2ctx_10.json'))
evidence = json.load(open(f'{local_path}/pretraining_dataset/evidence_10.json'))

indexed_evidence = []
indexed_contexts = []

for entry in evidence:
	pmid = entry['pmid']
	if pmid not in pmid2ctx: continue

	if pmid not in pmid2ctxid:
		pmid2ctxid[pmid] = len(pmid2ctxid)
		indexed_contexts.append({'passage': pmid2ctx[pmid], 'ctx_id': pmid2ctxid[pmid]})

	entry['ctx_id'] = pmid2ctxid[pmid]

	indexed_evidence.append(entry)

with open(f'{local_path}/pretraining_dataset/indexed_evidence_10.json', 'w') as f:
	json.dump(indexed_evidence, f, indent=4)
with open(f'{local_path}/pretraining_dataset/indexed_contexts_10.json', 'w') as f:
	json.dump(indexed_contexts, f, indent=4)

Run EBM-Net

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [None]:
##ONE TIME IMPORTS##
import random as rd
import json
import os
local_path = "/content/drive/MyDrive/MS_DataScience/DS595/CTP"

In [None]:
import numpy as np
import torch

def set_seed(args):
	rd.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)
	if args.n_gpu > 0:
		torch.cuda.manual_seed_all(args.seed)

In [None]:
def to_list(tensor):
	return tensor.detach().cpu().tolist()

In [None]:
def to_list(tensor):
	return tensor.detach().cpu().tolist()

In [None]:
from torch.utils.data import DataLoader, RandomSampler
from transformers import AdamW, get_cosine_schedule_with_warmup
from tqdm import tqdm, trange

def train(args, train_picos, train_ctxs, model, tokenizer):
	""" Train the model """
	#tb_writer = SummaryWriter()

	args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
	train_sampler = RandomSampler(train_picos)
	train_dataloader = DataLoader(train_picos, sampler=train_sampler, batch_size=args.train_batch_size)

	if args.max_steps > 0:
		t_total = args.max_steps
		args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
	else:
		t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

	# Prepare optimizer and schedule (linear warmup and decay)
	no_decay = ["bias", "LayerNorm.weight"]
	optimizer_grouped_parameters = [
		{
			"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
			"weight_decay": args.weight_decay,
		},
		{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
	]
	optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
	scheduler = get_cosine_schedule_with_warmup(
		optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
	)

	# multi-gpu training
	if args.n_gpu > 1:
		model = torch.nn.DataParallel(model)

	# Train!
	logger.info("***** Running training *****")
	logger.info("  Num examples = %d", len(train_picos))
	logger.info("  Num Epochs = %d", args.num_train_epochs)
	logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
	logger.info(
		"  Total train batch size (w. parallel, distributed & accumulation) = %d",
		args.train_batch_size
		* args.gradient_accumulation_steps)
	logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
	logger.info("  Total optimization steps = %d", t_total)

	global_step = 0
	tr_loss, logging_loss = 0.0, 0.0
	model.zero_grad()
	train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False)
	set_seed(args)	# Added here for reproductibility
	for _ in train_iterator:
		epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
		for step, batch in enumerate(epoch_iterator):
			model.train()

			batch = tuple(t.to(args.device) for t in batch)

			ctx_ids = to_list(batch[0])
			pico_token_ids = batch[1] # B x max_pico_length
			pico_token_mask = batch[2] # B x max_pico_length
			pico_segment_ids = batch[3] # B x max_pico_length
			labels = batch[4]

			ctx_batch = [train_ctxs[ctx_id] for ctx_id in ctx_ids] # B x list of ctx dataset
			ctx_batch = list(map(list, zip(*ctx_batch)))

			ctx_token_ids = torch.stack(ctx_batch[1]).to(args.device) # B x max_ctx_length
			ctx_token_mask = torch.stack(ctx_batch[2]).to(args.device) # B x max_ctx_length
			ctx_segment_ids = torch.stack(ctx_batch[3]).to(args.device) # B x max_ctx_length

			inputs = {
				"passage_ids": torch.cat([ctx_token_ids, pico_token_ids], dim=1),
				"passage_mask": torch.cat([ctx_token_mask, pico_token_mask], dim=1),
				"passage_segment_ids": torch.cat([ctx_segment_ids, pico_segment_ids], dim=1),
				"result_labels": labels
			}

			outputs = model(inputs)

			loss = outputs  # model outputs are always tuple in transformers (see doc)

			if args.n_gpu > 1:
				loss = loss.mean()	# mean() to average on multi-gpu parallel (not distributed) training
			if args.gradient_accumulation_steps > 1:
				loss = loss / args.gradient_accumulation_steps

			loss.backward()

			tr_loss += loss.item()

			if (step + 1) % args.gradient_accumulation_steps == 0:
				torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

				optimizer.step()
				scheduler.step()  # Update learning rate schedule
				model.zero_grad()
				global_step += 1

				if args.logging_steps > 0 and global_step % args.logging_steps == 0:
					# Log metrics
					#tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
					#tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
					#print((tr_loss - logging_loss) / args.logging_steps)
					logging_loss = tr_loss

				if args.save_steps > 0 and global_step % args.save_steps == 0:
					# Save model checkpoint
					output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
					if not os.path.exists(output_dir):
						os.makedirs(output_dir)
					model_to_save = (
						model.module if hasattr(model, "module") else model
					)  # Take care of distributed/parallel training
					model_to_save.save_pretrained(output_dir)
					tokenizer.save_pretrained(output_dir)
					torch.save(args, os.path.join(output_dir, "training_args.bin"))
					logger.info("Saving model checkpoint to %s", output_dir)

			if args.max_steps > 0 and global_step > args.max_steps:
				epoch_iterator.close()
				break
		if args.max_steps > 0 and global_step > args.max_steps:
			train_iterator.close()
			break

	#tb_writer.close()

	return global_step, tr_loss / global_step

In [None]:
from torch.utils.data import SequentialSampler
from sklearn.metrics import accuracy_score, f1_score

def evaluate(args, eval_picos, eval_ctxs, model, tokenizer, prefix=""):

	if not os.path.exists(args.output_dir):
		os.makedirs(args.output_dir)

	args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
	# Note that DistributedSampler samples randomly
	eval_sampler = SequentialSampler(eval_picos)
	eval_dataloader = DataLoader(eval_picos, sampler=eval_sampler, batch_size=args.eval_batch_size)

	# Eval!
	logger.info("***** Running evaluation {} *****".format(prefix))
	logger.info("  Num examples = %d", len(eval_picos))
	logger.info("  Batch size = %d", args.eval_batch_size)

	example_ids = []
	all_labels = []
	all_preds = []
	all_logits = np.zeros((0, 3))

	for batch in tqdm(eval_dataloader, desc="Evaluating"):
		model.eval()
		batch = tuple(t.to(args.device) for t in batch)
		with torch.no_grad():
			ctx_ids = to_list(batch[0])
			pico_token_ids = batch[1] # B x max_pico_length
			pico_token_mask = batch[2] # B x max_pico_length
			pico_segment_ids = batch[3] # B x max_pico_length
			labels = batch[4]

			ctx_batch = [eval_ctxs[ctx_id] for ctx_id in ctx_ids] # B x list of ctx dataset
			ctx_batch = list(map(list, zip(*ctx_batch)))

			ctx_token_ids = torch.stack(ctx_batch[1]).to(args.device) # B x max_ctx_length
			ctx_token_mask = torch.stack(ctx_batch[2]).to(args.device) # B x max_ctx_length
			ctx_segment_ids = torch.stack(ctx_batch[3]).to(args.device) # B x max_ctx_length

			inputs = {
				"passage_ids": torch.cat([ctx_token_ids, pico_token_ids], dim=1),
				"passage_mask": torch.cat([ctx_token_mask, pico_token_mask], dim=1),
				"passage_segment_ids": torch.cat([ctx_segment_ids, pico_segment_ids], dim=1)
			}

			logits = model(inputs) # N x 3
			preds = torch.argmax(logits, dim=1) # N

			example_ids += list(batch[4].detach().cpu().numpy())
			all_labels += list(labels.detach().cpu().numpy())
			all_preds += list(preds.detach().cpu().numpy())
			all_logits = np.concatenate([all_logits, logits.detach().cpu().numpy()], axis=0)

	if not prefix:
		prefix = 'final'

	with open(os.path.join(args.output_dir, '%s_all_example_idx.json' % prefix), 'w') as f:
		json.dump([int(label) for label in example_ids], f)
	with open(os.path.join(args.output_dir, '%s_all_labels.json' % prefix), 'w') as f:
		json.dump([int(label) for label in all_labels], f)
	with open(os.path.join(args.output_dir, '%s_all_preds.json' % prefix), 'w') as f:
		json.dump([int(pred) for pred in all_preds], f)
	np.save(os.path.join(args.output_dir, '%s_all_logits.npy' % prefix), np.array(all_logits))

	results = {}
	results['f1'] = f1_score(all_labels, all_preds, average='macro')
	results['acc'] = accuracy_score(all_labels, all_preds)

	return results

In [None]:
## DELETE UNUSED FUNCTIONS? ##

In [None]:
def represent(args, model, tokenizer):
	dataset = load_and_cache_examples(args, tokenizer, evaluate=True, do_repr=True)

	if not os.path.exists(args.output_dir):
		os.makedirs(args.output_dir)

	args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
	# Note that DistributedSampler samples randomly
	eval_sampler = SequentialSampler(dataset)
	eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

	# Eval!
	logger.info("***** Running Representations *****")
	logger.info("  Num examples = %d", len(dataset))
	logger.info("  Batch size = %d", args.eval_batch_size)

	example_ids = []
	all_reprs = np.zeros((0, model.bert.config.hidden_size))

	for batch in tqdm(eval_dataloader, desc="Evaluating"):
		model.eval()
		batch = tuple(t.to(args.device) for t in batch)
		with torch.no_grad():
			inputs = {
				"passage_ids": batch[0],
				"passage_mask": batch[1],
				"passage_segment_ids": batch[2],
			}

			reprs = model(inputs, get_reprs=True) # N x D

			example_ids += list(batch[4].detach().cpu().numpy())
			all_reprs = np.concatenate([all_reprs, reprs.detach().cpu().numpy()], axis=0)

	with open(os.path.join(args.output_dir, 'all_example_idx.json'), 'w') as f:
		json.dump([int(_id) for _id in example_ids], f)
	np.save(os.path.join(args.output_dir, 'all_reprs.npy'), np.array(all_reprs))

In [None]:
class CtxFeatures(object): # same with pre-training utils
	"""A single set of features of data."""

	def __init__(
		self,
		ctx_id,
		tokens,
		input_ids,
		input_mask,
		segment_ids
	):
		self.ctx_id = ctx_id
		self.tokens = tokens
		self.input_ids = input_ids
		self.input_mask = input_mask
		self.segment_ids = segment_ids

In [None]:
def convert_ctxs_to_features(
	examples,
	tokenizer,
	max_passage_length,
	permutation=None,
	cls_token="[CLS]",
	sep_token="[SEP]",
	pad_token=0,
	sequence_a_segment_id=0,
	sequence_b_segment_id=1,
	cls_token_segment_id=0,
	pad_token_segment_id=0
):
	"""Loads a data file into a list of `InputBatch`s."""

	features = []
	for example in tqdm(examples):
		ctx_id = example.ctx_id
		psg_tokens = tokenizer.tokenize(example.passage_text)

		tokens = []
		segment_ids = []
		input_mask = []

		tokens += [cls_token]
		tokens += psg_tokens[:max_passage_length - 2]
		tokens += [sep_token]

		segment_ids = [sequence_a_segment_id] * len(tokens)

		input_ids = tokenizer.convert_tokens_to_ids(tokens)
		input_mask = [1] * len(input_ids)

		# Zero-pad up to the sequence length.
		while len(input_ids) < max_passage_length:
			input_ids.append(pad_token)
			input_mask.append(0)
			segment_ids.append(pad_token_segment_id)

		assert len(input_ids) == max_passage_length
		assert len(input_mask) == max_passage_length
		assert len(segment_ids) == max_passage_length

		if ctx_id < 20:
			logger.info("*** Example ***")
			logger.info("ctx_id: %s" % (ctx_id))
			logger.info("tokens: %s" % " ".join(tokens))
			logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
			logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
			logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))

		features.append(
			CtxFeatures(
				ctx_id=ctx_id,
				tokens=tokens,
				input_ids=input_ids,
				input_mask=input_mask,
				segment_ids=segment_ids
			)
		)

	return features

In [None]:
class CtxExample(object): # same with pre-training utils
	"""
	a single training/test example for the EBM-Net dataset.
	"""

	def __init__(
		self,
		ctx_id,
		passage_text
	):
		self.ctx_id = ctx_id
		self.passage_text = passage_text

	def __str__(self):
		return self.__repr__()

	def __repr__(self):
		s = ""
		s += 'ctx_id: %s\n' % self.ctx_id
		s += "passage: %s\n" % self.passage_text

		return s

In [None]:
def read_ctx_examples(input_file, adversarial=False): # same with pre-training utils
	"""Read a EBM-Net json file into a list of EbmExample."""
	with open(input_file, "r", encoding="utf-8") as reader:
		input_data = json.load(reader)

	examples = []

	for entry in input_data:
		example = CtxExample(
			ctx_id=entry['ctx_id'],
			passage_text=entry['passage']
		)
		examples.append(example)

	return examples

In [None]:
from torch.utils.data import TensorDataset

def load_and_cache_ctxs(args, tokenizer, evaluate=False, do_repr=False, pretraining=False):
	# We need to index it

	# Load data features from cache or dataset file
	if do_repr:
		input_file = args.repr_ctx
	else:
		input_file = args.predict_ctx if evaluate else args.train_ctx

	cached_features_file = os.path.join(
		os.path.dirname(input_file),
		"cached_ctxs_adv{}_{}_{}".format(
			args.adversarial,
			"dev" if evaluate else "train",
			str(args.max_passage_length)
		),
	)

	if os.path.exists(cached_features_file) and not args.overwrite_cache:
		logger.info("Loading features from cached file %s", cached_features_file)
		features = torch.load(cached_features_file)
	else:
		logger.info("Creating features from dataset file at %s", input_file)

		examples = read_ctx_examples(input_file=input_file, adversarial=args.adversarial)

		features = convert_ctxs_to_features(
			examples=examples,
			tokenizer=tokenizer,
			max_passage_length=args.max_passage_length
		)

		logger.info("Saving features into cached file %s", cached_features_file)
		torch.save(features, cached_features_file)

	# Convert to Tensors and build dataset
	all_ctx_ids = torch.tensor([f.ctx_id for f in features], dtype=torch.long)
	all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
	all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
	all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)

	dataset = TensorDataset(
		all_ctx_ids,
		all_input_ids,
		all_input_mask,
		all_segment_ids
	)

	return dataset

In [None]:
class PicoFeatures(object): # unchanged from pre-training utils
	"""A single set of features of data."""

	def __init__(
		self,
		example_index,
		ctx_id,
		tokens,
		input_ids,
		input_mask,
		segment_ids,
		label
	):
		self.example_index = example_index
		self.ctx_id = ctx_id
		self.tokens = tokens
		self.input_ids = input_ids
		self.input_mask = input_mask
		self.segment_ids = segment_ids
		self.label = label

In [None]:
def convert_picos_to_features_pretraining(
	examples,
	tokenizer,
	max_pico_length,
	permutation=None,
	cls_token="[CLS]",
	sep_token="[SEP]",
	pad_token=0,
	sequence_a_segment_id=0,
	sequence_b_segment_id=1,
	cls_token_segment_id=0,
	pad_token_segment_id=0
):
	"""Loads a data file into a list of `InputBatch`s."""

	features = []
	example_index = 0

	for example in examples:
		ctx_id = example.ctx_id

		pico_tokens = tokenizer.tokenize(example.pico_text)

		tokens = []
		segment_ids = []
		input_mask = []
		label = example.label

		tokens += pico_tokens[:max_pico_length-1] + [sep_token]
		segment_ids = [sequence_b_segment_id] * len(tokens)

		input_ids = tokenizer.convert_tokens_to_ids(tokens)
		input_mask = [1] * len(input_ids)

		# Zero-pad up to the sequence length.
		while len(input_ids) < max_pico_length:
			input_ids.append(pad_token)
			input_mask.append(0)
			segment_ids.append(pad_token_segment_id)

		assert len(input_ids) == max_pico_length
		assert len(input_mask) == max_pico_length
		assert len(segment_ids) == max_pico_length

		if example_index < 20:
			logger.info("*** Example ***")
			logger.info("ctx_id: %s" % (ctx_id))
			logger.info("example_index: %s" % (example_index))
			logger.info("tokens: %s" % " ".join(tokens))
			logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
			logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
			logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
			logger.info("label: %s" % label)

		features.append(
			PicoFeatures(
				ctx_id=ctx_id,
				example_index=example_index,
				tokens=tokens,
				input_ids=input_ids,
				input_mask=input_mask,
				segment_ids=segment_ids,
				label=label
			)
		)

		example_index += 1

	return features

In [None]:
def convert_picos_to_features_ebmnet(
	examples,
	tokenizer,
	max_pico_length,
	permutation=None,
	cls_token="[CLS]",
	sep_token="[SEP]",
	pad_token=0,
	sequence_a_segment_id=0,
	sequence_b_segment_id=1,
	cls_token_segment_id=0,
	pad_token_segment_id=0
):
	"""Loads a data file into a list of `InputBatch`s."""

	features = []
	example_index = 0

	if '-' in permutation: # shifting
		perm_list = permutation.split('-')
	else:
		perm_list = [permutation]

	for perm in perm_list:
		for (example_index, example) in enumerate(examples):
			ctx_id = example.ctx_id

			i_tokens = tokenizer.tokenize(example.i_text)
			c_tokens = tokenizer.tokenize(example.c_text)
			o_tokens = tokenizer.tokenize(example.o_text)
			ico_tokens = {'i': i_tokens,
						  'c': c_tokens,
						  'o': o_tokens}

			tokens = []
			segment_ids = []
			input_mask = []
			label = example.label

			assert set(perm).issubset({'i', 'o', 'c'})
			for element in perm:
				tokens += ico_tokens[element] + ['[MASK]']
			tokens[-1] = sep_token
			segment_ids = [sequence_b_segment_id] * len(tokens)

			if len(tokens) > max_pico_length:
				tokens = tokens[:max_pico_length-1] + [sep_token]
				segment_ids = segment_ids[:max_pico_length-1] + [sequence_b_segment_id]

			input_ids = tokenizer.convert_tokens_to_ids(tokens)
			input_mask = [1] * len(input_ids)

			# Zero-pad up to the sequence length.
			while len(input_ids) < max_pico_length:
				input_ids.append(pad_token)
				input_mask.append(0)
				segment_ids.append(pad_token_segment_id)

			assert len(input_ids) == max_pico_length
			assert len(input_mask) == max_pico_length
			assert len(segment_ids) == max_pico_length

			if example_index < 20:
				logger.info("*** Example ***")
				logger.info("example_index: %s" % (example_index))
				logger.info("tokens: %s" % " ".join(tokens))
				logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
				logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
				logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))

			features.append(
				PicoFeatures(
					ctx_id=ctx_id,
					example_index=example_index,
					tokens=tokens,
					input_ids=input_ids,
					input_mask=input_mask,
					segment_ids=segment_ids,
					label=label
				)
			)

	return features

In [None]:
class PicoExamplePretraining(object):
	"""
	a single training/test example for the EBM-Net dataset.
	"""

	def __init__(
		self,
		ctx_id,
		pico_text,
		label
	):
		self.ctx_id = ctx_id
		self.pico_text = pico_text
		self.label = label

	def __str__(self):
		return self.__repr__()

	def __repr__(self):
		s = ""
		s += "ctx_id: %s\n" % self.ctx_id
		s += "pico_text: %s\n" % self.pico_text
		s += "label: %s\n" % self.label

		return s

In [None]:
def read_pico_examples_pretraining(input_file, adversarial=False):
	"""Read a EBM-Net json file into a list of EbmExample."""
	with open(input_file, "r", encoding="utf-8") as reader:
		input_data = json.load(reader)

	examples = []

	for entry in input_data:
		example = PicoExamplePretraining(
			ctx_id=entry['ctx_id'],
			pico_text=entry['pico'],
			label=entry['label']
		)
		examples.append(example)

		if adversarial:
			example = PicoExamplePretraining(
				ctx_id=entry['ctx_id'],
				pico_text=entry['rev_pico'],
				label=entry['rev_label']
			)
			examples.append(example)

	return examples

In [None]:
class PicoExampleEBMNet(object):
	"""
	a single training/test example for the EBM-Net dataset.
	"""

	def __init__(
		self,
		ctx_id,
		i_text,
		c_text,
		o_text,
		label
	):
		self.ctx_id = ctx_id
		self.i_text = i_text
		self.c_text = c_text
		self.o_text = o_text
		self.label = label

	def __str__(self):
		return self.__repr__()

	def __repr__(self):
		s = ""
		s += "ctx_id: %s\n" % self.ctx_id
		s += "i_text: %s\n" % self.i_text
		s += "c_text: %s\n" % self.c_text
		s += "o_text: %s\n" % self.o_text
		s += "label: %s\n" % self.label

		return s


In [None]:
def read_pico_examples_ebmnet(input_file, adversarial=False):
	"""Read a EBM-Net json file into a list of EbmExample."""
	with open(input_file, "r", encoding="utf-8") as reader:
		input_data = json.load(reader)

	examples = []

	for entry in input_data:
		example = PicoExampleEBMNet(
			ctx_id=entry['ctx_id'],
			i_text=entry['i_text'],
			c_text=entry['c_text'],
			o_text=entry['o_text'],
			label=entry['label']
		)
		examples.append(example)

		if adversarial:
			example = PicoExampleEBMNet(
				ctx_id=entry['ctx_id'],
				i_text=entry['c_text'],
				c_text=entry['i_text'],
				o_text=entry['o_text'],
				label=2-entry['label']
			)
			examples.append(example)

	return examples

In [None]:
def load_and_cache_picos(args, tokenizer, evaluate=False, do_repr=False, pretraining=False):
	# Dataset that we are going to use

	# Load data features from cache or dataset file
    if do_repr:
        input_file = args.repr_pico
    else:
        input_file = args.predict_pico if evaluate else args.train_pico

    cached_features_file = os.path.join(
        os.path.dirname(input_file),
        "cached_picos_adv{}_{}_{}_{}".format(
            args.adversarial,
            args.permutation,
            "dev" if evaluate else "train",
            str(args.max_pico_length)
        ),
    )

    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", input_file)

        if args.pretraining:
            examples = read_pico_examples_pretraining(input_file=input_file, adversarial=args.adversarial)

        elif args.ebmnet:
            examples = read_pico_examples_ebmnet(input_file=input_file, adversarial=args.adversarial)

        if args.pretraining:
            features = convert_picos_to_features_pretraining(
                examples=examples,
                tokenizer=tokenizer,
                max_pico_length=args.max_pico_length,
                permutation=args.permutation
            )
        elif args.ebmnet:
            features = convert_picos_to_features_ebmnet(
                examples=examples,
                tokenizer=tokenizer,
                max_pico_length=args.max_pico_length,
                permutation=args.permutation
            )

        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)


	# Convert to Tensors and build dataset
    all_ctx_ids = torch.tensor([f.ctx_id for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)

    mlm2cls = {}
    for i in range(34):
        if i < 15:
            mlm2cls[i] = 0
        elif 15 <= i < 17:
            mlm2cls[i] = 1
        else:
            mlm2cls[i] = 2

    if args.num_labels == 3 and args.pretraining: # here we have 34 labels to be processed
        all_labels = torch.tensor([mlm2cls[f.label] for f in features], dtype=torch.long)
    else:
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)

    all_example_ids = torch.tensor([f.example_index for f in features], dtype=torch.long)

    dataset = TensorDataset(
		all_ctx_ids,
		all_input_ids,
		all_input_mask,
		all_segment_ids,
		all_labels,
		all_example_ids
	)

    return dataset


In [None]:
from torch import nn
from transformers import BertModel, BertConfig

class EBM_Net(nn.Module):
    def __init__(self, args, path=None):
        super(EBM_Net, self).__init__()
        self.args = args
        self.config = BertConfig.from_pretrained(args.model_name_or_path)
        self.bert = BertModel.from_pretrained(args.model_name_or_path)

        num_cls = 34
        self.res_linear = nn.Linear(self.config.hidden_size, num_cls)
        if args.num_labels == 3:
            self.final_linear = nn.Linear(num_cls, args.num_labels)

        self.relu = nn.ReLU()
        self.m = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()
        self.softmax = nn.Softmax(dim=1)

        pretrained_dict = torch.load(os.path.join(path, 'pytorch_model.bin'))
        model_dict = self.state_dict()
        to_load = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
        print('HERE', model_dict.keys())

        if len(to_load) != len(model_dict):
			# initialize the final layers
            down_weights = torch.tensor([1/15] * 15 + [0] * 2 + [-1/17] * 17)
            mid_weights = torch.tensor([0] * 15 + [1/2] * 2 + [0] * 17)
            up_weights = torch.tensor([-1/15] * 15 + [0] * 2 + [1/17] * 17)
            weights = [down_weights, mid_weights, up_weights]

			# borrow the shape
            to_load['final_linear.weight'] = model_dict['final_linear.weight']
            to_load['final_linear.bias'] = model_dict['final_linear.bias']

            for idx in range(3):
                to_load['final_linear.weight'][idx] = weights[idx]
                to_load['final_linear.bias'][idx] = 0

        model_dict.update(to_load)
        self.load_state_dict(model_dict)


    def forward(self, inputs, get_reprs=False):
        cls_embeds = self.bert(input_ids=inputs['passage_ids'],
							   attention_mask=inputs['passage_mask'],
							   token_type_ids=inputs['passage_segment_ids'])[0][:, 0, :] # B x D

        if get_reprs:
            return cls_embeds

        if self.args.num_labels == 3:
            logits = self.final_linear(self.softmax(self.res_linear(cls_embeds))) # B x 3
        else:
            logits = self.res_linear(cls_embeds) # B x 34

        if 'result_labels' in inputs:
            return self.loss(self.m(logits), inputs['result_labels'])
        else:
            return logits

    def save_pretrained(self, path):
		# first save the model
        torch.save(self.state_dict(), os.path.join(path, 'pytorch_model.bin'))
        self.bert.save_pretrained(path)
		# then save the config (vocab saved outside)
        self.config.save_pretrained(path)

In [None]:
import os
import torch
from torch import nn
from transformers import BertModel, BertConfig

class EBM_Net(nn.Module):
    def __init__(self, args, path=None):
        super(EBM_Net, self).__init__()
        self.args = args
        self.config = BertConfig.from_pretrained(args.model_name_or_path)
        self.bert = BertModel.from_pretrained(args.model_name_or_path)

        num_cls = 34
        self.res_linear = nn.Linear(self.config.hidden_size, num_cls)
        if args.num_labels == 3:
            self.final_linear = nn.Linear(num_cls, args.num_labels)
        else:
            self.final_linear = nn.Linear(self.config.hidden_size, args.num_labels)  # New linear layer

        self.relu = nn.ReLU()
        self.m = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()
        self.softmax = nn.Softmax(dim=1)

        if path is not None:
            pretrained_dict = torch.load(os.path.join(path, 'pytorch_model.bin'))
            model_dict = self.state_dict()
            to_load = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
            print('HERE', model_dict.keys())

            if len(to_load) != len(model_dict):
                # Initialize new linear layer
                nn.init.xavier_uniform_(self.final_linear.weight)
                nn.init.zeros_(self.final_linear.bias)

                # Add new layer to load state dict
                model_dict['final_linear.weight'] = self.final_linear.weight
                model_dict['final_linear.bias'] = self.final_linear.bias

                model_dict.update(to_load)
                self.load_state_dict(model_dict)


#    def forward(self, inputs, get_reprs=False):
#        print(inputs)
#        outputs = self.bert(input_ids=inputs['input_ids'],
#                            attention_mask=inputs['attention_mask'],
#                            token_type_ids=inputs['token_type_ids'])
#        cls_embeds = outputs[0][:, 0, :]  # B x D

#        if get_reprs:
#            return cls_embeds

#        if self.args.num_labels == 3:
#            logits = self.final_linear(self.softmax(self.res_linear(cls_embeds)))  # B x 3
#        else:
#            logits = self.res_linear(cls_embeds)  # B x 34

#        if 'result_labels' in inputs:
#            return self.loss(self.m(logits), inputs['result_labels'])
#        else:
#            return logits

    def forward(self, inputs, get_reprs=False):
        cls_embeds = self.bert(input_ids=inputs['passage_ids'],
							   attention_mask=inputs['passage_mask'],
							   token_type_ids=inputs['passage_segment_ids'])[0][:, 0, :] # B x D

        if get_reprs:
            return cls_embeds

        if self.args.num_labels == 3:
            logits = self.final_linear(self.softmax(self.res_linear(cls_embeds))) # B x 3
        else:
            logits = self.res_linear(cls_embeds) # B x 34

        if 'result_labels' in inputs:
            return self.loss(self.m(logits), inputs['result_labels'])
        else:
            return logits

    def save_pretrained(self, path):
        # first save the model
        torch.save(self.state_dict(), os.path.join(path, 'pytorch_model.bin'))
        self.bert.save_pretrained(path)
        # then save the config (vocab saved outside)
        self.config.save_pretrained(path)


In [None]:
import types
import logging

from transformers import BertTokenizer

logger = logging.getLogger(__name__)

def main(model_name_or_path, output_dir, **kwargs):
    args = types.SimpleNamespace()
    args.model_name_or_path = model_name_or_path
    args.output_dir = output_dir

    args.train_ctx = kwargs.get("train_ctx", None)
    args.predict_ctx = kwargs.get("predict_ctx", None)
    args.repr_ctx = kwargs.get("repr_ctx", None)
    args.train_pico = kwargs.get("train_pico", None)
    args.predict_pico = kwargs.get("predict_pico", None)
    args.repr_pico = kwargs.get("repr_pico", None)
    args.permutation = kwargs.get("permutation", "ioc")
    args.tokenizer_name = kwargs.get("tokenizer_name", "")
    args.cache_dir = kwargs.get("cache_dir", "")
    args.max_passage_length = kwargs.get("max_passage_length", 256)
    args.max_pico_length = kwargs.get("max_pico_length", 128)
    args.do_train = kwargs.get("do_train", False)
    args.do_eval = kwargs.get("do_eval", False)
    args.do_repr = kwargs.get("do_repr", False)
    args.evaluate_during_training = kwargs.get("evaluate_during_training", False)
    args.do_lower_case = kwargs.get("do_lower_case", False)
    args.per_gpu_train_batch_size = kwargs.get("per_gpu_train_batch_size", 24)
    args.per_gpu_eval_batch_size = kwargs.get("per_gpu_eval_batch_size", 24)
    args.learning_rate = kwargs.get("learning_rate", 5e-5)
    args.gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps", 1)
    args.weight_decay = kwargs.get("weight_decay", 0.0)
    args.adam_epsilon = kwargs.get("adam_epsilon", 1e-8)
    args.max_grad_norm = kwargs.get("max_grad_norm", 1.0)
    args.num_train_epochs = kwargs.get("num_train_epochs", 24.0)
    args.max_steps = kwargs.get("max_steps", -1)
    args.warmup_steps = kwargs.get("warmup_steps", 400)
    args.logging_steps = kwargs.get("logging_steps", 50)
    args.save_steps = kwargs.get("save_steps", 25)
    args.eval_all_checkpoints = kwargs.get("eval_all_checkpoints", False)
    args.no_cuda = kwargs.get("no_cuda", False)
    args.overwrite_cache = kwargs.get("overwrite_cache", False)
    args.seed = kwargs.get("seed", 42)
    args.local_rank = kwargs.get("local_rank", -1)
    args.pretraining = kwargs.get("pretraining", False)
    args.num_labels = kwargs.get("num_labels", 3)
    args.adversarial = kwargs.get("adversarial", False)

    args.overwrite_output_dir = True
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
    args.device = device

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1)
    )

    tokenizer = BertTokenizer.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case
    )

    model = EBM_Net(args, path=args.model_name_or_path)
    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.do_train:
        train_ctxs = load_and_cache_ctxs(args, tokenizer, evaluate=False, pretraining=args.pretraining)
        train_picos = load_and_cache_picos(args, tokenizer, evaluate=False, pretraining=args.pretraining)
        global_step, tr_loss = train(args, train_picos, train_ctxs, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    if args.do_eval:
        eval_ctxs = load_and_cache_ctxs(args, tokenizer, evaluate=True)
        eval_picos = load_and_cache_picos(args, tokenizer, evaluate=True)

        results = {}

        if args.do_train:
            checkpoints = [args.output_dir]
        else:
            checkpoints = [args.model_name_or_path]

        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + 'pytorch_model.bin', recursive=True))
            )

            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)

        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            if 'checkpoint' not in checkpoint:
                global_step = 'final'
            else:
                global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""

            model = EBM_Net(args, path=checkpoint)
            model.to(args.device)

            result = evaluate(args, eval_picos, eval_ctxs, model, tokenizer, prefix=global_step)

            result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
            results.update(result)

            if 'checkpoint' in checkpoint and args.do_train and args.eval_all_checkpoints:
                os.remove(os.path.join(checkpoint, 'full_model.bin'))
                os.remove(os.path.join(checkpoint, 'pytorch_model.bin'))

        logger.info("Results: {}".format(results))
        with open(os.path.join(args.output_dir, 'results.json'), 'w') as f:
            results = {k: float(v) for k, v in results.items()}
            json.dump(results, f, indent=4)

    if args.do_repr:
        logger.info("Representing...")

        model = EBM_Net(args, path=args.model_name_or_path)
        model.to(args.device)

        represent(args, model, tokenizer)

Pretraining

In [None]:
model_name_or_path  = f'{local_path}/biobert-v1.1'
output_dir          = f'{local_path}output_dir'
do_train            = True
train_pico          = f'{local_path}/pretraining_dataset/indexed_evidence_10.json'
train_ctx           = f'{local_path}/pretraining_dataset/indexed_contexts_10.json'
num_labels          = 34
pretraining         = True
adversarial         = True

main(
    model_name_or_path  = model_name_or_path,
    output_dir          = output_dir,
    do_train            = do_train,
    train_pico          = train_pico,
    train_ctx           = train_ctx,
    num_labels          = num_labels,
    pretraining         = pretraining,
    adversarial         = adversarial
)



HERE odict_keys(['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias', 'bert.encoder.layer.0.attention.self.query.weight', 'bert.encoder.layer.0.attention.self.query.bias', 'bert.encoder.layer.0.attention.self.key.weight', 'bert.encoder.layer.0.attention.self.key.bias', 'bert.encoder.layer.0.attention.self.value.weight', 'bert.encoder.layer.0.attention.self.value.bias', 'bert.encoder.layer.0.attention.output.dense.weight', 'bert.encoder.layer.0.attention.output.dense.bias', 'bert.encoder.layer.0.attention.output.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.bias', 'bert.encoder.layer.0.intermediate.dense.weight', 'bert.encoder.layer.0.intermediate.dense.bias', 'bert.encoder.layer.0.output.dense.weight', 'bert.encoder.layer.0.output.dense.bias', 'bert.encoder.layer.0.output.LayerNorm.weight', 'bert.encoder.layer.0.o

Epoch:   0%|          | 0/24 [00:00<?, ?it/s]
Iteration:   0%|          | 0/2732 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/2732 [00:01<1:24:28,  1.86s/it][A
Iteration:   0%|          | 2/2732 [00:02<55:02,  1.21s/it]  [A
Iteration:   0%|          | 3/2732 [00:03<45:50,  1.01s/it][A
Iteration:   0%|          | 4/2732 [00:04<41:40,  1.09it/s][A
Iteration:   0%|          | 5/2732 [00:04<39:17,  1.16it/s][A
Iteration:   0%|          | 6/2732 [00:05<37:47,  1.20it/s][A
Iteration:   0%|          | 7/2732 [00:06<36:46,  1.23it/s][A
Iteration:   0%|          | 8/2732 [00:07<36:13,  1.25it/s][A
Iteration:   0%|          | 9/2732 [00:08<35:48,  1.27it/s][A
Iteration:   0%|          | 10/2732 [00:08<35:30,  1.28it/s][A
Iteration:   0%|          | 11/2732 [00:09<35:16,  1.29it/s][A
Iteration:   0%|          | 12/2732 [00:10<35:14,  1.29it/s][A
Iteration:   0%|          | 13/2732 [00:11<35:05,  1.29it/s][A
Iteration:   1%|          | 14/2732 [00:11<35:02,  1.29it/s][A
Iteratio

Fine-tuning

In [None]:
python -u run_ebmnet.py --model_name_or_path ${PRETRAINED_MODEL} \
--do_train --train_pico evidence_integration/indexed_train_picos.json --train_ctx evidence_integration/indexed_train_ctxs.json \
--do_eval --predict_pico evidence_integration/indexed_validation_picos.json --predict_ctx evidence_integration/indexed_validation_ctxs.json \
--output_dir ${OUTPUT_DIR}

model_name_or_path  = f'{local_path}/output_dir'
do_train            = True
train_pico          = f'{local_path}/evidence_integration/indexed_train_picos.json'
train_ctx           = f'{local_path}/evidence_integration/indexed_train_ctxs.json'
do_eval             = True
predict_pico        = f'{local_path}/evidence_integration/indexed_validation_picos.json'
predict_ctx         = f'{local_path}/evidence_integration/indexed_validation_ctxs.json'
output_df           = f'{local_path}/ebmnet_output'