In [None]:
!pip install -q datasets
!pip install -q tiktoken

In [None]:
from datasets import load_dataset
ds = load_dataset("asahi417/multi-domain-document-classification")

import torch
from torch.utils.data import Dataset , DataLoader

import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')

In [None]:
# Arguments
max_length = 95  #
num_outputs = 4
vocab_size = tokenizer.n_vocab
emd_dim = 16

In [None]:
print(ds)

In [None]:
X = ds['test']['text']
y = ds['test']['label']

In [None]:
X[:10],y[:10]

In [None]:
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(
    X,y,test_size=0.3,random_state=1,stratify=y
)

In [None]:
len(X_train),len(X_test)

In [None]:
# Define Data Class
class dataset(Dataset):
    def __init__(self,X,y,tokenizer,max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.features = []
        self.labels = torch.tensor(y)
        for i in range(len(self.labels)):
          tokens = self.tokenizer.encode(X[i])[:self.max_length]
          if len(tokens) < max_length:
                tokens += [0] * (max_length - len(tokens))
          self.features.append(torch.tensor(tokens))
    def __getitem__(self,index):
        one = self.features[index]
        two = self.labels[index]
        return one,two
    def __len__(self):
        return len(self.labels)
train_ds = dataset(X_train,y_train,tokenizer,max_length)
test_ds = dataset(X_test,y_test,tokenizer,max_length)

In [None]:
# Data Loader
torch.manual_seed(123)
train_loader = DataLoader(
    dataset = train_ds,
    shuffle = True,
    batch_size = 100,
    num_workers = 0,
    drop_last = True,
    )

In [None]:
class MultiHeadAttention(torch.nn.Module):
  def __init__(self, d_in, d_out,context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert d_out % num_heads == 0, "d_out must be divisible by num_head "
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = torch.nn.Linear(d_out, d_out)
    self.dropout = torch.nn.Dropout(dropout)
    self.register_buffer(
      'mask',
      torch.triu(torch.ones(context_length, context_length), diagonal=1)
    )
  def forward(self, x):
    b, num_tokens, d_in = x.shape
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)
    keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
    values = values.view(b, num_tokens, self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
    keys = keys.transpose(1, 2)
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)
    attn_scores = queries @ keys.transpose(2, 3)
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
    attn_scores.masked_fill_(mask_bool, -torch.inf)
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1]**0.5, dim=-1)
    attn_weights = self.dropout(attn_weights)
    context_vec = (attn_weights @ values).transpose(1, 2)
    context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
    context_vec = self.out_proj(context_vec)
    return context_vec

In [None]:
class NeuralNetwork(torch.nn.Module):
    def __init__(self, num_inputs, num_outputs, vocab_size, emd_dim, num_heads, dropout, context_length):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.vocab_size = vocab_size
        self.emd_dim = emd_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.context_length = context_length

        # Embedding
        self.tok_emb = torch.nn.Embedding(self.vocab_size, self.emd_dim)
        self.pos_emb = torch.nn.Embedding(self.num_inputs, self.emd_dim)

        # Custom Multi-Head Attention Layer
        self.attention = MultiHeadAttention(
            d_in=self.emd_dim, 
            d_out=self.emd_dim, 
            context_length=self.context_length, 
            dropout=self.dropout, 
            num_heads=self.num_heads
        )
        
        self.layers = torch.nn.Sequential(
            # 1st hidden layer
            torch.nn.Linear(self.emd_dim, 30),
            torch.nn.ReLU(),

            # 2nd hidden layer
            torch.nn.Linear(30, 20),
            torch.nn.ReLU(),

            # Output layer
            torch.nn.Linear(20, self.num_outputs)
        )

    def forward(self, x):
        tok_embeds = self.tok_emb(x)
        pos_embeds = self.pos_emb(torch.arange(self.num_inputs, device=x.device))
        x = tok_embeds + pos_embeds

        # Apply self-attention mechanism
        x = self.attention(x)
        
        x = x.mean(dim=1)  # Shape: [batch_size, emd_dim]

        logits = self.layers(x)
        return logits


In [None]:
# Training
torch.manual_seed(123)
import torch.nn.functional as F

# Initialize model
model = NeuralNetwork(
    num_inputs=max_length,       
    num_outputs=num_outputs,     
    vocab_size=vocab_size,       
    emd_dim=emd_dim,             
    num_heads=4,                 
    dropout=0.1,                 
    context_length=max_length    
)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    for batch, (features, labels) in enumerate(train_loader):
        # Forward pass
        logits = model(features)

        # Loss computation
        loss = F.cross_entropy(logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss and progress
        if (epoch+1)%10==0:
            print(f"Epoch {epoch + 1:03d}/{num_epochs:03d}"
                  f" | Batch {batch + 1:03d}/{len(train_loader):03d}"
                  f" | Loss: {loss.item():.4f}")


In [None]:
# Predcition Accuracy
def compute_accuracy(model,test_loader):
    correct = 0
    tatal_examples = 0
    model.eval()
    for idx ,(feature,lables) in enumerate(test_loader):
        with torch.no_grad():
            logits = model(features)
        predictions = torch.argmax(logits,dim=1)
        compare = labels == predictions
        correct += torch.sum(compare)
        tatal_examples += len(compare)
    return (correct/tatal_examples).item()

In [None]:
torch.manual_seed(123)
test_loader = DataLoader(
    dataset = test_ds,
    shuffle = False,
    batch_size = 500,
    num_workers = 0,
    )

In [None]:
accuracy_test = compute_accuracy(model,test_loader)
accuracy_test

In [None]:
accuracy_train = compute_accuracy(model,train_loader)
accuracy_train