# import

In [None]:
import argparse
from tqdm import tqdm
import json
import time
from datetime import datetime, timedelta, timezone
import os
import math
import logging
from collections import Counter
import random

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data as Data
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
import torch
from torch import optim, nn, Tensor

root_dir = '.'
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# constants

In [None]:
max_len = 100

# relation dict
rel2id = {'per:title': 0, 'per:stateorprovinces_of_residence': 1,
          'per:stateorprovince_of_death': 2, 'per:stateorprovince_of_birth': 3,
          'per:spouse': 4, 'per:siblings': 5, 'per:schools_attended': 6,
          'per:religion': 7, 'per:parents': 8, 'per:other_family': 9,
          'per:origin': 10, 'per:employee_of': 11, 'per:date_of_death': 12,
          'per:date_of_birth': 13, 'per:country_of_death': 14, 'per:country_of_birth': 15,
          'per:countries_of_residence': 16, 'per:city_of_death': 17, 'per:city_of_birth': 18,
          'per:cities_of_residence': 19, 'per:children': 20, 'per:charges': 21,
          'per:cause_of_death': 22, 'per:alternate_names': 23, 'per:age': 24,
          'org:website': 25, 'org:top_members/employees': 26, 'org:subsidiaries': 27,
          'org:stateorprovince_of_headquarters': 28, 'org:shareholders': 29,
          'org:political/religious_affiliation': 30, 'org:parents': 31,
          'org:number_of_employees/members': 32, 'org:members': 33,
          'org:member_of': 34, 'org:founded_by': 35, 'org:founded': 36,
          'org:dissolved': 37, 'org:country_of_headquarters': 38,
          'org:city_of_headquarters': 39, 'org:alternate_names': 40,
          'no_relation': 41}
id2rel = {key: value for value, key in rel2id.items()}

# part-of-speech dict
pos2id = {'PAD': 0, 'UNK': 1, 'NNP': 2, 'NN': 3, 'IN': 4, 'DT': 5, ',': 6, 'JJ': 7, 'NNS': 8, 'VBD': 9, 'CD': 10,
          'CC': 11, '.': 12, 'RB': 13, 'VBN': 14, 'PRP': 15, 'TO': 16, 'VB': 17, 'VBG': 18, 'VBZ': 19, 'PRP$': 20,
          ':': 21, 'POS': 22, '\'\'': 23, '``': 24, '-RRB-': 25, '-LRB-': 26, 'VBP': 27, 'MD': 28, 'NNPS': 29, 'WP': 30,
          'WDT': 31, 'WRB': 32, 'RP': 33, 'JJR': 34, 'JJS': 35, '$': 36, 'FW': 37, 'RBR': 38, 'SYM': 39, 'EX': 40,
          'RBS': 41, 'WP$': 42, 'PDT': 43, 'LS': 44, 'UH': 45, '#': 46}

# named-entity-recognition dict
ner2id = {'PAD': 0, 'UNK': 1, 'O': 2, 'PERSON': 3, 'ORGANIZATION': 4, 'LOCATION': 5, 'DATE': 6, 'NUMBER': 7,
          'MISC': 8, 'DURATION': 9, 'MONEY': 10, 'PERCENT': 11, 'ORDINAL': 12, 'TIME': 13, 'SET': 14}


# args

In [None]:
parser = argparse.ArgumentParser()

# batch_size
parser.add_argument('--batch_size', default=32,
                    type=int, help='Size of mini batch.')
# max_epochs
parser.add_argument('--max_epochs', default=30, type=int,
                    help='Max epochs for training.')
# use_glove
parser.add_argument('--use_glove', default=False, type=bool,
                    help='Whether to use Glove as Pre_trained Embedding')
# word_embed_dim
parser.add_argument('--word_embed_dim', default=300, type=int,
                    help='Size of Word Embedding.')
# part-of-speech embed_dim
parser.add_argument('--pos_embed_dim', default=30, type=int,
                    help='Size of Part-of-Speech Embedding.')
