In [None]:
import pandas as pd
import numpy as np

import copy
from typing import Optional, Any, Union, Callable

import torch
from torch import nn
from torch import Tensor
from torch import LongTensor

from torch.nn.init import xavier_normal_

from torch.nn import MultiheadAttention, ModuleList, Dropout, Linear, LayerNorm, functional as F

import math

class GPT(nn.Module):
    
    def __init__(self, d_model: int = 512, nhead: int = 8, num_decoder_layers: int = 6, dim_feedforward: int = 2048,
                 dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 
                 custom_decoder: Optional[Any] = None, layer_norm_eps: float = 1e-5, 
                 norm_first: bool = False) -> None:
        
        super(GPT, self).__init__()
        
        self.tok_emb = TokenEmbedding(vocab_size, emb_size) 
        self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout)    


        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout,
                                                    activation, layer_norm_eps, norm_first)
            decoder_norm = LayerNorm(d_model, eps=layer_norm_eps)
            self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
            
        self.linear_lm =  nn.Linear(emb_size, vocab_size)
        self.linear_cls = nn.Linear(emb_size, class_num)
            
        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead

    def forward(self, tgt: Tensor, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None) -> [Tensor]:
        
        output = self.positional_encoding(self.tok_emb(tgt))
    
        output = self.decoder(output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        
        output1 = self.linear_lm(output)
        output2 = self.linear_cls(output)
        
        return output1, output2

    
    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""

        for p in self.parameters():
            if p.dim() > 1:
                xavier_normal_(p)
                
                
    def decode(self, tgt: Tensor, tgt_mask: Tensor):
        return self.decoder(self.positional_encoding(self.tok_emb(tgt)),tgt_mask)
                
                
class TransformerDecoder(nn.Module):
    
    __constants__ = ['norm']

    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, tgt: Tensor, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        
        output = tgt

        for mod in self.layers:
            output = mod(output, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask )

        if self.norm is not None:
            output = self.norm(output)

        return output
    
    
class TransformerDecoderLayer(nn.Module):
    

    __constants__ = ['norm_first']

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
                 layer_norm_eps=1e-5, norm_first=False) -> None:
      
        super(TransformerDecoderLayer, self).__init__()
        
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
        
        self.linear1 = Linear(d_model, dim_feedforward)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model)

        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
        
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
       
        if isinstance(activation, str):
            self.activation = _get_activation_fn(activation)
        else:
            self.activation = activation

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerDecoderLayer, self).__setstate__(state)

    def forward(self, tgt: Tensor, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None) -> Tensor:
       
        x = tgt
        
        if self.norm_first:
            x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
            x = x + self._ff_block(self.norm2(x))
        else:
            x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
            x = self.norm2(x + self._ff_block(x))

        return x

    # self-attention block
    def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        
        return self.dropout1(x)


    # feed forward block
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

    
    
def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))        
    
    
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(tgt):
    
    tgt_seq_len = tgt.shape[0]
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    
    return tgt_mask, tgt_padding_mask

    
def greedy_decode(model, ys, max_len):
    
    ys = ys.to(DEVICE)
   
    for i in range(max_len-1):

        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        
        out = model.decode(ys, tgt_mask)
        
        out = out.transpose(0, 1)
        prob = model.linear_lm(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(ys.data).fill_(next_word)], dim=0).to(DEVICE)
        
        if next_word == EOS_IDX:
            break
            
    return ys

