In [89]:
import numpy as np
import torch
import torch.nn as nn
import math
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast

# Parameters
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
MAX_SEQ_LENGTH = 400

# Load IMDb dataset
dataset_imdb = load_dataset("stanfordnlp/imdb")

# Load tokenizer
tokenizer = torch.hub.load(
    "huggingface/pytorch-transformers", "tokenizer", "bert-base-uncased"
)

# Collate function for IMDb
def imdb_collate_fn(batch):
    sequences = []
    for item in batch:
        sequences.append(item["text"])

    # Tokenization and padding
    input_tensor = torch.LongTensor(
        tokenizer(sequences, padding=True, truncation=True, max_length=MAX_SEQ_LENGTH).input_ids
    )

    # Extract second-to-last token as the label
    last_tokens = []
    for seq in input_tensor:
        idx_last_token = (seq != tokenizer.pad_token_id).nonzero()[-2].item()
        last_tokens.append(seq[idx_last_token])

    labels_tensor = torch.LongTensor(last_tokens)
    return input_tensor, labels_tensor

# DataLoader
train_loader = DataLoader(
    dataset_imdb["train"], batch_size=BATCH_SIZE, shuffle=True, collate_fn=imdb_collate_fn
)
test_loader = DataLoader(
    dataset_imdb["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=imdb_collate_fn
)

# Positional Encoding
def get_positional_angles(pos, idx, d_model_size):
    angle_rates = 1 / np.power(10_000, (2 * (idx // 2)) / np.float32(d_model_size))
    return pos * angle_rates

def compute_positional_encoding(max_position, d_model_size):
    angle_rads = get_positional_angles(
        np.arange(max_position)[:, None], np.arange(d_model_size)[None, :], d_model_size
    )
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    return torch.FloatTensor(angle_rads[None, ...])

# MultiHeadAttention Layer
class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, input_dim, d_model_size, num_heads):
        super().__init__()

        self.num_heads = num_heads
        self.d_k = d_model_size // num_heads

        self.query_linear = nn.Linear(input_dim, d_model_size)
        self.key_linear = nn.Linear(input_dim, d_model_size)
        self.value_linear = nn.Linear(input_dim, d_model_size)
        self.output_linear = nn.Linear(d_model_size, d_model_size)
        self.softmax_layer = nn.Softmax(dim=-1)

    def forward(self, x, attention_mask=None):
        batch_size, sequence_length, _ = x.size()

        # Compute Q, K, V
        Q = self.query_linear(x)
        K = self.key_linear(x)
        V = self.value_linear(x)

        # Reshape for multiple heads
        Q = Q.view(batch_size, sequence_length, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, sequence_length, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, sequence_length, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if attention_mask is not None:
            attention_mask = attention_mask.repeat(1, self.num_heads, 1, 1)
            scores = scores.masked_fill(attention_mask == 0, -1e9)

        attn_weights = self.softmax_layer(scores)
        output = torch.matmul(attn_weights, V)

        output = output.transpose(1, 2).contiguous().view(batch_size, sequence_length, -1)
        return self.output_linear(output)

# Transformer Layer
class TransformerBlock(nn.Module):
    def __init__(self, d_model_size, num_heads, dff_size, dropout_prob=0.1):
        super().__init__()

        self.attention = MultiHeadedSelfAttention(d_model_size, d_model_size, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model_size, dff_size),
            nn.ReLU(),
            nn.Linear(dff_size, d_model_size)
        )

        self.norm1 = nn.LayerNorm(d_model_size)
        self.norm2 = nn.LayerNorm(d_model_size)
        self.dropout1 = nn.Dropout(dropout_prob)
        self.dropout2 = nn.Dropout(dropout_prob)

    def forward(self, x, mask):
        attn_output = self.attention(x, mask)
        x = self.norm1(self.dropout1(attn_output) + x)

        ffn_output = self.ffn(x)
        x = self.norm2(self.dropout2(ffn_output) + x)
        return x

# Text Classification Model
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, dff_size, max_len):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model  # Here is the correct attribute
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dff_size = dff_size
        self.max_len = max_len

        self.embedding_layer = nn.Embedding(vocab_size, d_model)
        self.position_encoding = nn.Parameter(compute_positional_encoding(max_len, d_model))

        self.transformer_layers = nn.ModuleList(
            [TransformerBlock(d_model, num_heads, dff_size) for _ in range(num_layers)]
        )

        self.final_linear = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = (x != tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
        seq_len = x.shape[1]

        # Use self.d_model instead of self.d_model_size
        x = self.embedding_layer(x) * math.sqrt(self.d_model)
        x += self.position_encoding[:, :seq_len]

        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, mask)

        x = x[:, -2]
        return self.final_linear(x)


# Initialize model and training components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TextClassificationModel(
    vocab_size=len(tokenizer),
    d_model=64,
    num_layers=6,
    num_heads=4,
    dff_size=128,
    max_len=MAX_SEQ_LENGTH
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_function = nn.CrossEntropyLoss()

# Training loop with accuracy calculation
def compute_accuracy(model, dataloader):
    total, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)

            predictions = model(inputs)
            preds = torch.argmax(predictions, dim=-1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

# Train and evaluate
for epoch in range(50):
    total_loss = 0
    model.train()

    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1} | Loss: {total_loss:.4f}")

    model.eval()
    train_acc = compute_accuracy(model, train_loader)
    test_acc = compute_accuracy(model, test_loader)
    print(f"Train Accuracy: {train_acc:.3f} | Test Accuracy: {test_acc:.3f}")


Using cache found in /root/.cache/torch/hub/huggingface_pytorch-transformers_main


Epoch 1 | Loss: 1170.3133
Train Accuracy: 0.693 | Test Accuracy: 0.688
Epoch 2 | Loss: 671.6316
Train Accuracy: 0.732 | Test Accuracy: 0.718
Epoch 3 | Loss: 549.3916
Train Accuracy: 0.763 | Test Accuracy: 0.733
Epoch 4 | Loss: 465.8578
Train Accuracy: 0.808 | Test Accuracy: 0.753
Epoch 5 | Loss: 395.6816
Train Accuracy: 0.828 | Test Accuracy: 0.758
Epoch 6 | Loss: 340.9747
Train Accuracy: 0.836 | Test Accuracy: 0.768
Epoch 7 | Loss: 294.4058
Train Accuracy: 0.852 | Test Accuracy: 0.772
Epoch 8 | Loss: 259.9780
Train Accuracy: 0.880 | Test Accuracy: 0.787
Epoch 9 | Loss: 234.2719
Train Accuracy: 0.888 | Test Accuracy: 0.784
Epoch 10 | Loss: 206.1052
Train Accuracy: 0.901 | Test Accuracy: 0.788
Epoch 11 | Loss: 185.1544
Train Accuracy: 0.907 | Test Accuracy: 0.785
Epoch 12 | Loss: 157.9428
Train Accuracy: 0.930 | Test Accuracy: 0.787
Epoch 13 | Loss: 138.3546
Train Accuracy: 0.929 | Test Accuracy: 0.755
Epoch 14 | Loss: 117.4101
Train Accuracy: 0.954 | Test Accuracy: 0.774
Epoch 15 | Los