In [6]:
%load_ext autoreload
%autoreload 2

import os
from utils import set_seed
import torch

set_seed(1234)
os.environ['HF_HOME'] = 'hf_cache' # Don't want model files in our home directory due to disk quota

torch.cuda.is_available()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


True

In [7]:
from transformers import LukeTokenizer
import const
from model import DocRedModel
import json

MODEL_NAME = const.LUKE_LARGE

rel2id = json.load(open('data/meta/rel2id.json'))
id2rel = {v: k for k, v in rel2id.items()}
tokenizer = LukeTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True)
model = DocRedModel(model_name=MODEL_NAME,
                    tokenizer=tokenizer,
                    num_class=len(rel2id)).to(const.DEVICE)



In [8]:
from utils import read_docred
import json

train_samples = read_docred(fp='/data2/nhanse02/thesis/data/train_annotated.json', rel2id=rel2id, tokenizer=tokenizer)
dev_samples = read_docred(fp='/data2/nhanse02/thesis/data/dev.json', rel2id=rel2id, tokenizer=tokenizer)

/data2/nhanse02/thesis/data/train_annotated.json: 100%|██████████| 3053/3053 [00:21<00:00, 140.01it/s]
/data2/nhanse02/thesis/data/dev.json: 100%|██████████| 998/998 [00:08<00:00, 121.84it/s]


In [9]:
from torch.utils.data import DataLoader
from utils import collate_fn

train_dataloader = DataLoader(train_samples, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=True)
dev_dataloader = DataLoader(dev_samples, batch_size=8, shuffle=False, collate_fn=collate_fn, drop_last=False)

In [10]:
from train import train

train(model=model,
      train_dataloader=train_dataloader,
      dev_dataloader=dev_dataloader,
      dev_samples=dev_samples,
      id2rel=id2rel,
      num_epochs=30)

Train Epoch 1/30: 100%|██████████| 763/763 [07:15<00:00,  1.75batch/s, cur_loss=0.106, run_loss=0.524] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.82batch/s]


Epoch 1/30 F1: 0.5179 F1 Ign: 0.5013


Train Epoch 2/30: 100%|██████████| 763/763 [07:07<00:00,  1.78batch/s, cur_loss=0.0539, run_loss=0.097] 
Validation: 100%|██████████| 125/125 [00:46<00:00,  2.69batch/s]


Epoch 2/30 F1: 0.5106 F1 Ign: 0.4912


Train Epoch 3/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.135, run_loss=0.0839] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.80batch/s]


Epoch 3/30 F1: 0.5350 F1 Ign: 0.5198


Train Epoch 4/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.0462, run_loss=0.07]  
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.82batch/s]


Epoch 4/30 F1: 0.5554 F1 Ign: 0.5399


Train Epoch 5/30: 100%|██████████| 763/763 [07:06<00:00,  1.79batch/s, cur_loss=0.0347, run_loss=0.0589]
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.90batch/s]


Epoch 5/30 F1: 0.5539 F1 Ign: 0.5435


Train Epoch 6/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.0669, run_loss=0.0513] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.80batch/s]


Epoch 6/30 F1: 0.6043 F1 Ign: 0.5819


Train Epoch 7/30: 100%|██████████| 763/763 [07:08<00:00,  1.78batch/s, cur_loss=0.0278, run_loss=0.0444]
Validation: 100%|██████████| 125/125 [00:46<00:00,  2.71batch/s]


Epoch 7/30 F1: 0.5981 F1 Ign: 0.5764


Train Epoch 8/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.0263, run_loss=0.0389] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.82batch/s]


Epoch 8/30 F1: 0.6013 F1 Ign: 0.5775


Train Epoch 9/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.0365, run_loss=0.0334] 
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.77batch/s]


Epoch 9/30 F1: 0.6018 F1 Ign: 0.5800


Train Epoch 10/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.0198, run_loss=0.029]  
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 10/30 F1: 0.6036 F1 Ign: 0.5833


Train Epoch 11/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.0234, run_loss=0.0253] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.91batch/s]


Epoch 11/30 F1: 0.6064 F1 Ign: 0.5864


Train Epoch 12/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.0116, run_loss=0.0222] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.88batch/s]


Epoch 12/30 F1: 0.6095 F1 Ign: 0.5905


Train Epoch 13/30: 100%|██████████| 763/763 [07:05<00:00,  1.80batch/s, cur_loss=0.0133, run_loss=0.0191] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.78batch/s]


Epoch 13/30 F1: 0.6127 F1 Ign: 0.5920


