In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
except ModuleNotFoundError as e:
    print("not in colab")
    pass
import os
base_dir = "/content/drive/MyDrive/semeval2022"
if not os.path.exists(base_dir):
  !pip install -r requirements.txt
  base_dir = ""
else:
  !cp -rf /content/drive/MyDrive/semeval2022/*.py . 
  !cp -rf /content/drive/MyDrive/semeval2022/utils .
  !cp -rf /content/drive/MyDrive/semeval2022/model .
  !pip install -r /content/drive/MyDrive/semeval2022/requirements.tx

In [None]:
from utils.util import get_entity_vocab, get_reader, train_model, test_model, val_model, write_result, create_model, save_model, parse_args, get_tagset, wnut_iob, write_submit_result, write_result, load_model, luke_iob, k_fold, vote, wait_gc
import time

In [None]:
encoder_model = "distilbert-base-uncased"
encoder_model = "roberta-base"
encoder_model = "bert-base-uncased"
encoder_model = "bert-large-uncased-whole-word-masking"
track = "EN-English/en"
train_file = os.path.join(base_dir, "training_data/{}_train.conll".format(track))
dev_file = os.path.join(base_dir, "training_data/{}_dev.conll".format(track))
test_file = os.path.join(base_dir, "training_data/{}_test.conll".format(track))
wiki_file = os.path.join(base_dir, "data/wiki_def/wiki_abstract.vocab")
wiki_file = os.path.join(base_dir, "data/wiki_def/wiki_pkl.zip")
output_dir = os.path.join(base_dir, "{}".format(track), "{}-train".format(encoder_model))
submission_file = os.path.join(base_dir, "submission", "{}.pred.conll".format(track))
iob_tagging = wnut_iob
use_crf = False

In [None]:
def train_for_k_fold(train_file, dev_file):
  entity_vocab = get_entity_vocab(conll_files=[], entity_files=[wiki_file])
  data_augment = []
  train_data = get_reader(file_path=train_file, target_vocab=iob_tagging, encoder_model=encoder_model, max_instances=-1, max_length=100, entity_vocab=entity_vocab, augment=data_augment)
  dev_entity_vocab = get_entity_vocab(conll_files=[train_file], entity_files=[wiki_file])
  dev_data = get_reader(file_path=dev_file, target_vocab=iob_tagging, encoder_model=encoder_model, max_instances=-1, max_length=100, entity_vocab=dev_entity_vocab, augment=[])

  model = create_model(train_data=train_data, dev_data=dev_data, tag_to_id=iob_tagging,
                      dropout_rate=0.1, batch_size=32, stage='fit', lr=2e-5,
                      encoder_model=encoder_model, num_gpus=1, use_crf=use_crf)

  trainer = train_model(model=model, out_dir=output_dir, epochs=20, monitor="val_micro@F1")

# use pytorch lightnings saver here.
  out_model_path, best_checkpoint = save_model(trainer=trainer, out_dir=output_dir, model_name=encoder_model, timestamp=time.time())
  model = load_model(best_checkpoint, iob_tagging, use_crf=use_crf)
  submission_dev_file = os.path.join(os.path.dirname(out_model_path), "{}.pred.conll.dev".format(track))
  model.dev_data = dev_data
  val_model(model)
  write_result(model, submission_dev_file, 'val')
  del train_data
  del dev_data
  del model
  del trainer
  del dev_entity_vocab
  del entity_vocab
  return best_checkpoint, out_model_path, submission_dev_file

def test(best_checkpoint, out_model_path, test_entity_vocab):
  model = load_model(best_checkpoint, iob_tagging, use_crf=use_crf)
  test_data = get_reader(file_path=test_file, target_vocab=iob_tagging, encoder_model=encoder_model, max_instances=-1, max_length=100, entity_vocab=test_entity_vocab, augment=[])
  model.test_data = test_data
  test_model(model)
  submission_test_file = os.path.join(os.path.dirname(out_model_path), "{}.pred.conll.test".format(track))
  write_result(model, submission_test_file, 'test')
  del test_data
  del model
  return submission_test_file

In [None]:
output_files = k_fold(train_file, dev_file, 10)
dev_pred_files = []
dev_label_files = []
test_files = []
out_model_paths = []
submission_dev_files = []
test_entity_vocab = get_entity_vocab(conll_files=[train_file, dev_file], entity_files=[wiki_file])

for train_file, dev_file in output_files:
    best_checkpoint, out_model_path, submission_dev_file = train_for_k_fold(train_file, dev_file)
    out_model_paths.append(out_model_path)
    dev_label_files.append(dev_file)
    dev_pred_files.append(submission_dev_file)
    wait_gc()
    submission_test_file = test(best_checkpoint, out_model_path, test_entity_vocab)
    wait_gc()
    test_files.append(submission_test_file)
vote(dev_label_files, dev_pred_files, test_files, submission_file, wnut_iob)