# ner_embed_dim
parser.add_argument('--ner_embed_dim', default=30, type=int,
                    help='Size of Named-Entity-Recognition Embedding.')
# position_embed_dim
parser.add_argument('--position_embed_dim', default=30, type=int,
                    help='Size of Position Encoding Embedding.')
# hidden_dim
parser.add_argument('--hidden_dim', default=300, type=int,
                    help='Size of Hidden Layer.')
# attn_dim
parser.add_argument('--attn_dim', default=300, type=int,
                    help='Size of Attention Layer.')
# dropout
parser.add_argument('--dropout', default=0.5, type=float,
                    help='Dropout rate.')
# optimizer
parser.add_argument('--optimizer', default='sgd', type=str,
                    help='Choose Optimizer from SGD/ Adam/ Adadelta.')
# learning rate
parser.add_argument('--lr', default=1e-1, type=float,
                    help='Initial learning rate.')
# momentum
parser.add_argument('--momentum', default=0.9, type=float,
                    help='Momentum factor for SGD.')
# weight_decay
parser.add_argument('--weight_decay', default=1e-2, type=float,
                    help='Weight decay (L2 penalty).')
# grad_clip
parser.add_argument('--grad_clip', default=5.0, type=int,
                    help='Max norm of the gradients clipping.')
# random_seed
parser.add_argument('--random_seed', default=16, type=int,
                    help='Sets the seed for generating random numbers.')
# resume
parser.add_argument('--resume', default=False, type=bool,
                    help='Whether to load the checkpoints to resume training.')

args = parser.parse_args(args=[])


# data_processor

## Preprocess

In [None]:
def preprocess():
    def counter(sequences, threshold):
        words_cnt = dict()
        for seq in sequences:
            for token in seq:
                if token in words_cnt:
                    words_cnt[token] += 1
                else:
                    words_cnt[token] = 1
        words_cnt = {key: words_cnt[key]
                     for key in words_cnt if words_cnt[key] >= threshold}
        words_cnt = sorted(words_cnt.items(), key=lambda x: x[1], reverse=True)
        return list(map(lambda x: x[0], words_cnt))

    def entity_masks(_samp):
        _t = [token.lower() for token in _samp['token']]
        subj_start, subj_end = _samp['subj_start'], _samp['subj_end']
        obj_start, obj_end = _samp['obj_start'], _samp['obj_end']
        _t[subj_start: subj_end + 1] = ['SUBJ-' +
                                        _samp['subj_type']] * (subj_end - subj_start + 1)
        _t[obj_start: obj_end + 1] = ['OBJ-' +
                                      _samp['obj_type']] * (obj_end - obj_start + 1)
        return _t

    data = []
    for path in ['./data/original/train.json',
                 './data/original/dev.json',
                 './data/original/test.json']:
        with open(path, 'r', encoding='utf8') as f:
            _obj = json.load(f)

        data.append([{'token': entity_masks(sample),
                      'relation': rel2id[sample['relation']],
                      'subj_start': sample['subj_start'],
                      'subj_end': sample['subj_end'],
                      'obj_start': sample['obj_start'],
                      'obj_end': sample['obj_end'],
                      'pos': [pos2id[_p] for _p in sample['stanford_pos']],
                      'ner': [ner2id[_n] for _n in sample['stanford_ner']]} for sample in _obj])

    words = counter([samp['token'] for samp in data[0]], threshold=1)
    word2id = {'pad': 0, 'unk': 1}
    idx = 2
    for w in words:
        if w not in word2id:
            word2id[w] = idx
            idx += 1

    weight_matrix = load_glove_emb(glove_path='./data/glove.6B.300d.txt',
                                   embed_dim=300,
                                   word2id=word2id)
    np.save('./data/weight_matrix.npy', weight_matrix)

    with open('./data/vocab.json', 'w') as f:
        json.dump(word2id, f)

    for _obj, path in zip(data, ['./data/train.json',
                                 './data/dev.json',
                                 './data/test.json']):
        with open(path, 'w') as f:
            json.dump(_obj, f)


