<a href="https://colab.research.google.com/github/mahopman/IEBM-Net/blob/main/run_drug_ebmnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

In [1]:
__author__ = 'Qiao Jin'
__editor__ = 'Mia Hopman'

In [2]:
local_path = '/content/drive/MyDrive/MS_DataScience/DS595/IEBM-Net_Data'
pretraining_dataset_path = f'{local_path}/pretraining_dataset'
evidence_integration_path = f'{local_path}/evidence_integration'

In [3]:
import random as rd
import torch

def set_seed(args):
	rd.seed(args.seed)
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)
	if args.n_gpu > 0:
		torch.cuda.manual_seed_all(args.seed)

In [4]:
def to_list(tensor):
	return tensor.detach().cpu().tolist()

In [5]:
from torch.utils.data import DataLoader, RandomSampler
from transformers import AdamW, get_cosine_schedule_with_warmup
from tqdm import tqdm, trange

def train(args, train_picos, train_ctxs, model, tokenizer):
	""" Train the model """
	#tb_writer = SummaryWriter()

	args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
	train_sampler = RandomSampler(train_picos)
	train_dataloader = DataLoader(train_picos, sampler=train_sampler, batch_size=args.train_batch_size)

	if args.max_steps > 0:
		t_total = args.max_steps
		args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
	else:
		t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

	# Prepare optimizer and schedule (linear warmup and decay)
	no_decay = ["bias", "LayerNorm.weight"]
	optimizer_grouped_parameters = [
		{
			"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
			"weight_decay": args.weight_decay,
		},
		{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
	]
	optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
	scheduler = get_cosine_schedule_with_warmup(
		optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
	)

	# multi-gpu training
	if args.n_gpu > 1:
		model = torch.nn.DataParallel(model)

	# Train!
	logger.info("***** Running training *****")
	logger.info("  Num examples = %d", len(train_picos))
	logger.info("  Num Epochs = %d", args.num_train_epochs)
	logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
	logger.info(
		"  Total train batch size (w. parallel, distributed & accumulation) = %d",
		args.train_batch_size
		* args.gradient_accumulation_steps)
	logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
	logger.info("  Total optimization steps = %d", t_total)

	global_step = 0
	tr_loss, logging_loss = 0.0, 0.0
	model.zero_grad()
	train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=False)
	set_seed(args)	# Added here for reproductibility
	for _ in train_iterator:
		epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
		for step, batch in enumerate(epoch_iterator):
			model.train()

			batch = tuple(t.to(args.device) for t in batch)

			ctx_ids = to_list(batch[0])
			pico_token_ids = batch[1] # B x max_pico_length
			pico_token_mask = batch[2] # B x max_pico_length
			pico_segment_ids = batch[3] # B x max_pico_length
			labels = batch[4]

			ctx_batch = [train_ctxs[ctx_id] for ctx_id in ctx_ids] # B x list of ctx dataset
			ctx_batch = list(map(list, zip(*ctx_batch)))

			ctx_token_ids = torch.stack(ctx_batch[1]).to(args.device) # B x max_ctx_length
			ctx_token_mask = torch.stack(ctx_batch[2]).to(args.device) # B x max_ctx_length
			ctx_segment_ids = torch.stack(ctx_batch[3]).to(args.device) # B x max_ctx_length

			inputs = {
				"passage_ids": torch.cat([ctx_token_ids, pico_token_ids], dim=1),
				"passage_mask": torch.cat([ctx_token_mask, pico_token_mask], dim=1),
				"passage_segment_ids": torch.cat([ctx_segment_ids, pico_segment_ids], dim=1),
				"result_labels": labels
			}

			outputs = model(inputs)

			loss = outputs  # model outputs are always tuple in transformers (see doc)

			if args.n_gpu > 1:
				loss = loss.mean()	# mean() to average on multi-gpu parallel (not distributed) training
			if args.gradient_accumulation_steps > 1:
				loss = loss / args.gradient_accumulation_steps

			loss.backward()

			tr_loss += loss.item()

			if (step + 1) % args.gradient_accumulation_steps == 0:
				torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

				optimizer.step()
				scheduler.step()  # Update learning rate schedule
				model.zero_grad()
				global_step += 1

				if args.logging_steps > 0 and global_step % args.logging_steps == 0:
					# Log metrics
					#tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
					#tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
					#print((tr_loss - logging_loss) / args.logging_steps)
					logging_loss = tr_loss

				if args.save_steps > 0 and global_step % args.save_steps == 0:
					# Save model checkpoint
					output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
					if not os.path.exists(output_dir):
						os.makedirs(output_dir)
					model_to_save = (
						model.module if hasattr(model, "module") else model
					)  # Take care of distributed/parallel training
					model_to_save.save_pretrained(output_dir)
					tokenizer.save_pretrained(output_dir)
					torch.save(args, os.path.join(output_dir, "training_args.bin"))
					logger.info("Saving model checkpoint to %s", output_dir)

			if args.max_steps > 0 and global_step > args.max_steps:
				epoch_iterator.close()
				break
		if args.max_steps > 0 and global_step > args.max_steps:
			train_iterator.close()
			break

	#tb_writer.close()

	return global_step, tr_loss / global_step

