In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'

from utils import set_seed
import torch

set_seed(1234)
os.environ['HF_HOME'] = 'hf_cache' # Don't want model files in our home directory due to disk quota

torch.cuda.is_available()

In [None]:
from transformers import LukeTokenizer
import const
from model import DocRedModel
import json
from utils import get_holdouts, remove_holdouts, read_docred, collate_fn, add_pseudolabels
from train import train_official, train_contr_candidates, train_contr_cluster
from torch.utils.data import DataLoader

MODEL_NAME = const.LUKE_BASE
ENCODER_LR = const.LRS[MODEL_NAME]

CAND_EPOCHS = 20
CLUST_EPOCHS = 30
OFFICIAL_EPOCHS = 30

CAND_TMP = 0.01
CLUST_TMP = 0.01

CONTR_CAND_SUP_WT = 0.5
CONTR_CAND_EMBED_SIZE = 768
CONTR_CAND_NORMALIZE = True

CONTR_CLUST_SUP_WT = 0.5
CONTR_CLUST_EMBED_SIZE = 768
CONTR_CLUST_NORMALIZE = True

TRAIN_BATCH_SIZE = 4 # Number of documents to encode not number of entity-pair samples
DEV_BATCH_SIZE = 8

# Subdirectory organization
# Contrastive Model
# --- Dual-Objective Model (uses candidate mask from parent)
# ------ Official Holdout Model (uses candidate mask/clusters from parent)

def run_pipline():
    rel2id_original = json.load(open('data/meta/rel2id.json'))
    id2rel_original = {v: k for k, v in rel2id_original.items()}
    tokenizer = LukeTokenizer.from_pretrained(MODEL_NAME, add_prefix_space=True) # Prefix space since we are doing word-level tokenization in read_docred (docred is weird)

    holdout_rel_batches = get_holdouts(train_samples_fp=const.TRAIN_SAMPLES_FP, # Get rel batches
                                       dev_samples_fp=const.DEV_SAMPLES_FP,
                                       rel2id=rel2id_original,
                                       id2rel=id2rel_original,
                                       tokenizer=tokenizer)
    
    for i, holdout_rels in enumerate(holdout_rel_batches): # NOTE: Only do first batch for now
        if i == 1 or i == 2:
            continue

        out_dir = os.path.join('out', f'holdout-batch-{i}')

        train_samples = read_docred(fp=const.TRAIN_SAMPLES_FP, rel2id=rel2id_original, tokenizer=tokenizer)
        dev_samples = read_docred(fp=const.DEV_SAMPLES_FP, rel2id=rel2id_original, tokenizer=tokenizer)  

        train_samples, dev_samples, rel2id_holdout, id2rel_holdout = remove_holdouts(train_samples=train_samples,
                                                                                     dev_samples=dev_samples,
                                                                                     holdout_rels=holdout_rels,
                                                                                     rel2id=rel2id_original,
                                                                                     id2rel=id2rel_original)
                    
        train_dataloader = DataLoader(train_samples, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate_fn, drop_last=True)
        val_train_dataloader = DataLoader(train_samples, batch_size=DEV_BATCH_SIZE, shuffle=False, collate_fn=collate_fn, drop_last=False)
        dev_dataloader = DataLoader(dev_samples, batch_size=DEV_BATCH_SIZE, shuffle=False, collate_fn=collate_fn, drop_last=False)

        # -=-=-=-=-==-=-=-=-=-
        # CANDIDATE TRAINING
        # -=-=-=-=-==-=-=-=-=-
        cand_dir = os.path.join(out_dir, f"{const.MODE_CONTRASTIVE_CANDIDATES}_{MODEL_NAME.split('/')[-1]}_tmp-{CAND_TMP}_supw-{CONTR_CAND_SUP_WT}_embed-{CONTR_CAND_EMBED_SIZE}_norm-{CONTR_CAND_NORMALIZE}")
        if not os.path.exists(cand_dir): # if candidate mask already exists
            print("CANDIDATE TRAINING...")
            cand_model = DocRedModel(model_name=MODEL_NAME,
                                     tokenizer=tokenizer,
                                     num_class=len(rel2id_holdout),
                                     mode=const.MODE_CONTRASTIVE_CANDIDATES,
                                     contr_tmp=CAND_TMP,
                                     contr_cand_sup_wt=CONTR_CAND_SUP_WT,
                                     out_embed_size=CONTR_CAND_EMBED_SIZE).to(const.DEVICE)
            
            train_contr_candidates(model=cand_model,
                                   train_dataloader=train_dataloader,
                                   val_train_dataloader=val_train_dataloader,
                                   id2rel_original=id2rel_original,
                                   normalize_embeds=CONTR_CAND_NORMALIZE,
                                   encoder_lr=ENCODER_LR,
                                   num_epochs=CAND_EPOCHS,
                                   out_dir=cand_dir)
            del cand_model
            with torch.cuda.device(const.DEVICE):
                torch.cuda.empty_cache()
        else:
            print("CONTRASTIVE CANDIDATE TRAINING COMPLETED PREVIOUSLY, SKIPPING...")

        # -=-=-=-=-==-=-=-=-=-
        # CLUSTER TRAINING
        # -=-=-=-=-==-=-=-=-=-
        clust_dir = os.path.join(cand_dir, f"{const.MODE_CONTRASTIVE_CLUSTER}_{MODEL_NAME.split('/')[-1]}_tmp-{CLUST_TMP}_supw-{CONTR_CLUST_SUP_WT}_embed-{CONTR_CLUST_EMBED_SIZE}_norm-{CONTR_CLUST_NORMALIZE}")
        clust_model = DocRedModel(model_name=MODEL_NAME,
                                    tokenizer=tokenizer,
                                    num_class=len(rel2id_holdout),
                                    mode=const.MODE_CONTRASTIVE_CLUSTER,
                                    contr_tmp=CLUST_TMP,
                                    contr_clust_sup_wt=CONTR_CLUST_SUP_WT,
                                    out_embed_size=CONTR_CLUST_EMBED_SIZE).to(const.DEVICE)
        clust_train = True
        clust_last_epoch = 0 # Load last epoch if exists
        if os.path.exists(clust_dir):
            clust_last_epoch = json.load(open(os.path.join(clust_dir, 'stats', 'stats.json')))[-1]['epoch']
            if clust_last_epoch >= CLUST_EPOCHS:
                print("CONTRASTIVE CLUSTER TRAINING COMPLETED PREVIOUSLY, SKIPPING...")
                clust_train = False
            else:
                print("LOADING LAST CLUSTERING MODEL CHECKPOINT...")
                clust_model.load_state_dict(torch.load(os.path.join(clust_dir, 'checkpoints-lower-epoch', 'latest-checkpoint.pt'), weights_only=True))

        if clust_train:
            print("CLUSTER TRAINING...")
            train_contr_cluster(model=clust_model,
                                train_samples=train_samples,
                                train_dataloader=train_dataloader,
                                val_train_dataloader=val_train_dataloader,
                                cand_mask=torch.load(os.path.join(cand_dir, 'checkpoints', 'latest-contr-cand-mask.pt')),
                                id2rel_original=id2rel_original,
                                id2rel_holdout=id2rel_holdout,
                                normalize_embeds=CONTR_CLUST_NORMALIZE,
                                encoder_lr=ENCODER_LR,
                                num_epochs=CLUST_EPOCHS,
                                last_epoch=clust_last_epoch,
                                out_dir=clust_dir)
        del clust_model
        with torch.cuda.device(const.DEVICE):
            torch.cuda.empty_cache()

        # -=-=-=-=-==-=-=-=-=-
        # OFFICIAL TRAINING
        # -=-=-=-=-==-=-=-=-=-
        # Load pseudolabels to train samples and dev samples -> modifies samples in place
        rel2id_holdout, id2rel_holdout = add_pseudolabels(train_samples=train_samples,
                                                          dev_samples=dev_samples,
                                                          pseudolabels=torch.load(os.path.join(clust_dir, 'checkpoints', 'latest-pseudolabels.pt')),
                                                          id2rel_holdout_update={int(k): v for k, v in json.load(open(os.path.join(clust_dir, 'checkpoints', 'latest-id2rel-holdout-update.json'))).items()}, # NOTE: have to cast int keys to int, json doesn't support int keys
                                                          id2rel_holdout=id2rel_holdout,
                                                          id2rel_original=id2rel_original)

        official_dir = os.path.join(clust_dir, f"{const.MODE_OFFICIAL}_{MODEL_NAME.split('/')[-1]}")
        if not os.path.exists(official_dir):
            print("OFFICIAL TRAINING...")

            official_model = DocRedModel(model_name=MODEL_NAME,
                                        tokenizer=tokenizer,
                                        num_class=len(rel2id_holdout),
                                        mode=const.MODE_OFFICIAL).to(const.DEVICE)
            
            train_official(model=official_model,
                           train_dataloader=train_dataloader,
                           dev_dataloader=dev_dataloader,
                           dev_samples=dev_samples,
                           id2rel_holdout=id2rel_holdout,
                           encoder_lr=ENCODER_LR,
                           num_epochs=OFFICIAL_EPOCHS,
                           out_dir=official_dir)
            
            del official_model
            with torch.cuda.device(const.DEVICE):
                torch.cuda.empty_cache()
        else:
            print("OFFICIAL TRAINING COMPLETED PREVIOUSLY, SKIPPING...")