def load_glove_emb(glove_path, embed_dim, word2id):
    pre_trained_weights = np.random.randn(len(word2id), embed_dim).astype(
        np.float32) * np.sqrt(2.0 / len(word2id))
    with open(glove_path, 'r', encoding="utf8") as f:
        for line in f:
            line = line.split(' ')
            word = line[0]
            embedding_vector = np.array(line[1:]).astype(np.float32)
            if word in word2id:
                pre_trained_weights[word2id[word]] = embedding_vector

    return pre_trained_weights

# preprocess()


## TacRedDataset

In [None]:
class TacRedDataset(Data.Dataset):
    def __init__(self, data, word2id):
        super(TacRedDataset, self).__init__()

        self.sent_tensors_list, self.rel_list = [], []
        self.subj_start_list, self.subj_end_list = [], []
        self.obj_start_list, self.obj_end_list = [], []
        self.pos_tensors_list, self.ner_tensors_list = [], []

        for sample in data:
            tokens = sample['token']
            sent = torch.zeros(len(tokens), dtype=torch.long).to(device)
            for t_idx, token in enumerate(tokens):
                try:
                    sent[t_idx] = word2id[token]
                except KeyError:
                    sent[t_idx] = word2id['unk']
            self.sent_tensors_list.append(sent)

            self.rel_list.append(int(sample['relation']))

            self.subj_start_list.append(int(sample['subj_start']))
            self.subj_end_list.append(int(sample['subj_end']))
            self.obj_start_list.append(int(sample['obj_start']))
            self.obj_end_list.append(int(sample['obj_end']))

            self.pos_tensors_list.append(torch.tensor(sample['pos'], dtype=torch.long).to(device))

            self.ner_tensors_list.append(torch.tensor(sample['ner'], dtype=torch.long).to(device))

    def __len__(self):
        return len(self.rel_list)

    def __getitem__(self, index):
        sent = self.sent_tensors_list[index]
        rel = self.rel_list[index]
        subj_start = self.subj_start_list[index]
        subj_end = self.subj_end_list[index]
        obj_start = self.obj_start_list[index]
        obj_end = self.obj_end_list[index]
        pos = self.pos_tensors_list[index]
        ner = self.ner_tensors_list[index]

        return sent, rel, subj_start, subj_end, obj_start, obj_end, pos, ner


## DataLoader

In [None]:
def get_data_loader(dataset: Data.Dataset, batch_size: int, shuffle: bool):
    def collate_fn(batch):
        sent, rel, subj_start, subj_end, obj_start, obj_end, pos, ner = zip(*batch)
        _len = [len(s) for s in sent]

        # sentence tensor
        sent = pad_sequence(sent, batch_first=True)

        # relation tensor
        rel = torch.tensor(rel, dtype=torch.long).to(device)

        _max_len = max(_len)

        # subject position encoding
        subj_pos = torch.zeros((len(batch), _max_len), dtype=torch.long).to(device)
        for batch_idx, (_start, _end) in enumerate(zip(subj_start, subj_end)):
            subj_pos[batch_idx] = get_positions(_start, _end, _max_len)

        # object position encoding
        obj_pos = torch.zeros((len(batch), _max_len), dtype=torch.long).to(device)
        for batch_idx, (_start, _end) in enumerate(zip(obj_start, obj_end)):
            obj_pos[batch_idx] = get_positions(_start, _end, _max_len)

        pos = pad_sequence(pos, batch_first=True)

        ner = pad_sequence(ner, batch_first=True)

        # must be on the CPU if provided as a tensor
        tensor_len = torch.tensor(_len, dtype=torch.long)

        return sent, rel, subj_pos, obj_pos, pos, ner, tensor_len

    data_loader = Data.DataLoader(dataset=dataset,
                                  batch_size=batch_size,
                                  collate_fn=collate_fn,
                                  shuffle=shuffle)
    return data_loader