In [6]:
import json
import numpy as np
from torch.utils.data import SequentialSampler
from sklearn.metrics import accuracy_score, f1_score

def evaluate(args, eval_picos, eval_ctxs, model, tokenizer, prefix=""):

	if not os.path.exists(args.output_dir):
		os.makedirs(args.output_dir)

	args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
	# Note that DistributedSampler samples randomly
	eval_sampler = SequentialSampler(eval_picos)
	eval_dataloader = DataLoader(eval_picos, sampler=eval_sampler, batch_size=args.eval_batch_size)

	# Eval!
	logger.info("***** Running evaluation {} *****".format(prefix))
	logger.info("  Num examples = %d", len(eval_picos))
	logger.info("  Batch size = %d", args.eval_batch_size)

	example_ids = []
	all_labels = []
	all_preds = []
	all_logits = np.zeros((0, 3))

	for batch in tqdm(eval_dataloader, desc="Evaluating"):
		model.eval()
		batch = tuple(t.to(args.device) for t in batch)
		with torch.no_grad():
			ctx_ids = to_list(batch[0])
			pico_token_ids = batch[1] # B x max_pico_length
			pico_token_mask = batch[2] # B x max_pico_length
			pico_segment_ids = batch[3] # B x max_pico_length
			labels = batch[4]

			ctx_batch = [eval_ctxs[ctx_id] for ctx_id in ctx_ids] # B x list of ctx dataset
			ctx_batch = list(map(list, zip(*ctx_batch)))

			ctx_token_ids = torch.stack(ctx_batch[1]).to(args.device) # B x max_ctx_length
			ctx_token_mask = torch.stack(ctx_batch[2]).to(args.device) # B x max_ctx_length
			ctx_segment_ids = torch.stack(ctx_batch[3]).to(args.device) # B x max_ctx_length

			inputs = {
				"passage_ids": torch.cat([ctx_token_ids, pico_token_ids], dim=1),
				"passage_mask": torch.cat([ctx_token_mask, pico_token_mask], dim=1),
				"passage_segment_ids": torch.cat([ctx_segment_ids, pico_segment_ids], dim=1)
			}

			logits = model(inputs) # N x 3
			preds = torch.argmax(logits, dim=1) # N

			example_ids += list(batch[4].detach().cpu().numpy())
			all_labels += list(labels.detach().cpu().numpy())
			all_preds += list(preds.detach().cpu().numpy())
			all_logits = np.concatenate([all_logits, logits.detach().cpu().numpy()], axis=0)

	if not prefix:
		prefix = 'final'

	with open(os.path.join(args.output_dir, '%s_all_example_idx.json' % prefix), 'w') as f:
		json.dump([int(label) for label in example_ids], f)
	with open(os.path.join(args.output_dir, '%s_all_labels.json' % prefix), 'w') as f:
		json.dump([int(label) for label in all_labels], f)
	with open(os.path.join(args.output_dir, '%s_all_preds.json' % prefix), 'w') as f:
		json.dump([int(pred) for pred in all_preds], f)
	np.save(os.path.join(args.output_dir, '%s_all_logits.npy' % prefix), np.array(all_logits))

	results = {}
	results['f1'] = f1_score(all_labels, all_preds, average='macro')
	results['acc'] = accuracy_score(all_labels, all_preds)

	return results

