In [1]:
%load_ext autoreload
%autoreload 2

import os
from utils import set_seed
import torch

set_seed(1234)
os.environ['HF_HOME'] = 'hf_cache' # Don't want model files in our home directory due to disk quota

torch.cuda.is_available()

True

In [2]:
from transformers import LukeTokenizer
import const
from model import DocRedModel
import json

MODEL_NAME = const.LUKE_LARGE

rel2id = json.load(open('data/meta/rel2id.json'))
id2rel = {v: k for k, v in rel2id.items()}
tokenizer = LukeTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True) # Prefix space since we are doing word-level tokenization in read_docred (docred is weird)
model = DocRedModel(model_name=MODEL_NAME,
                    tokenizer=tokenizer,
                    num_class=len(rel2id)).to(const.DEVICE)



In [3]:
from utils import read_docred
import json

train_samples = read_docred(fp='/data2/nhanse02/thesis/data/train_annotated.json', rel2id=rel2id, tokenizer=tokenizer)
dev_samples = read_docred(fp='/data2/nhanse02/thesis/data/dev.json', rel2id=rel2id, tokenizer=tokenizer)

/data2/nhanse02/thesis/data/train_annotated.json: 100%|██████████| 3053/3053 [00:22<00:00, 136.23it/s]
/data2/nhanse02/thesis/data/dev.json: 100%|██████████| 998/998 [00:06<00:00, 152.62it/s]


In [4]:
from torch.utils.data import DataLoader
from utils import collate_fn

train_dataloader = DataLoader(train_samples, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=True)
dev_dataloader = DataLoader(dev_samples, batch_size=8, shuffle=False, collate_fn=collate_fn, drop_last=False)

In [5]:
from train import train

train(model=model,
      train_dataloader=train_dataloader,
      dev_dataloader=dev_dataloader,
      dev_samples=dev_samples,
      id2rel=id2rel,
      num_epochs=30)

Train Epoch 1/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.109, run_loss=0.311] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.83batch/s]


Epoch 1/30 F1: 0.4178 F1 Ign: 0.4108


Train Epoch 2/30: 100%|██████████| 763/763 [07:06<00:00,  1.79batch/s, cur_loss=0.0813, run_loss=0.103] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.87batch/s]


Epoch 2/30 F1: 0.5249 F1 Ign: 0.5094


Train Epoch 3/30: 100%|██████████| 763/763 [07:07<00:00,  1.79batch/s, cur_loss=0.0878, run_loss=0.0868]
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.83batch/s]


Epoch 3/30 F1: 0.5326 F1 Ign: 0.5146


Train Epoch 4/30: 100%|██████████| 763/763 [07:09<00:00,  1.78batch/s, cur_loss=0.0341, run_loss=0.0738]
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.77batch/s]


Epoch 4/30 F1: 0.5612 F1 Ign: 0.5394


Train Epoch 5/30: 100%|██████████| 763/763 [07:11<00:00,  1.77batch/s, cur_loss=0.0395, run_loss=0.0631]
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.72batch/s]


Epoch 5/30 F1: 0.5863 F1 Ign: 0.5667


Train Epoch 6/30: 100%|██████████| 763/763 [06:52<00:00,  1.85batch/s, cur_loss=0.0459, run_loss=0.0552]
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.74batch/s]


Epoch 6/30 F1: 0.5920 F1 Ign: 0.5671


Train Epoch 7/30: 100%|██████████| 763/763 [07:09<00:00,  1.78batch/s, cur_loss=0.0622, run_loss=0.0475]
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.86batch/s]


Epoch 7/30 F1: 0.5943 F1 Ign: 0.5690


Train Epoch 8/30: 100%|██████████| 763/763 [07:10<00:00,  1.77batch/s, cur_loss=0.0443, run_loss=0.0409] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.83batch/s]


Epoch 8/30 F1: 0.5940 F1 Ign: 0.5730


Train Epoch 9/30: 100%|██████████| 763/763 [07:09<00:00,  1.78batch/s, cur_loss=0.0257, run_loss=0.034]  
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.76batch/s]


Epoch 9/30 F1: 0.6119 F1 Ign: 0.5920


Train Epoch 10/30: 100%|██████████| 763/763 [07:12<00:00,  1.76batch/s, cur_loss=0.0348, run_loss=0.0295] 
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.78batch/s]