def get_positions(start_idx, end_idx, length):
    """ Get subj/obj position encoding """
    position_enc = list(range(-start_idx, 0)) + \
        [0]*(end_idx - start_idx + 1) + list(range(1, length - end_idx))
    return torch.tensor(position_enc, dtype=torch.long).to(device)


# early stopping

In [None]:
class EarlyStopping(object):
    def __init__(self, mode='min', min_delta=0.001, patience=5, percentage=False):
        self.mode = mode
        self.min_delta = min_delta
        self.patience = patience
        self.best = None
        self.num_bad_epochs = 0
        self.is_better = None
        self._init_is_better(mode, min_delta, percentage)

        if patience == 0:
            self.is_better = lambda a, b: True
            self.step = lambda a: False

    def step(self, metrics):
        if self.best is None:
            self.best = metrics
            return False
        if math.isnan(metrics):
            return True
        if self.is_better(metrics, self.best):
            self.num_bad_epochs = 0
            self.best = metrics
        else:
            self.num_bad_epochs += 1
        if self.num_bad_epochs >= self.patience:
            return True
        return False

    def _init_is_better(self, mode, min_delta, percentage):
        if mode not in {'min', 'max'}:
            raise ValueError('illegal mode: ', mode)
        if not percentage:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - min_delta
            if mode == 'max':
                self.is_better = lambda a, best: a > best + min_delta
        else:
            if mode == 'min':
                self.is_better = lambda a, best: a < best - (best * min_delta / 100)
            if mode == 'max':
                self.is_better = lambda a, best: a > best + (best * min_delta / 100)


# utils

In [None]:
def save_model(experiment_time, model, optimizer):
    mkdir(f'{root_dir}/results/checkpoints')
    checkpoint_path = f'{root_dir}/results/checkpoints/' + experiment_time + '.pth'
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict}
    torch.save(checkpoint, checkpoint_path)


def load_model(latest, file_name=None):
    """ load the latest checkpoint """
    checkpoints_dir = f'{root_dir}/results/checkpoints'
    if latest:
        file_list = os.listdir(checkpoints_dir)
        file_list.sort(key=lambda fn: os.path.getmtime(
            checkpoints_dir + '/' + fn))
        checkpoint = torch.load(checkpoints_dir + '/' + file_list[-1])
        return checkpoint, str(file_list[-1])
    else:
        if file_name is None:
            raise ValueError('checkpoint_path cannot be empty!')
        checkpoint = torch.load(checkpoints_dir + '/' + file_name)
        return checkpoint, file_name


def weights_init(m: nn.Module):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.xavier_normal_(param.data, gain=1.0)
        else:
            nn.init.constant_(param.data, 0)


def fix_seed(seed):
    """ fix seed to ensure reproducibility """
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(" %(message)s")
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])
    fh = logging.FileHandler(filename, "a")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


def record_time(start_time, end_time):
    """ get minute & second-level measurement of the asc-time """
    elapsed_time = end_time - start_time
    elapsed_min = int(elapsed_time / 60)
    elapsed_sec = int(elapsed_time - (elapsed_min * 60))
    return elapsed_min, elapsed_sec


def mkdir(dir_path):
    """ create folder if not exists. """
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)


def json_load(paths):
    """ load json data. """
    data = []
    for path in paths:
        with open(path, 'r', encoding='utf8') as f:
            data.append(json.load(f))
    return data


## eval