In [7]:
def represent(args, model, tokenizer):
	dataset = load_and_cache_examples(args, tokenizer, evaluate=True, do_repr=True)

	if not os.path.exists(args.output_dir):
		os.makedirs(args.output_dir)

	args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
	# Note that DistributedSampler samples randomly
	eval_sampler = SequentialSampler(dataset)
	eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

	# Eval!
	logger.info("***** Running Representations *****")
	logger.info("  Num examples = %d", len(dataset))
	logger.info("  Batch size = %d", args.eval_batch_size)

	example_ids = []
	all_reprs = np.zeros((0, model.bert.config.hidden_size))

	for batch in tqdm(eval_dataloader, desc="Evaluating"):
		model.eval()
		batch = tuple(t.to(args.device) for t in batch)
		with torch.no_grad():
			inputs = {
				"passage_ids": batch[0],
				"passage_mask": batch[1],
				"passage_segment_ids": batch[2],
			}

			reprs = model(inputs, get_reprs=True) # N x D

			example_ids += list(batch[4].detach().cpu().numpy())
			all_reprs = np.concatenate([all_reprs, reprs.detach().cpu().numpy()], axis=0)

	with open(os.path.join(args.output_dir, 'all_example_idx.json'), 'w') as f:
		json.dump([int(_id) for _id in example_ids], f)
	np.save(os.path.join(args.output_dir, 'all_reprs.npy'), np.array(all_reprs))

In [8]:
class CtxFeatures(object): # same with pre-training utils
	"""A single set of features of data."""

	def __init__(
		self,
		ctx_id,
		tokens,
		input_ids,
		input_mask,
		segment_ids
	):
		self.ctx_id = ctx_id
		self.tokens = tokens
		self.input_ids = input_ids
		self.input_mask = input_mask
		self.segment_ids = segment_ids

In [9]:
def convert_ctxs_to_features(
	examples,
	tokenizer,
	max_passage_length,
	permutation=None,
	cls_token="[CLS]",
	sep_token="[SEP]",
	pad_token=0,
	sequence_a_segment_id=0,
	sequence_b_segment_id=1,
	cls_token_segment_id=0,
	pad_token_segment_id=0
):
	"""Loads a data file into a list of `InputBatch`s."""

	features = []
	for example in tqdm(examples):
		ctx_id = example.ctx_id
		psg_tokens = tokenizer.tokenize(example.passage_text)

		tokens = []
		segment_ids = []
		input_mask = []

		tokens += [cls_token]
		tokens += psg_tokens[:max_passage_length - 2]
		tokens += [sep_token]

		segment_ids = [sequence_a_segment_id] * len(tokens)

		input_ids = tokenizer.convert_tokens_to_ids(tokens)
		input_mask = [1] * len(input_ids)

		# Zero-pad up to the sequence length.
		while len(input_ids) < max_passage_length:
			input_ids.append(pad_token)
			input_mask.append(0)
			segment_ids.append(pad_token_segment_id)

		assert len(input_ids) == max_passage_length
		assert len(input_mask) == max_passage_length
		assert len(segment_ids) == max_passage_length

		if ctx_id < 20:
			logger.info("*** Example ***")
			logger.info("ctx_id: %s" % (ctx_id))
			logger.info("tokens: %s" % " ".join(tokens))
			logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
			logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
			logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))

		features.append(
			CtxFeatures(
				ctx_id=ctx_id,
				tokens=tokens,
				input_ids=input_ids,
				input_mask=input_mask,
				segment_ids=segment_ids
			)
		)

	return features