Epoch 10/30 F1: 0.5971 F1 Ign: 0.5800


Train Epoch 11/30: 100%|██████████| 763/763 [07:11<00:00,  1.77batch/s, cur_loss=0.0137, run_loss=0.025]  
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.83batch/s]


Epoch 11/30 F1: 0.6000 F1 Ign: 0.5815


Train Epoch 12/30: 100%|██████████| 763/763 [07:10<00:00,  1.77batch/s, cur_loss=0.0455, run_loss=0.022]  
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.84batch/s]


Epoch 12/30 F1: 0.6005 F1 Ign: 0.5833


Train Epoch 13/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.0267, run_loss=0.0187] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.93batch/s]


Epoch 13/30 F1: 0.6131 F1 Ign: 0.5922


Train Epoch 14/30: 100%|██████████| 763/763 [07:07<00:00,  1.78batch/s, cur_loss=0.0147, run_loss=0.0174] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.90batch/s]


Epoch 14/30 F1: 0.6122 F1 Ign: 0.5917


Train Epoch 15/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.0164, run_loss=0.0153] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.87batch/s]


Epoch 15/30 F1: 0.6060 F1 Ign: 0.5862


Train Epoch 16/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.0155, run_loss=0.0137] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.95batch/s]


Epoch 16/30 F1: 0.6100 F1 Ign: 0.5932


Train Epoch 17/30: 100%|██████████| 763/763 [07:07<00:00,  1.79batch/s, cur_loss=0.00231, run_loss=0.0119] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.87batch/s]


Epoch 17/30 F1: 0.6179 F1 Ign: 0.5970


Train Epoch 18/30: 100%|██████████| 763/763 [07:06<00:00,  1.79batch/s, cur_loss=0.0148, run_loss=0.0112]  
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.91batch/s]


Epoch 18/30 F1: 0.6212 F1 Ign: 0.6021


Train Epoch 19/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.00669, run_loss=0.0101] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.88batch/s]


Epoch 19/30 F1: 0.6250 F1 Ign: 0.6048


Train Epoch 20/30: 100%|██████████| 763/763 [07:10<00:00,  1.77batch/s, cur_loss=0.0112, run_loss=0.00904]  
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.85batch/s]


Epoch 20/30 F1: 0.6247 F1 Ign: 0.6068


Train Epoch 21/30: 100%|██████████| 763/763 [07:14<00:00,  1.76batch/s, cur_loss=0.00286, run_loss=0.00852] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 21/30 F1: 0.6268 F1 Ign: 0.6089


Train Epoch 22/30: 100%|██████████| 763/763 [07:09<00:00,  1.78batch/s, cur_loss=0.008, run_loss=0.00753]   
Validation: 100%|██████████| 125/125 [00:46<00:00,  2.71batch/s]


Epoch 22/30 F1: 0.6291 F1 Ign: 0.6082


Train Epoch 23/30: 100%|██████████| 763/763 [07:11<00:00,  1.77batch/s, cur_loss=0.00574, run_loss=0.00701] 
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.73batch/s]


Epoch 23/30 F1: 0.6244 F1 Ign: 0.6064


Train Epoch 24/30: 100%|██████████| 763/763 [07:09<00:00,  1.78batch/s, cur_loss=0.000482, run_loss=0.00656]
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 24/30 F1: 0.6288 F1 Ign: 0.6108


Train Epoch 25/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.00776, run_loss=0.00603] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 25/30 F1: 0.6296 F1 Ign: 0.6114


Train Epoch 26/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.00811, run_loss=0.00543] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.93batch/s]


Epoch 26/30 F1: 0.6334 F1 Ign: 0.6149


Train Epoch 27/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.00103, run_loss=0.00504] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.89batch/s]


Epoch 27/30 F1: 0.6310 F1 Ign: 0.6134


Train Epoch 28/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.00335, run_loss=0.00461] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.86batch/s]


Epoch 28/30 F1: 0.6326 F1 Ign: 0.6148


Train Epoch 29/30: 100%|██████████| 763/763 [07:10<00:00,  1.77batch/s, cur_loss=0.00685, run_loss=0.0043]  
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 29/30 F1: 0.6348 F1 Ign: 0.6166


Train Epoch 30/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.00419, run_loss=0.004]   
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.83batch/s]


Epoch 30/30 F1: 0.6336 F1 Ign: 0.6156
