In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from tqdm.notebook import tqdm
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CustomAttention(nn.Module):
    def __init__(self, d_model, attention_type='multihead', alignment_fn='scaled_dot', n_heads=1, window_size=5):
        super(CustomAttention, self).__init__()
        self.d_model = d_model
        self.attention_type = attention_type
        self.alignment_fn = alignment_fn
        self.n_heads = n_heads
        self.window_size = window_size

        # Linear layers for query, key, value (used in most attention mechanisms)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        if attention_type == 'multihead':
            self.attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)

    def compute_alignment(self, query, key):
        if self.alignment_fn == 'dot':
            return torch.matmul(query, key.transpose(-2, -1))
        elif self.alignment_fn == 'scaled_dot':
            return torch.matmul(query, key.transpose(-2, -1)) / (self.d_model ** 0.5)
        elif self.alignment_fn == 'additive':
            score = torch.tanh(query.unsqueeze(-2) + key.unsqueeze(-3))
            return score.sum(dim=-1)
        else:
            raise ValueError(f"Unknown alignment function: {self.alignment_fn}")

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        if self.attention_type == 'multihead':
            return self.attention(query, key, value)[0]

        elif self.attention_type == 'local':
            batch_size, seq_len, _ = x.size()
            outputs = []
            for i in range(seq_len):
                start = max(0, i - self.window_size // 2)
                end = min(seq_len, i + self.window_size // 2 + 1)
                local_query = query[:, i:i+1, :]
                local_key = key[:, start:end, :]
                local_value = value[:, start:end, :]
                alignment = self.compute_alignment(local_query, local_key)
                weights = F.softmax(alignment, dim=-1)
                outputs.append(torch.matmul(weights, local_value).squeeze(1))
            return torch.stack(outputs, dim=1)

        elif self.attention_type == 'global':
            alignment = self.compute_alignment(query, key)
            weights = F.softmax(alignment, dim=-1)
            return torch.matmul(weights, value)

        elif self.attention_type == 'kernelized':
            query = F.elu(query) + 1
            key = F.elu(key) + 1
            alignment = torch.matmul(query, key.transpose(-2, -1))
            weights = F.softmax(alignment, dim=-1)
            return torch.matmul(weights, value)

        elif self.attention_type == 'group_query':
            # Group query logic
            groups = query.chunk(self.n_heads, dim=1)
            grouped_results = []
            for g in groups:
                alignment = self.compute_alignment(g, key)
                weights = F.softmax(alignment, dim=-1)
                grouped_results.append(torch.matmul(weights, value))
            return torch.cat(grouped_results, dim=1)

        elif self.attention_type == 'hierarchical':
            # Hierarchical: assume hierarchical levels as chunks
            chunks = x.chunk(4, dim=1)  # Dividing into hierarchical levels
            hierarchical_outputs = []
            for chunk in chunks:
                query = self.query(chunk)
                key = self.key(chunk)
                value = self.value(chunk)
                alignment = self.compute_alignment(query, key)
                weights = F.softmax(alignment, dim=-1)
                hierarchical_outputs.append(torch.matmul(weights, value))
            return torch.cat(hierarchical_outputs, dim=1)
        else:
            raise ValueError(f"Unknown attention type: {self.attention_type}")

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ff_hidden_dim, dropout, attention_type, alignment_fn):
        super(TransformerBlock, self).__init__()
        self.attention = CustomAttention(d_model, attention_type=attention_type, alignment_fn=alignment_fn, n_heads=n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, d_model)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output = self.attention(x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class DocumentClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, ff_hidden_dim, num_layers, num_classes, max_seq_len, dropout=0.1, attention_type='multihead', alignment_fn='scaled_dot'):
        super(DocumentClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_seq_len, d_model))
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, ff_hidden_dim, dropout, attention_type, alignment_fn)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_embedding[:, :seq_len, :]
        for layer in self.layers:
            x = layer(x)
        x = x.mean(dim=1)
        x = self.dropout(x)
        return self.fc(x)



In [3]:


# Training Loop
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Evaluation Loop
def evaluate(model, valid_loader, criterion, device):
    model.eval()
    valid_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in valid_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids)

            # Calculate loss
            loss = criterion(outputs, labels)
            valid_loss += loss.item()

            # Get predictions (threshold at 0.5 for multi-label classification)
            preds = torch.sigmoid(outputs) > 0.5
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    avg_valid_loss = valid_loss / len(valid_loader)

    # Concatenate predictions and labels
    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')

    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1
    }

    return avg_valid_loss, metrics

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from nltk.corpus import reuters
from nltk import download
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import GPT2Tokenizer

# Download Reuters dataset
download('reuters')
download('punkt')

# Load Reuters Dataset
docs = reuters.fileids()
documents = [reuters.raw(doc_id) for doc_id in docs]
labels = [reuters.categories(doc_id) for doc_id in docs]

# Binarize the labels for multi-label classification
mlb = MultiLabelBinarizer()
labels_binarized = mlb.fit_transform(labels)

# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(
    documents, labels_binarized, test_size=0.2, random_state=42
)