In [10]:
class CtxExample(object): # same with pre-training utils
	"""
	a single training/test example for the EBM-Net dataset.
	"""

	def __init__(
		self,
		ctx_id,
		passage_text
	):
		self.ctx_id = ctx_id
		self.passage_text = passage_text

	def __str__(self):
		return self.__repr__()

	def __repr__(self):
		s = ""
		s += 'ctx_id: %s\n' % self.ctx_id
		s += "passage: %s\n" % self.passage_text

		return s

In [11]:
def read_ctx_examples(input_file, adversarial=False): # same with pre-training utils
	"""Read a EBM-Net json file into a list of EbmExample."""
	with open(input_file, "r", encoding="utf-8") as reader:
		input_data = json.load(reader)

	examples = []

	for entry in input_data:
		example = CtxExample(
			ctx_id=entry['ctx_id'],
			passage_text=entry['passage']
		)
		examples.append(example)

	return examples

In [12]:
from torch.utils.data import TensorDataset

def load_and_cache_ctxs(args, tokenizer, evaluate=False, do_repr=False, pretraining=False):
	# We need to index it

	# Load data features from cache or dataset file
	if do_repr:
		input_file = args.repr_ctx
	else:
		input_file = args.predict_ctx if evaluate else args.train_ctx

	cached_features_file = os.path.join(
		os.path.dirname(input_file),
		"cached_ctxs_adv{}_{}_{}".format(
			args.adversarial,
			"dev" if evaluate else "train",
			str(args.max_passage_length)
		),
	)

	if os.path.exists(cached_features_file) and not args.overwrite_cache:
		logger.info("Loading features from cached file %s", cached_features_file)
		features = torch.load(cached_features_file)
	else:
		logger.info("Creating features from dataset file at %s", input_file)

		examples = read_ctx_examples(input_file=input_file, adversarial=args.adversarial)

		features = convert_ctxs_to_features(
			examples=examples,
			tokenizer=tokenizer,
			max_passage_length=args.max_passage_length
		)

		logger.info("Saving features into cached file %s", cached_features_file)
		torch.save(features, cached_features_file)

	# Convert to Tensors and build dataset
	all_ctx_ids = torch.tensor([f.ctx_id for f in features], dtype=torch.long)
	all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
	all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
	all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)

	dataset = TensorDataset(
		all_ctx_ids,
		all_input_ids,
		all_input_mask,
		all_segment_ids
	)

	return dataset

In [13]:
class PicoFeatures(object): # unchanged from pre-training utils
	"""A single set of features of data."""

	def __init__(
		self,
		example_index,
		ctx_id,
		tokens,
		input_ids,
		input_mask,
		segment_ids,
		label
	):
		self.example_index = example_index
		self.ctx_id = ctx_id
		self.tokens = tokens
		self.input_ids = input_ids
		self.input_mask = input_mask
		self.segment_ids = segment_ids
		self.label = label

In [14]:
def convert_picos_to_features_pretraining(
	examples,
	tokenizer,
	max_pico_length,
	permutation=None,
	cls_token="[CLS]",
	sep_token="[SEP]",
	pad_token=0,
	sequence_a_segment_id=0,
	sequence_b_segment_id=1,
	cls_token_segment_id=0,
	pad_token_segment_id=0
):
	"""Loads a data file into a list of `InputBatch`s."""

	features = []
	example_index = 0

	for example in examples:
		ctx_id = example.ctx_id

		pico_tokens = tokenizer.tokenize(example.pico_text)

		tokens = []
		segment_ids = []
		input_mask = []
		label = example.label

		tokens += pico_tokens[:max_pico_length-1] + [sep_token]
		segment_ids = [sequence_b_segment_id] * len(tokens)

		input_ids = tokenizer.convert_tokens_to_ids(tokens)
		input_mask = [1] * len(input_ids)

		# Zero-pad up to the sequence length.
		while len(input_ids) < max_pico_length:
			input_ids.append(pad_token)
			input_mask.append(0)
			segment_ids.append(pad_token_segment_id)

		assert len(input_ids) == max_pico_length
		assert len(input_mask) == max_pico_length
		assert len(segment_ids) == max_pico_length

		if example_index < 20:
			logger.info("*** Example ***")
			logger.info("ctx_id: %s" % (ctx_id))
			logger.info("example_index: %s" % (example_index))
			logger.info("tokens: %s" % " ".join(tokens))
			logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
			logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
			logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
			logger.info("label: %s" % label)

		features.append(
			PicoFeatures(
				ctx_id=ctx_id,
				example_index=example_index,
				tokens=tokens,
				input_ids=input_ids,
				input_mask=input_mask,
				segment_ids=segment_ids,
				label=label
			)
		)

		example_index += 1

	return features

