In [1]:
import torch
import torch.nn as nn
from transformers import ElectraModel, ElectraTokenizer

class LanguageModel(nn.Module):
    def __init__(self):
        super(LanguageModel, self).__init__()
        
        # KoElectra model and tokenizer
        self.electra = ElectraModel.from_pretrained("monologg/koelectra-base-v3-discriminator")
        self.tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator")
        
        # Bi-directional LSTM layer
        self.lstm = nn.LSTM(
            input_size=self.electra.config.hidden_size,
            hidden_size=50,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        
        # Global Max Pooling and Flatten layers
        self.global_max_pooling = nn.AdaptiveMaxPool1d(1)
        self.flatten = nn.Flatten()
        
        # Dense layer
        self.dense = nn.Linear(50, 50)
        
    def forward(self, inputs):
        # Tokenize inputs
        input_ids = self.tokenizer(
            inputs, padding=True, truncation=True, max_length=512, return_tensors="pt"
        )["input_ids"]
        
        # Get KoElectra output
        outputs = self.electra(input_ids)
        last_hidden_state = outputs.last_hidden_state
        
        # Pass through LSTM layer
        lstm_out, _ = self.lstm(last_hidden_state)
        
        # Apply Global Max Pooling and Flatten
        max_pooled = self.global_max_pooling(lstm_out.permute(0, 2, 1))
        flattened = self.flatten(max_pooled)
        
        # Pass through dense layer
        output = self.dense(flattened)
        
        return output


In [None]:
language_model = LanguageModel()
input_string = "This is a sample input string."
output_tensor = language_model(input_string)