Train Epoch 14/30: 100%|██████████| 763/763 [06:58<00:00,  1.82batch/s, cur_loss=0.0236, run_loss=0.0175] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.91batch/s]


Epoch 14/30 F1: 0.6145 F1 Ign: 0.5945


Train Epoch 15/30: 100%|██████████| 763/763 [07:00<00:00,  1.81batch/s, cur_loss=0.00377, run_loss=0.0151]
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.86batch/s]


Epoch 15/30 F1: 0.6146 F1 Ign: 0.5941


Train Epoch 16/30: 100%|██████████| 763/763 [07:06<00:00,  1.79batch/s, cur_loss=0.00913, run_loss=0.0137]
Validation: 100%|██████████| 125/125 [00:47<00:00,  2.61batch/s]


Epoch 16/30 F1: 0.6125 F1 Ign: 0.5914


Train Epoch 17/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.00824, run_loss=0.0122] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.87batch/s]


Epoch 17/30 F1: 0.6147 F1 Ign: 0.5966


Train Epoch 18/30: 100%|██████████| 763/763 [07:06<00:00,  1.79batch/s, cur_loss=0.0166, run_loss=0.0112]  
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.86batch/s]


Epoch 18/30 F1: 0.6154 F1 Ign: 0.5942


Train Epoch 19/30: 100%|██████████| 763/763 [07:02<00:00,  1.81batch/s, cur_loss=0.00712, run_loss=0.0102] 
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.90batch/s]


Epoch 19/30 F1: 0.6222 F1 Ign: 0.6040


Train Epoch 20/30: 100%|██████████| 763/763 [07:03<00:00,  1.80batch/s, cur_loss=0.0125, run_loss=0.00952]  
Validation: 100%|██████████| 125/125 [00:43<00:00,  2.88batch/s]


Epoch 20/30 F1: 0.6179 F1 Ign: 0.5982


Train Epoch 21/30: 100%|██████████| 763/763 [07:03<00:00,  1.80batch/s, cur_loss=0.0158, run_loss=0.00836]  
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 21/30 F1: 0.6237 F1 Ign: 0.6058


Train Epoch 22/30: 100%|██████████| 763/763 [07:05<00:00,  1.79batch/s, cur_loss=0.0124, run_loss=0.00762]  
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.83batch/s]


Epoch 22/30 F1: 0.6216 F1 Ign: 0.6038


Train Epoch 23/30: 100%|██████████| 763/763 [07:03<00:00,  1.80batch/s, cur_loss=0.00942, run_loss=0.00709] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 23/30 F1: 0.6245 F1 Ign: 0.6047


Train Epoch 24/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.0132, run_loss=0.00666]  
Validation: 100%|██████████| 125/125 [00:45<00:00,  2.73batch/s]


Epoch 24/30 F1: 0.6280 F1 Ign: 0.6088


Train Epoch 25/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.00877, run_loss=0.006]   
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.92batch/s]


Epoch 25/30 F1: 0.6301 F1 Ign: 0.6112


Train Epoch 26/30: 100%|██████████| 763/763 [07:06<00:00,  1.79batch/s, cur_loss=0.00775, run_loss=0.00559] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.91batch/s]


Epoch 26/30 F1: 0.6310 F1 Ign: 0.6120


Train Epoch 27/30: 100%|██████████| 763/763 [07:02<00:00,  1.81batch/s, cur_loss=0.00525, run_loss=0.00512] 
Validation: 100%|██████████| 125/125 [00:44<00:00,  2.78batch/s]


Epoch 27/30 F1: 0.6314 F1 Ign: 0.6133


Train Epoch 28/30: 100%|██████████| 763/763 [07:03<00:00,  1.80batch/s, cur_loss=0.00426, run_loss=0.00471] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.93batch/s]


Epoch 28/30 F1: 0.6313 F1 Ign: 0.6126


Train Epoch 29/30: 100%|██████████| 763/763 [07:01<00:00,  1.81batch/s, cur_loss=0.00309, run_loss=0.00443] 
Validation: 100%|██████████| 125/125 [00:42<00:00,  2.93batch/s]


Epoch 29/30 F1: 0.6329 F1 Ign: 0.6140


Train Epoch 30/30: 100%|██████████| 763/763 [07:04<00:00,  1.80batch/s, cur_loss=0.00407, run_loss=0.00408] 
Validation: 100%|██████████| 125/125 [00:46<00:00,  2.71batch/s]


Epoch 30/30 F1: 0.6333 F1 Ign: 0.6149