In [15]:
def convert_picos_to_features_ebmnet(
	examples,
	tokenizer,
	max_pico_length,
	permutation=None,
	cls_token="[CLS]",
	sep_token="[SEP]",
	pad_token=0,
	sequence_a_segment_id=0,
	sequence_b_segment_id=1,
	cls_token_segment_id=0,
	pad_token_segment_id=0
):
	"""Loads a data file into a list of `InputBatch`s."""

	features = []
	example_index = 0

	if '-' in permutation: # shifting
		perm_list = permutation.split('-')
	else:
		perm_list = [permutation]

	for perm in perm_list:
		for (example_index, example) in enumerate(examples):
			ctx_id = example.ctx_id

			i_tokens = tokenizer.tokenize(example.i_text)
			c_tokens = tokenizer.tokenize(example.c_text)
			o_tokens = tokenizer.tokenize(example.o_text)
			ico_tokens = {'i': i_tokens,
						  'c': c_tokens,
						  'o': o_tokens}

			tokens = []
			segment_ids = []
			input_mask = []
			label = example.label

			assert set(perm).issubset({'i', 'o', 'c'})
			for element in perm:
				tokens += ico_tokens[element] + ['[MASK]']
			tokens[-1] = sep_token
			segment_ids = [sequence_b_segment_id] * len(tokens)

			if len(tokens) > max_pico_length:
				tokens = tokens[:max_pico_length-1] + [sep_token]
				segment_ids = segment_ids[:max_pico_length-1] + [sequence_b_segment_id]

			input_ids = tokenizer.convert_tokens_to_ids(tokens)
			input_mask = [1] * len(input_ids)

			# Zero-pad up to the sequence length.
			while len(input_ids) < max_pico_length:
				input_ids.append(pad_token)
				input_mask.append(0)
				segment_ids.append(pad_token_segment_id)

			assert len(input_ids) == max_pico_length
			assert len(input_mask) == max_pico_length
			assert len(segment_ids) == max_pico_length

			if example_index < 20:
				logger.info("*** Example ***")
				logger.info("example_index: %s" % (example_index))
				logger.info("tokens: %s" % " ".join(tokens))
				logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
				logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
				logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))

			features.append(
				PicoFeatures(
					ctx_id=ctx_id,
					example_index=example_index,
					tokens=tokens,
					input_ids=input_ids,
					input_mask=input_mask,
					segment_ids=segment_ids,
					label=label
				)
			)

	return features

In [16]:
class PicoExamplePretraining(object):
	"""
	a single training/test example for the EBM-Net dataset.
	"""

	def __init__(
		self,
		ctx_id,
		pico_text,
		label
	):
		self.ctx_id = ctx_id
		self.pico_text = pico_text
		self.label = label

	def __str__(self):
		return self.__repr__()

	def __repr__(self):
		s = ""
		s += "ctx_id: %s\n" % self.ctx_id
		s += "pico_text: %s\n" % self.pico_text
		s += "label: %s\n" % self.label

		return s

In [17]:
def read_pico_examples_pretraining(input_file, adversarial=False):
	"""Read a EBM-Net json file into a list of EbmExample."""
	with open(input_file, "r", encoding="utf-8") as reader:
		input_data = json.load(reader)

	examples = []

	for entry in input_data:
		example = PicoExamplePretraining(
			ctx_id=entry['ctx_id'],
			pico_text=entry['pico'],
			label=entry['label']
		)
		examples.append(example)

		if adversarial:
			example = PicoExamplePretraining(
				ctx_id=entry['ctx_id'],
				pico_text=entry['rev_pico'],
				label=entry['rev_label']
			)
			examples.append(example)

	return examples