In [None]:
def compute_score(key, prediction):
    correct_by_relation = Counter()
    guessed_by_relation = Counter()
    gold_by_relation = Counter()
    
    NO_RELATION = rel2id['no_relation']

    # Loop over the data to compute a score
    for row in range(len(key)):
        gold = key[row]
        guess = prediction[row]

        if gold == NO_RELATION and guess == NO_RELATION:
            pass
        elif gold == NO_RELATION and guess != NO_RELATION:
            guessed_by_relation[guess] += 1
        elif gold != NO_RELATION and guess == NO_RELATION:
            gold_by_relation[gold] += 1
        elif gold != NO_RELATION and guess != NO_RELATION:
            guessed_by_relation[guess] += 1
            gold_by_relation[gold] += 1
            if gold == guess:
                correct_by_relation[guess] += 1

    # verbose information
    relations = gold_by_relation.keys()
    verbose_scores = {}
    for relation in sorted(relations):
        correct = correct_by_relation[relation]
        guessed = guessed_by_relation[relation]
        gold = gold_by_relation[relation]
        prec = 1.0
        if guessed > 0:
            prec = float(correct) / float(guessed)
        recall = 0.0
        if gold > 0:
            recall = float(correct) / float(gold)
        f1 = 0.0
        if prec + recall > 0:
            f1 = 2.0 * prec * recall / (prec + recall)
        verbose_scores[id2rel[relation]] = {'p': prec, 'r': recall, 'f1': f1}

    # aggregate score
    prec_micro = 1.0
    if sum(guessed_by_relation.values()) > 0:
        prec_micro = float(sum(correct_by_relation.values())) / \
            float(sum(guessed_by_relation.values()))
    recall_micro = 0.0
    if sum(gold_by_relation.values()) > 0:
        recall_micro = float(sum(correct_by_relation.values())) / \
            float(sum(gold_by_relation.values()))
    f1_micro = 0.0
    if prec_micro + recall_micro > 0.0:
        f1_micro = 2.0 * prec_micro * \
            recall_micro / (prec_micro + recall_micro)

    return prec_micro * 100, recall_micro * 100, f1_micro * 100, verbose_scores


# model

