In [1]:
%load_ext autoreload
%autoreload 2

import os
from utils import set_seed
import torch

set_seed(1234)
os.environ['HF_HOME'] = 'hf_cache' # Don't want model files in our home directory due to disk quota

torch.cuda.is_available()

True

In [2]:
from transformers import LukeTokenizer
import const
from model import DocRedModel
import json

MODEL_NAME = const.LUKE_BASE
MODE = const.MODE_CONTRASTIVE
EPOCHS = 10
ENCODER_LR = 3e-5 # NOTE: LOOK INTO LEARNING RATES
TRAIN_BATCH_SIZE = 3 if MODE == const.MODE_CONTRASTIVE else 4
DEV_BATCH_SIZE = 8
 
out_dir = os.path.join('out', 'holdout-batch-0', f"{MODEL_NAME.split('/')[-1]}_{MODE}")

rel2id_original = json.load(open('data/meta/rel2id.json'))
id2rel_original = {v: k for k, v in rel2id_original.items()}
tokenizer = LukeTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True) # Prefix space since we are doing word-level tokenization in read_docred (docred is weird)



In [3]:
from utils import read_docred
import json

train_samples = read_docred(fp='data/train_annotated.json', rel2id=rel2id_original, tokenizer=tokenizer)
dev_samples = read_docred(fp='data/dev.json', rel2id=rel2id_original, tokenizer=tokenizer)

In [None]:
from utils import get_holdouts, remove_holdouts

holdout_rel_batches = get_holdouts(train_samples=train_samples,
                                   dev_samples=dev_samples,
                                   rel2id=rel2id_original,
                                   id2rel=id2rel_original)

holdout_rels = holdout_rel_batches[0]
train_samples, rel2id_holdout, id2rel_holdout = remove_holdouts(samples=train_samples,
                                                                holdout_rels=holdout_rels, # NOTE: only testing with first batch
                                                                rel2id=rel2id_original,
                                                                id2rel=id2rel_original)
holdout_rels

In [5]:
from torch.utils.data import DataLoader
from utils import collate_fn
from train import train

train_dataloader = DataLoader(train_samples, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn, drop_last=True)
dev_dataloader = DataLoader(dev_samples, batch_size=DEV_BATCH_SIZE, shuffle=False, collate_fn=collate_fn, drop_last=False)

In [None]:
for contr_tmp in [0.01, 0.05, 0.1]:
    model = DocRedModel(model_name=MODEL_NAME,
                        tokenizer=tokenizer,
                        num_class=len(rel2id_holdout),
                        contrastive_tmp=contr_tmp).to(const.DEVICE)

    train(model=model,
          train_dataloader=train_dataloader,
          dev_dataloader=dev_dataloader,
          train_samples=train_samples,
          dev_samples=dev_samples,
          id2rel_holdout=id2rel_holdout,
          id2rel_original=id2rel_original,
          num_epochs=EPOCHS,
          mode=MODE,
          out_dir=out_dir + f'-tmp-{contr_tmp}')

    del model
    with torch.cuda.device('cuda:1'):
        torch.cuda.empty_cache()