In [18]:
class PicoExampleEBMNet(object):
	"""
	a single training/test example for the EBM-Net dataset.
	"""

	def __init__(
		self,
		ctx_id,
		i_text,
		c_text,
		o_text,
		label
	):
		self.ctx_id = ctx_id
		self.i_text = i_text
		self.c_text = c_text
		self.o_text = o_text
		self.label = label

	def __str__(self):
		return self.__repr__()

	def __repr__(self):
		s = ""
		s += "ctx_id: %s\n" % self.ctx_id
		s += "i_text: %s\n" % self.i_text
		s += "c_text: %s\n" % self.c_text
		s += "o_text: %s\n" % self.o_text
		s += "label: %s\n" % self.label

		return s

In [19]:
def read_pico_examples_ebmnet(input_file, adversarial=False):
	"""Read a EBM-Net json file into a list of EbmExample."""
	with open(input_file, "r", encoding="utf-8") as reader:
		input_data = json.load(reader)

	examples = []

	for entry in input_data:
		example = PicoExampleEBMNet(
			ctx_id=entry['ctx_id'],
			i_text=entry['i_text'],
			c_text=entry['c_text'],
			o_text=entry['o_text'],
			label=entry['label']
		)
		examples.append(example)

		if adversarial:
			example = PicoExampleEBMNet(
				ctx_id=entry['ctx_id'],
				i_text=entry['c_text'],
				c_text=entry['i_text'],
				o_text=entry['o_text'],
				label=2-entry['label']
			)
			examples.append(example)

	return examples

In [20]:
def load_and_cache_picos(args, tokenizer, evaluate=False, do_repr=False, pretraining=False):
	# Dataset that we are going to use

	# Load data features from cache or dataset file
    if do_repr:
        input_file = args.repr_pico
    else:
        input_file = args.predict_pico if evaluate else args.train_pico

    cached_features_file = os.path.join(
        os.path.dirname(input_file),
        "cached_picos_adv{}_{}_{}_{}".format(
            args.adversarial,
            args.permutation,
            "dev" if evaluate else "train",
            str(args.max_pico_length)
        ),
    )

    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        logger.info("Creating features from dataset file at %s", input_file)

        if args.pretraining:
            examples = read_pico_examples_pretraining(input_file=input_file, adversarial=args.adversarial)

        elif args.ebmnet:
            examples = read_pico_examples_ebmnet(input_file=input_file, adversarial=args.adversarial)

        if args.pretraining:
            features = convert_picos_to_features_pretraining(
                examples=examples,
                tokenizer=tokenizer,
                max_pico_length=args.max_pico_length,
                permutation=args.permutation
            )
        elif args.ebmnet:
            features = convert_picos_to_features_ebmnet(
                examples=examples,
                tokenizer=tokenizer,
                max_pico_length=args.max_pico_length,
                permutation=args.permutation
            )

        logger.info("Saving features into cached file %s", cached_features_file)
        torch.save(features, cached_features_file)


	# Convert to Tensors and build dataset
    all_ctx_ids = torch.tensor([f.ctx_id for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)

    mlm2cls = {}
    for i in range(34):
        if i < 15:
            mlm2cls[i] = 0
        elif 15 <= i < 17:
            mlm2cls[i] = 1
        else:
            mlm2cls[i] = 2

    if args.num_labels == 3 and args.pretraining: # here we have 34 labels to be processed
        all_labels = torch.tensor([mlm2cls[f.label] for f in features], dtype=torch.long)
    else:
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)

    all_example_ids = torch.tensor([f.example_index for f in features], dtype=torch.long)

    dataset = TensorDataset(
		all_ctx_ids,
		all_input_ids,
		all_input_mask,
		all_segment_ids,
		all_labels,
		all_example_ids
	)

    return dataset


In [21]:
import os
import torch
from torch import nn
from transformers import BertModel, BertConfig