In [None]:
class RelationModel(nn.Module):
    def __init__(self, use_glove, word_size, word_embed_dim, pos_size, pos_embed_dim,
                 ner_size, ner_embed_dim, max_len, position_embed_dim,
                 hidden_dim, rel_size, attn_dim, dropout, ) -> None:
        super(RelationModel, self).__init__()
        self.embed_dim = word_embed_dim + pos_embed_dim + ner_embed_dim
        self.max_len = max_len

        if use_glove:
            self.word_embedding = self.load_glove()
        else:
            self.word_embedding = nn.Embedding(word_size, word_embed_dim, padding_idx=0)

        self.pos_embedding = nn.Embedding(pos_size, pos_embed_dim, padding_idx=0)

        self.ner_embedding = nn.Embedding(ner_size, ner_embed_dim, padding_idx=0)

        self.position_embedding = nn.Embedding(max_len * 2 + 1, position_embed_dim)

        self.lstm = nn.LSTM(self.embed_dim, hidden_dim, batch_first=True, bidirectional=False)

        self.weight_reply = nn.Parameter(torch.randn(hidden_dim, hidden_dim), requires_grad=True)
        self.fc_gate = nn.Linear(hidden_dim * 2, hidden_dim)
        self.layernorm_gate = nn.LayerNorm(hidden_dim, elementwise_affine=True)

        self.fc_q = nn.Linear(hidden_dim, attn_dim)
        self.fc_h = nn.Linear(hidden_dim, attn_dim)
        self.fc_k = nn.Linear(position_embed_dim * 2, attn_dim)
        self.fc_v = nn.Linear(attn_dim, 1)

        self.fc_out = nn.Linear(hidden_dim, rel_size)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, seq_input: Tensor, pos: Tensor, ner: Tensor, subj_pos: Tensor, obj_pos: Tensor, tensor_len: Tensor, ) -> Tensor:
        """
        # pos: part-of-speech, position: relative position
        seq_input, pos, ner, subj_pos, obj_pos: (batch, seq_len)
        """
        word_embed = self.word_embedding(seq_input)
        pos_embed = self.pos_embedding(pos)
        ner_embed = self.ner_embedding(ner)

        encoder_embed = self.dropout(
            torch.cat((word_embed, pos_embed, ner_embed), dim=-1))

        packed = pack_padded_sequence(encoder_embed, lengths=tensor_len, batch_first=True, enforce_sorted=False)
        lstm_output, (h_n, _) = self.lstm(packed)
        lstm_output, _ = pad_packed_sequence(lstm_output, batch_first=True)
        h_n = self.dropout(h_n.squeeze(0))
        lstm_output = self.dropout(lstm_output)

        subj_pos_embed = self.position_embedding(subj_pos + self.max_len)
        obj_pos_embed = self.position_embedding(obj_pos + self.max_len)
        position_embed = torch.cat((subj_pos_embed, obj_pos_embed), dim=-1)

        entity_selection = self.entity_selection_gate(lstm_output)

        output = self.weighted_rep(entity_selection, tensor_len, h_n, position_embed)

        output = F.log_softmax(self.fc_out(output), dim=-1)

        return output

    def entity_selection_gate(self, reply_pre: Tensor) -> Tensor:
        # reply_pre: (batch, seq_len, hidden_dim)
        reply_post = reply_pre.permute(0, 2, 1)

        alpha = torch.matmul(
            torch.matmul(reply_pre, self.weight_reply), reply_post)
        # mask padding
        alpha = torch.softmax(alpha, dim=-1)
        # weighted average
        reply_c = torch.bmm(alpha, reply_pre)
        attn_gate = torch.sigmoid(self.layernorm_gate(
            self.fc_gate(torch.cat((reply_pre, reply_c), dim=2))))

        reply_cs = attn_gate * reply_pre

        return reply_cs

    def weighted_rep(self, hidden_states: Tensor, tensor_len: Tensor, h_n: Tensor, pe_features: Tensor) -> Tensor:
        """
        hidden_states: (batch, seq_len, hidden_dim)
        tensor_len: (batch)
        h_n: (batch, hidden_dim)
        pe_features: (batch, seq_len, position_embed_dim * 2)
        """
        _q = self.fc_q(hidden_states)
        _h = self.fc_h(h_n).unsqueeze(1).expand_as(_q)
        _k = self.fc_k(pe_features)
        attn_score = self.fc_v(torch.tanh(sum([_q, _h, _k]))).squeeze(-1)

        # mask padding
        attn_score.data.masked_fill_(
            self.build_mask(tensor_len.data), -float('inf'))
        attn_weights = torch.softmax(attn_score, dim=-1)
        # weighted average input vectors
        output = attn_weights.unsqueeze(dim=1).bmm(hidden_states).squeeze(dim=1)

        return output

    @staticmethod
    def load_glove(glove_path=f'{root_dir}/data/weight_matrix.npy'):
        weight = torch.from_numpy(np.load(glove_path)).float()
        pre_trained_embedding = nn.Embedding.from_pretrained(weight)
        pre_trained_embedding.weight.requires_grad = True

        return pre_trained_embedding

    @staticmethod
    def build_mask(x: Tensor) -> Tensor:
        # x: (batch)
        z = torch.zeros((len(x), int(max(x))), dtype=torch.long).to(device)
        for _x, _z in zip(x, z):
            _z[_x:] = 1
        return z.data.to(bool)


# train

## Load DataLoader

In [None]:
fix_seed(args.random_seed)

cur_time = datetime.utcnow().replace(tzinfo=timezone.utc).astimezone(
    timezone(timedelta(hours=8))).strftime('%Y_%m_%d_%H_%M')
logger = get_logger(f'{root_dir}/results/logs/' + cur_time + '.log')

# Print arguments
for arg in vars(args):
    logger.info("{} = {}".format(arg, getattr(args, arg)))

start_time = time.time()