run_pipline()

In [None]:
# def run_official():
#     model_name = const.LUKE_BASE
#     out_dir = os.path.join('out', 'official_base')

#     rel2id_original = json.load(open('data/meta/rel2id.json'))
#     id2rel_original = {v: k for k, v in rel2id_original.items()}
#     tokenizer = LukeTokenizer.from_pretrained(model_name, add_prefix_space=True) # Prefix space since we are doing word-level tokenization in read_docred (docred is weird)

#     train_samples = read_docred(fp=const.TRAIN_SAMPLES_FP, rel2id=rel2id_original, tokenizer=tokenizer)
#     dev_samples = read_docred(fp=const.DEV_SAMPLES_FP, rel2id=rel2id_original, tokenizer=tokenizer)  

#     train_dataloader = DataLoader(train_samples, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=True)
#     dev_dataloader = DataLoader(dev_samples, batch_size=8, shuffle=False, collate_fn=collate_fn, drop_last=False)

#     model = DocRedModel(model_name=model_name,
#                         tokenizer=tokenizer,
#                         num_class=len(rel2id_original),
#                         mode=const.MODE_OFFICIAL).to(const.DEVICE)
    
#     train_official(model=model,
#                    train_dataloader=train_dataloader,
#                    dev_dataloader=dev_dataloader,
#                    dev_samples=dev_samples,
#                    id2rel_holdout=id2rel_original,
#                    encoder_lr=const.LRS[model_name],
#                    num_epochs=30,
#                    out_dir=out_dir)
# run_official()