class EBM_Net(nn.Module):
    def __init__(self, args, path=None):
        super(EBM_Net, self).__init__()
        self.args = args
        self.config = BertConfig.from_pretrained(args.model_name_or_path)
        self.bert = BertModel.from_pretrained(args.model_name_or_path)

        num_cls = 34
        self.res_linear = nn.Linear(self.config.hidden_size, num_cls)
        if args.num_labels == 3:
            self.final_linear = nn.Linear(num_cls, args.num_labels)
        else:
            self.final_linear = nn.Linear(self.config.hidden_size, args.num_labels)  # New linear layer

        self.relu = nn.ReLU()
        self.m = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()
        self.softmax = nn.Softmax(dim=1)

        if path is not None:
            pretrained_dict = torch.load(os.path.join(path, 'pytorch_model.bin'))
            model_dict = self.state_dict()
            to_load = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
            print('HERE', model_dict.keys())

            if len(to_load) != len(model_dict):
                # Initialize new linear layer
                nn.init.xavier_uniform_(self.final_linear.weight)
                nn.init.zeros_(self.final_linear.bias)

                # Add new layer to load state dict
                model_dict['final_linear.weight'] = self.final_linear.weight
                model_dict['final_linear.bias'] = self.final_linear.bias

                model_dict.update(to_load)
                self.load_state_dict(model_dict)

    def forward(self, inputs, get_reprs=False):
        cls_embeds = self.bert(input_ids=inputs['passage_ids'],
							   attention_mask=inputs['passage_mask'],
							   token_type_ids=inputs['passage_segment_ids'])[0][:, 0, :] # B x D

        if get_reprs:
            return cls_embeds

        if self.args.num_labels == 3:
            logits = self.final_linear(self.softmax(self.res_linear(cls_embeds))) # B x 3
        else:
            logits = self.res_linear(cls_embeds) # B x 34

        if 'result_labels' in inputs:
            return self.loss(self.m(logits), inputs['result_labels'])
        else:
            return logits

    def save_pretrained(self, path):
        # first save the model
        torch.save(self.state_dict(), os.path.join(path, 'pytorch_model.bin'))
        self.bert.save_pretrained(path)
        # then save the config (vocab saved outside)
        self.config.save_pretrained(path)


In [22]:
import types
import logging

from transformers import BertTokenizer

logger = logging.getLogger(__name__)

