In [8]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import json
import os

In [2]:
class EncoderModel(nn.Module): 


    def __init__(
        self,
        embedding_size: int,
        d_model: int,
        num_heads: int,
        num_layers: int,
        vocab_size: int,
        num_classes: int,
        seq_len: int,
        pe: int = 1024,
        rate: float = 0.1
    ):
        
        super().__init__()
        self.embedding_size = embedding_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.seq_len = seq_len
        self.pe = pe
        self.rate = rate
        
        self.embedding=nn.Embedding(vocab_size, embedding_size, pe)
        if embedding_size !=d_model:
            self.embedding_intermediate=nn.Linear(d_model)
        
        encoder_layer =nn.TransformerEncoderLayer(self.d_model, 
                                                                self.num_heads, 
                                                                dim_feedforward=2048, 
                                                                dropout=self.rate, 
                                                                activation="gelu", 
                                                                layer_norm_eps=1e-05,
                                                                batch_first=True)
        
        
        self.pooling_embedding=nn.Embedding(1, self.d_model)
        self.output_layer = nn.Linear(self.num_classes, torch.sigmoid())
        self.transformer_encoder=nn.TransformerEncoder(encoder_layer, self.num_layers)
        
    def dot_attention(self, q,k,v,mask=None):
        logits=torch.matmul(q, k)
        
        if mask is not None:
            logits+=torch.to((1-mask[:, torch.unsqueeze, :]), torch.float32)*-1e9
            
        attention_weights=torch.softmax(logits, dim=-1)
        output=torch.matmul(attention_weights, v)
        
        return output
    
    def call(
        self,
        inputs,
        input_ids,
        attention_mask=None,
        training=False,
        **kwargs
    ):
        input_ids = inputs["input_ids"]
        attention_mask = inputs.get("attention_mask")
        
        encoder_output = self.embedding(input_ids)  
        
        if self.embedding_size != self.d_model:
            encoder_output = self.embedding_intermediate(encoder_output)
            
        for i in range(self.num_layers):
            encoder_output=self.encoder_layer[i](
                encoder_output, attention_mask, 
            )

        pooling_code = self.pooling_embedding(  # code_embedding
            torch.range(1, dtype=torch.int32)
        ) 
    
        encoder_output = self.dot_attention(
            pooling_code, encoder_output, encoder_output, mask=attention_mask
        ) 
        encoder_output = torch.squeeze(
            encoder_output, [1]
        )
            
        output = self.output_layer(
            encoder_output
        )  
        return output
    
    def get_config(self):
        return {
            "embedding_size": self.embedding_size,
            "d_model": self.d_model,
            "num_heads": self.num_heads,
            "num_layers": self.num_layers,
            "vocab_size": self.vocab_size,
            "num_classes": self.num_classes,
            "seq_len": self.seq_len,
            "pe": self.pe,
            "rate": self.rate,
        }
        
    def save(self, save_dir):
        if not os.path.isdir(save_dir):
            os.mkdir(save_dir)

        with open(os.path.join(save_dir, "config.json"), "w") as f:
            json.dump(self.get_config(), f)

        self(self._get_sample_data())
        self.save_weights(os.path.join(save_dir, "model_weights.h5"))

        return os.listdir(save_dir)

    @classmethod
    def load(cls, save_dir):
        with open(os.path.join(save_dir, "config.json"), "r") as f:
            config = json.load(f)

        model = cls(**config)
        model(model._get_sample_data())
        model.load_weights(os.path.join(save_dir, "model_weights.h5"))

        return model
        