# Load DataLoader
train_data, dev_data, test_data, word2id = json_load([f'{root_dir}/data/debug/train.json',
                                                      f'{root_dir}/data/debug/dev.json',
                                                      f'{root_dir}/data/debug/test.json', 
                                                      f'{root_dir}/data/vocab.json'])
train_data_loader = get_data_loader(dataset=TacRedDataset(data=train_data, word2id=word2id),
                                    batch_size=args.batch_size,
                                    shuffle=True)
dev_data_loader = get_data_loader(dataset=TacRedDataset(data=dev_data, word2id=word2id),
                                  batch_size=1,
                                  shuffle=False)
test_data_loader = get_data_loader(dataset=TacRedDataset(data=test_data, word2id=word2id),
                                   batch_size=1,
                                   shuffle=False)
logger.info(f'data processing consumes: {(time.time() - start_time):.2f}s')

## Define Model

In [None]:
model = RelationModel(use_glove=args.use_glove,
                      word_size=len(word2id),
                      word_embed_dim=args.word_embed_dim,
                      pos_size=len(pos2id),
                      pos_embed_dim=args.pos_embed_dim,
                      ner_size=len(ner2id),
                      ner_embed_dim=args.ner_embed_dim,
                      max_len=max_len,
                      position_embed_dim=args.position_embed_dim,
                      hidden_dim=args.hidden_dim,
                      rel_size=len(rel2id),
                      attn_dim=args.attn_dim,
                      dropout=args.dropout).to(device)
model.apply(weights_init)

# Optimizer
if args.optimizer == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
elif args.optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
    raise ValueError('Choose Optimizer from SGD/ Adam/ Adadelta. !')

# Cross Entropy Loss
criterion = nn.NLLLoss()

early_stop = EarlyStopping(mode='max', min_delta=0.0001, patience=5)

# Load Checkpoint
if args.resume:
    checkpoint, cp_name = load_model(latest=True)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger.info(f'load checkpoint: [{cp_name}]')

## Functions for train/ dev/ test

In [None]:
def train():
    model.train()
    epoch_loss = 0.0
    for sent, rel, subj_pos, obj_pos, pos, ner, tensor_len in tqdm(train_data_loader):
        optimizer.zero_grad()
        output = model(sent, pos, ner, subj_pos, obj_pos, tensor_len)
        loss = criterion(output, rel)
        loss.backward()
        epoch_loss += loss.item()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()
    return epoch_loss / len(train_data_loader)


def evaluate(eval_data_loader):
    model.eval()
    gold_labels, pred_labels = [], []
    with torch.no_grad():
        for sent, rel, subj_pos, obj_pos, pos, ner, tensor_len in eval_data_loader:
            output = model(sent, pos, ner, subj_pos, obj_pos, tensor_len)
            rel_pred = torch.argmax(output, dim=1)
            pred_labels.append(rel_pred.cpu().item())
            gold_labels.append(rel.cpu().item())

    return compute_score(key=gold_labels, prediction=pred_labels)


## Train Model

In [None]:
best_f1 = 0
for epoch in range(1, int(args.max_epochs + 1)):
    start_time = time.time()

    train_loss = train()
    _, _, dev_f1, _ = evaluate(dev_data_loader)
    epoch_min, epoch_sec = record_time(start_time, time.time())

    logger.info(
        f'epoch: [{epoch:02}/{args.max_epochs}]  train_loss={train_loss:.3f}  '
        f'dev_f1={dev_f1:.2f}  duration: {epoch_min}m {epoch_sec}s')

    if dev_f1 > best_f1:
        save_model(experiment_time=cur_time, model=model, optimizer=optimizer)

    if early_stop.step(dev_f1):
        logger.info(f'early stop at [{epoch:02}/{args.max_epochs}]')
        break

test_prec, test_recall, test_f1, verbose_info = evaluate(test_data_loader)

logger.info(
    f'precision: {test_prec:.2f}  recall: {test_recall:.2f}  f1: {test_f1:.2f}')
logger.info(verbose_info)


## End