In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

from openai.gpt1 import OpenAIGPTModel
from utils import read_json, save_json, make_dot_dict
from en_tokenizer import CharTokenizer
from dataset import CustomDataset

In [2]:
class OpenAIGPTLMHeadModel(nn.Module):
    def __init__(self, hp, pad_token_id):
        super(OpenAIGPTLMHeadModel, self).__init__()
        self.transformer = OpenAIGPTModel(hp, pad_token_id)
        self.lm_head = nn.Linear(hp.d_model, hp.vocab_size, bias=False)
        
    def load_pretrained_model(self, path):
        pass
    
    def save_weight(self):
        pass
    
    def forward(self, x):
        x = self.transformer(x)
        x = self.lm_head(x)
        return x
    

In [10]:
class CustomLoss(object):
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id
    
    def __call__(self, logits, target):
        # logits : [batch, seq_len, vocab_size]
        # target : [batch, seq_len]
        vocab_size = logits.size()[-1]
        
        shift_logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
        shift_true = target[:, 1:].contiguous().view(-1)
        padding_mask = (shift_true == self.pad_token_id).to(shift_logits.dtype)
        
        loss = F.cross_entropy(shift_logits, shift_true,
                              reduction='none')
        loss *= 1. - padding_mask
        return loss.sum() / (1.-padding_mask).sum()
    

In [11]:
class CustomAccuracy(object):
    def __init__(self, pad_token_id):
        self.pad_token_id = pad_token_id
        
    def __call__(self, logits, target):
        vocab_size = logits.size()[-1]
        
        shift_logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
        shift_pred = shift_logits.argmax(dim=1)
        
        shift_true = target[:, 1:].contiguous().view(-1)
        padding_mask = (shift_true == self.pad_token_id).float()
        
        tp = (shift_pred == shift_true).float()
        tp *= 1. - padding_mask
        return tp.sum() / (1. - padding_mask).sum()

In [12]:
def move_to(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        return {key: move_to(value, device) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [move_to(value, device) for value in obj]
    else:
        raise AssertionError()


class AverageMeter(object):
    val, avg, sum, count = [None] * 4

    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        template = "{} {:.3f} ({:.3f})"
        return template.format(self.name, self.val, self.avg)

In [17]:
hp = read_json('config.json')
hp = make_dot_dict(hp)

tokenizer = CharTokenizer(model_path='sentencepiece_models/cp-char.model')
dataset = CustomDataset(hp=hp, root='data/docs', tokenizer=tokenizer)

hp.vocab_size = tokenizer.vocab_size

_lr = 1e-4
_epochs = 2
_batch_size = 2

model = OpenAIGPTLMHeadModel(hp, tokenizer.pad_token_id)
loss_fn = CustomLoss(tokenizer.pad_token_id)
opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad],
                       lr=_lr, 
                       betas=(0.9, 0.999),
                       weight_decay=1e-4)

metrics = {'acc' : CustomAccuracy(tokenizer.pad_token_id)}

In [None]:
import numpy as np


print(model)
for epoch in range(_epochs):
    for x in dataset:
        n_data = len(x)
        
        n_step = n_data // _batch_size
        indices = np.arange(n_data)
        np.random.shuffle(indices)
        
        for i in range(0, n_step, _batch_size):
            c_idx = indices[i:i+_batch_size]
            
            c_x = x[c_idx]
            lm_logits = model(c_x)
            
            loss = loss_fn(lm_logits, c_x)
            m = {key: value(lm_logits, c_x) for key, value in metrics.items()}
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            print(f"loss : {loss.item():.3f} " 
                  + " ".join([f"{key} : {value.item():.3f}" for key, value in m.items()])
                 , end='\r')
#     print()
        

OpenAIGPTLMHeadModel(
  (transformer): OpenAIGPTModel(
    (embed): Embedding(
      (words_embed): Embedding(47, 128)
      (positions_embed): Embedding(512, 128)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0): Block(
        (attn): Attention(
          (W_QKV): Linear(in_features=128, out_features=384, bias=True)
          (WO): Linear(in_features=128, out_features=128, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (proj_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (fc1): Linear(in_features=128, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (1): Block(
        (attn): Attention(
          (W_QKV): Linear(in_features

In [23]:
class Trainer(object):
    def __init__(self, model, loss_fn, optimizer, metrics):
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.metrics = metrics
        
        self.train_eval = {"train_loss": AverageMeter("train_loss")}
        for key in metrics.keys():
            self.train_eval[f"train_{key}"] = AverageMeter(f"train_{key}")

        self.dev_eval = {"dev_loss": AverageMeter("dev_loss")}
        for key in metrics.keys():
            self.dev_eval[f"dev_{key}"] = AverageMeter(f"dev_{key}")
    
    def fit(self, train_dataset, epochs, batch_size ,dev_dataset=None callbacks=None):
        try:
            xmp
        except:
            index = None
            self.map_fn(index, train_dataset, epochs, batch_size, dev_dataset, callbacks)
    
    def map_fn(self, index, train_dataset, epochs, batch_size, dev_dataset=None, callbacks):
        
        
    