def main(model_name_or_path, output_dir, **kwargs):
    args = types.SimpleNamespace()
    args.model_name_or_path = model_name_or_path
    args.output_dir = output_dir

    args.train_ctx = kwargs.get("train_ctx", None)
    args.predict_ctx = kwargs.get("predict_ctx", None)
    args.repr_ctx = kwargs.get("repr_ctx", None)
    args.train_pico = kwargs.get("train_pico", None)
    args.predict_pico = kwargs.get("predict_pico", None)
    args.repr_pico = kwargs.get("repr_pico", None)
    args.permutation = kwargs.get("permutation", "ioc")
    args.tokenizer_name = kwargs.get("tokenizer_name", "")
    args.cache_dir = kwargs.get("cache_dir", "")
    args.max_passage_length = kwargs.get("max_passage_length", 256)
    args.max_pico_length = kwargs.get("max_pico_length", 128)
    args.do_train = kwargs.get("do_train", False)
    args.do_eval = kwargs.get("do_eval", False)
    args.do_repr = kwargs.get("do_repr", False)
    args.evaluate_during_training = kwargs.get("evaluate_during_training", False)
    args.do_lower_case = kwargs.get("do_lower_case", False)
    args.per_gpu_train_batch_size = kwargs.get("per_gpu_train_batch_size", 24)
    args.per_gpu_eval_batch_size = kwargs.get("per_gpu_eval_batch_size", 24)
    args.learning_rate = kwargs.get("learning_rate", 5e-5)
    args.gradient_accumulation_steps = kwargs.get("gradient_accumulation_steps", 1)
    args.weight_decay = kwargs.get("weight_decay", 0.0)
    args.adam_epsilon = kwargs.get("adam_epsilon", 1e-8)
    args.max_grad_norm = kwargs.get("max_grad_norm", 1.0)
    args.num_train_epochs = kwargs.get("num_train_epochs", 3)
    args.max_steps = kwargs.get("max_steps", -1)
    args.warmup_steps = kwargs.get("warmup_steps", 400)
    args.logging_steps = kwargs.get("logging_steps", 200)
    args.save_steps = kwargs.get("save_steps", 100)
    args.eval_all_checkpoints = kwargs.get("eval_all_checkpoints", False)
    args.no_cuda = kwargs.get("no_cuda", False)
    args.overwrite_cache = kwargs.get("overwrite_cache", False)
    args.seed = kwargs.get("seed", 42)
    args.local_rank = kwargs.get("local_rank", -1)
    args.pretraining = kwargs.get("pretraining", False)
    args.num_labels = kwargs.get("num_labels", 3)
    args.adversarial = kwargs.get("adversarial", False)

    args.overwrite_output_dir = True
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
    args.device = device

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1)
    )

    tokenizer = BertTokenizer.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case
    )

    model = EBM_Net(args, path=args.model_name_or_path)
    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.do_train:
        train_ctxs = load_and_cache_ctxs(args, tokenizer, evaluate=False, pretraining=args.pretraining)
        train_picos = load_and_cache_picos(args, tokenizer, evaluate=False, pretraining=args.pretraining)
        global_step, tr_loss = train(args, train_picos, train_ctxs, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    if args.do_eval:
        eval_ctxs = load_and_cache_ctxs(args, tokenizer, evaluate=True)
        eval_picos = load_and_cache_picos(args, tokenizer, evaluate=True)

        results = {}

        if args.do_train:
            checkpoints = [args.output_dir]
        else:
            checkpoints = [args.model_name_or_path]

        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + 'pytorch_model.bin', recursive=True))
            )

            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)

        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            if 'checkpoint' not in checkpoint:
                global_step = 'final'
            else:
                global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""

            model = EBM_Net(args, path=checkpoint)
            model.to(args.device)

            result = evaluate(args, eval_picos, eval_ctxs, model, tokenizer, prefix=global_step)

            result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
            results.update(result)

            if 'checkpoint' in checkpoint and args.do_train and args.eval_all_checkpoints:
                os.remove(os.path.join(checkpoint, 'full_model.bin'))
                os.remove(os.path.join(checkpoint, 'pytorch_model.bin'))

        logger.info("Results: {}".format(results))
        with open(os.path.join(args.output_dir, 'results.json'), 'w') as f:
            results = {k: float(v) for k, v in results.items()}
            json.dump(results, f, indent=4)

    if args.do_repr:
        logger.info("Representing...")

        model = EBM_Net(args, path=args.model_name_or_path)
        model.to(args.device)

        represent(args, model, tokenizer)

Pretraining

In [None]:
model_name_or_path  = f'{local_path}/biobert-v1.1'
output_dir          = f'{local_path}/pretrained_model'
do_train            = True
train_pico          = f'{pretraining_dataset_path}/indexed_evidence.json'
train_ctx           = f'{pretraining_dataset_path}/indexed_contexts.json'
num_labels          = 34
pretraining         = True
adversarial         = True

main(
    model_name_or_path  = model_name_or_path,
    output_dir          = output_dir,
    do_train            = do_train,
    train_pico          = train_pico,
    train_ctx           = train_ctx,
    num_labels          = num_labels,
    pretraining         = pretraining,
    adversarial         = adversarial
)



Finetuning

In [None]:
model_name_or_path  = f'{local_path}/pretrained_model'
do_train            = True
train_pico          = f'{evidence_integration_path}/indexed_train_picos.json'
train_ctx           = f'{evidence_integration_path}/indexed_train_ctxs.json'
do_eval             = True
predict_pico        = f'{evidence_integration_path}/indexed_validation_picos.json'
predict_ctx         = f'{evidence_integration_path}/indexed_validation_ctxs.json'
output_dir          = f'{local_path}/drug_ebmnet_model'

main(
    model_name_or_path  = model_name_or_path,
    do_train            = do_train,
    train_pico          = train_pico,
    train_ctx           = train_ctx,
    do_eval             = do_eval,
    predict_pico        = predict_pico,
    predict_ctx         = predict_ctx,
    output_dir          = output_dir
)