In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertTokenizerFast, BertModel
import ast
from transformers import get_cosine_schedule_with_warmup
from torch.cuda import amp
import time

In [None]:
def change_input(tokenizer, text1, text2=None, text3=None, labels = None,max_length=512):
    '''
    This function will change the given input from double to triple
    '''
    #do the basic tokenization without changing to index
    tokens_1 = tokenizer.tokenize(text1)
    if text2 is not None:
        tokens_2 = tokenizer.tokenize(text2)
    if text3 is not None:
        tokens_3 = tokenizer.tokenize(text3)
    #as shown in kg-bert, do the truncation
    while True:
        #do the trunctation 
        total_length = len(tokens_1)+len(tokens_2)+len(tokens_3)
        if total_length<= max_length-4:
            break
        if len(tokens_1)>len(tokens_2) and len(tokens_1)>len(tokens_3):
            tokens_1.pop()
        elif len(tokens_2)>len(tokens_1) and len(tokens_2)>len(tokens_3):
            tokens_2.pop()
        elif len(tokens_3)>len(tokens_2) and len(tokens_3)>len(tokens_1):
            tokens_3.pop()
        else:
            #else pop the token3(tail)
            tokens_3.pop()
    #segment encoding
    final_token = ["[CLS]"]+tokens_1+["[SEP]"]
    #segment for first sentence
    segment_ids = [0]*len(final_token)
    if text2 is not None:
        final_token+=tokens_2+["[SEP]"]
        segment_ids+=[1]*(len(tokens_2)+1)
    if text3 is not None:
        final_token+=tokens_3+["[SEP]"]
        segment_ids+=[0]*(len(tokens_3)+1)
    #change it to the index
    input_ids = tokenizer.convert_tokens_to_ids(final_token)
    #for padding
    padding = [0]*(max_length - len(input_ids))
    #for attention mask
    attention_mask = [1]*len(input_ids)
    input_ids+=padding
    attention_mask+= padding
    segment_ids+=padding
    assert len(input_ids) == max_length
    assert len(attention_mask) == max_length
    assert len(segment_ids) == max_length
    return {"input_ids": input_ids,
            "segment_ids": segment_ids,
            "attention_mask": attention_mask,
            "labels":labels,
    }

In [None]:
class language_Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        '''
        df is dataframe given previously
        '''
        self.df = df
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        '''
        This function will return the index
        '''
        dic = change_input(self.tokenizer,self.df.iloc[idx]["head"], self.df.iloc[idx]["relation"], self.df.iloc[idx]["tail"],self.df.iloc[idx]["labels"])
        return torch.tensor(dic["input_ids"]), torch.tensor(dic["segment_ids"]), torch.tensor(dic["attention_mask"]), torch.tensor(dic["labels"])

In [None]:
class KGBERT(nn.Module):
    def __init__(self,num_class,path = None):
        '''
        init function:
            path: pretrained model from huggingface path, if not, download it from website
        '''
        super().__init__()
        if path is not None:
            self.model = BertModel.from_pretrained(path)
        else:
            self.model = BertModel.from_pretrained('bert-base-uncased')
        self.ln1 = nn.Linear(768, num_class)
    def forward(self, input_ids,segment_ids, attention_mask):
        with amp.autocast():
            bert_out = self.model(input_ids = input_ids, attention_mask = attention_mask, token_type_ids = segment_ids)["pooler_output"]
            output = self.ln1(bert_out)
            return output

In [None]:
def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')
def train_with_amp(net, train_set, criterion, optimizer, epochs,batch_size, scheduler, gradient_accumulate_step, max_grad_norm , num_gpu):
    net.train()   
    
    # instantiate a scalar object 
    ls          = []
    device_ids  = [try_gpu(i) for i in range(num_gpu)]
    device  = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("\ntrain on %s\n"%str(device_ids))
    enable_amp  = True if "cuda" in device_ids[0].type else False
    scaler      = amp.GradScaler(enabled= enable_amp)
    net         = nn.DataParallel(net, device_ids = device_ids)
    net.to(device)
    train_iter  = torch.utils.data.DataLoader(train_set, batch_size = batch_size)
    for epoch in range(epochs):
        for idx, value in enumerate(train_iter):
            ini_time    = time.time()
            input_ids, seg_ids, att_mask, labels = value
            input_ids   = input_ids.to(device_ids[0])
            att_mask    = att_mask.to(device_ids[0])
            labels      = labels.to(device_ids[0])
            seg_ids     = seg_ids.to(device_ids[0])
            # when forward process, use amp
            with amp.autocast(enabled= enable_amp):
                output  = net(input_ids, seg_ids, att_mask)  
            loss        = criterion(output, labels.view(-1,1).float())
            # prevent gradient to 0
            if gradient_accumulate_step > 1:
                # 如果显存不足，通过 gradient_accumulate 来解决
                loss    = loss/gradient_accumulate_step
            
            # 放大梯度，避免其消失
            scaler.scale(loss).mean().backward()
            # do the gradient clip
            gradient_norm = nn.utils.clip_grad_norm_(net.parameters(),max_grad_norm)
            if (idx + 1) % gradient_accumulate_step == 0:
                # 多少 step 更新一次梯度
                # 通过 scaler.step 来unscale 回梯度值， 如果气结果不是infs 和Nans， 调用optimizer.step()来更新权重
                # 否则忽略step调用， 保证权重不更新
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
            # 每1000次计算 print 出一次loss
            if idx % 1000 == 0 or idx == len(train_iter) -1:
                with torch.no_grad():
                    print("==============Epochs "+ str(epoch) + " ======================")
                    print("loss: " + str(loss) + "; grad_norm: " + str(gradient_norm))
                ls.append(loss.item())
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'param_groups': optimizer.state_dict()["param_groups"],
                    'loss': ls
                },"./checkpoint.params")
            with open("train_log", "a") as f:
                f.write("Epoch %s, Batch %s: %.4f sec\n"%(epoch, idx, time.time() - ini_time))

In [None]:
if __name__ == "__main__":
    train = pd.read_csv("../input/train-valid-test-dataset/train.csv").drop("Unnamed: 0", axis = 1)
    train["index_where"] = train["index_where"].apply(ast.literal_eval)
    train_set = language_Dataset(train)
    net = KGBERT(1)
    loss = nn.BCEWithLogitsLoss()
    batch_size = 2
    lr = 2e-6
    num_gpu = 1
    optimizer = torch.optim.AdamW(net.parameters(), lr = lr)
    scheduler = get_cosine_schedule_with_warmup(optimizer= optimizer, num_warmup_steps = 0, 
                                                num_training_steps= len(torch.utils.data.DataLoader(train_set, batch_size = batch_size)), num_cycles = 0.5)
    train_with_amp(net, train_set,loss,optimizer,3,batch_size,scheduler,1,1000,num_gpu)