# Load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Custom Dataset
class ReutersDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        features = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': features['input_ids'].squeeze(0),
            'attention_mask': features['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.float32)
        }

# Prepare DataLoaders


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package reuters to
[nltk_data]     /home/gourishanker/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /home/gourishanker/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
import pandas 
pandas.DataFrame({"documents":documents,"labels":labels}).to_csv("reuters.csv")

In [23]:
df = pandas.read_csv("reuters.csv" , index_col=0)

In [27]:
df["len"] = df["documents"].apply(lambda x: len(x))

In [37]:
df.sort_values("len",ascending=True,inplace=True)

In [None]:
attention_types = [
    #"kernelized",   # Kernelized Attention
    #"local",        # Local Attention
    #"global",       # Global Attention
    #"multihead",    # Multihead Attention
    "group_query",  # Group Query Attention

   # "hierarchical"  # Hierarchical Attention
]

# List of available alignment functions
alignment_functions = [
    "dot",          # Dot Product Alignment
   # "scaled_dot",   # Scaled Dot Product Alignment
    #"additive"      # Additive (Bahdanau) Alignment
]

vocab_size = tokenizer.vocab_size+1
d_model = 128
n_heads = 8
ff_hidden_dim = 512
num_layers = 4
num_classes =len(mlb.classes_)
max_seq_len = 1024
batch_size = 32
device = 'cuda'
criterion = nn.BCEWithLogitsLoss()


train_dataset = ReutersDataset(X_train, y_train, tokenizer)
test_dataset = ReutersDataset(X_test, y_test, tokenizer)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
valid_loader = DataLoader(test_dataset, batch_size, shuffle=False)

# Example usage in a loop
for attn_type in attention_types:
    for align_fn in alignment_functions:
        print(f"Testing with Attention: {attn_type}, Alignment: {align_fn}")
        model = DocumentClassifier(
            vocab_size,
            d_model,
            n_heads,
            ff_hidden_dim,
            num_layers,
            num_classes,
            max_seq_len,
            attention_type=attn_type,
            alignment_fn=align_fn
        )
        model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        accuracy = 0
        for epoch in range(10):
            train_loss = train(model, train_loader, criterion, optimizer, device)
            print(f"Epoch {epoch+1}, Training Loss: {train_loss:.4f}")

            # Evaluate on validation data
            avg_valid_loss, metrics = evaluate(model, valid_loader, criterion, device)
            print(f"Validation Loss: {avg_valid_loss:.4f}")
            print(f"Accuracy: {metrics['accuracy']:.4f}, Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1 Score: {metrics['f1_score']:.4f}")

            if(accuracy < metrics['accuracy']):
                accuracy = metrics['accuracy']
                torch.save(model.state_dict(), f"best_model_{attn_type}_{align_fn}.pth")
                print("Model saved!")

Testing with Attention: group_query, Alignment: dot
Epoch 1, Training Loss: 0.2105


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Validation Loss: 0.0850
Accuracy: 0.0000, Precision: 0.0000, Recall: 0.0000, F1 Score: 0.0000
Epoch 2, Training Loss: 0.0682
Validation Loss: 0.0548
Accuracy: 0.2924, Precision: 0.9254, Recall: 0.2396, F1 Score: 0.3806
Model saved!
Epoch 3, Training Loss: 0.0513
Validation Loss: 0.0453
Accuracy: 0.2952, Precision: 0.9969, Recall: 0.2415, F1 Score: 0.3888
Model saved!
Epoch 4, Training Loss: 0.0435
Validation Loss: 0.0399
Accuracy: 0.5042, Precision: 0.9042, Recall: 0.4179, F1 Score: 0.5716
Model saved!
Epoch 5, Training Loss: 0.0396
Validation Loss: 0.0371
Accuracy: 0.5190, Precision: 0.9045, Recall: 0.4300, F1 Score: 0.5829
Model saved!
Epoch 6, Training Loss: 0.0370
Validation Loss: 0.0349
Accuracy: 0.5185, Precision: 0.9244, Recall: 0.4394, F1 Score: 0.5957
Epoch 7, Training Loss: 0.0347
Validation Loss: 0.0334
Accuracy: 0.5167, Precision: 0.9168, Recall: 0.4504, F1 Score: 0.6041
Epoch 8, Training Loss: 0.0324
Validation Loss: 0.0312
Accuracy: 0.5556, Precision: 0.9100, Recall: 0.50

: 

In [None]:
x1 = torch.randn((32, 128 ,128))
x2 = torch.randn((32, 128 ,16))


In [32]:
for x in x1.chunk(8, dim=-1):
    print(x.shape) # Dividing into hierarchical levels

torch.Size([32, 128, 16])
torch.Size([32, 128, 16])
torch.Size([32, 128, 16])
torch.Size([32, 128, 16])
torch.Size([32, 128, 16])
torch.Size([32, 128, 16])
torch.Size([32, 128, 16])
torch.Size([32, 128, 16])


In [26]:
x1 @ x2.transpose(-2, -1)

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [32, 128] but got: [32, 16].