In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
from datasets import load_dataset



In [2]:
dataset = load_dataset("glue", "sst2")
train_dataset = dataset["train"]
test_dataset = dataset["validation"]

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data: 100%|████████████████████████████████████████████████████████████████| 3.11M/3.11M [00:02<00:00, 1.32MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████| 72.8k/72.8k [00:01<00:00, 50.6kB/s]
Downloading data: 100%|██████████████████████████████████████████████████████████████████| 148k/148k [00:01<00:00, 92.1kB/s]


Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [6]:
def preprocess_text(text):
    encoded_text = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=128,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors="pt"
    )
    return encoded_text


In [8]:
class TransformerEncoder(nn.Module):
    def __init__(self, input_size, num_layers, hidden_size, num_heads, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.pos_encoding = PositionalEncoding(hidden_size, dropout)
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_size, num_heads, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(hidden_size, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(hidden_size, 4 * hidden_size)
        self.linear2 = nn.Linear(4 * hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)

    def forward(self, x, mask):
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, hidden_size)
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        x = self.linear2(self.dropout(nn.functional.relu(self.linear1(x))))
        x = x + self.dropout(x)
        x = self.norm2(x)
        return x.permute(1, 0, 2)  # (batch_size, seq_len, hidden_size)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


In [9]:
input_size = len(tokenizer)
hidden_size = 128
num_layers = 3
num_heads = 4
dropout = 0.1

model = TransformerEncoder(input_size, num_layers, hidden_size, num_heads, dropout)


In [10]:
model

TransformerEncoder(
  (embedding): Embedding(30522, 128)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (layers): ModuleList(
    (0-2): 3 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (linear1): Linear(in_features=128, out_features=512, bias=True)
      (linear2): Linear(in_features=512, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
  )
)