# Text Sentiment Model

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
import numpy as np
import h5py
import os
from tqdm import tqdm

## Dataset Class

In [None]:
class CustomDataset(Dataset):
    def __init__(self, filename, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.data = []
        self.labels = []
        self.max_len = max_len

        with h5py.File(filename, 'r') as hdf:
            for group in hdf.keys():
                text = hdf[group].attrs['text']
                label = int(hdf[group].attrs['label'])
                
                encoding = self.tokenizer.encode_plus(
                    text,
                    add_special_tokens=True,
                    max_length=self.max_len,
                    truncation=True,
                    padding='max_length',
                    return_attention_mask=True,
                    return_tensors='pt',
                )
                
                self.data.append({
                    'input_ids': encoding['input_ids'].squeeze(),
                    'attention_mask': encoding['attention_mask'].squeeze()
                })
                self.labels.append(label)
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, item):
        return {
            **self.data[item],
            'labels': torch.tensor(self.labels[item], dtype=torch.long)
        }
    
def stratified_split(dataset, test_size=0.2, val_size=0.1, random_seed=None):
    np.random.seed(random_seed)
    
    labels = np.array(dataset.labels)
    unique_classes, class_counts = np.unique(labels, return_counts=True)
    class_indices = [np.where(labels == i)[0] for i in unique_classes]
    
    test_split_size = (class_counts * test_size).astype(int)
    val_split_size = (class_counts * val_size).astype(int)
    
    train_indices, test_indices, val_indices = [], [], []
    for class_idx, class_split_test, class_split_val in zip(class_indices, test_split_size, val_split_size):
        class_indices_perm = np.random.permutation(class_idx)
        class_test_indices = class_indices_perm[:class_split_test]
        class_val_indices = class_indices_perm[class_split_test:class_split_test + class_split_val]
        class_train_indices = class_indices_perm[class_split_test + class_split_val:]
        
        train_indices.extend(class_train_indices)
        test_indices.extend(class_test_indices)
        val_indices.extend(class_val_indices)
    
    np.random.shuffle(train_indices)
    np.random.shuffle(test_indices)
    np.random.shuffle(val_indices)
    
    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)
    val_dataset = Subset(dataset, val_indices)
    
    return train_dataset, val_dataset, test_dataset


## Model Class

In [None]:
class SentimentClassifier(nn.Module):
    def __init__(self, n_classes):
        super(SentimentClassifier, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=n_classes)
    
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return output

## Train/Eval functions

In [None]:
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    loss_fct = nn.CrossEntropyLoss()

    for batch in tqdm(dataloader, total=len(dataloader), desc="Training"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        model.zero_grad()  
        outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])
        loss = outputs.loss
        
        total_loss += loss.item()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


def evaluate(model, dataloader, device):
    model.eval()
    total_loss, total_correct, total_examples = 0, 0, 0
    loss_fct = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in tqdm(dataloader, total=len(dataloader), desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'])
            
            loss = outputs.loss
            logits = outputs.logits
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            total_correct += (preds == batch['labels']).sum().item()
            total_examples += batch['labels'].size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_examples
    return avg_loss, accuracy

## Loader/Saver

In [None]:
def save_checkpoint(model, optimizer, epoch, file_path="model_checkpoints"):
    if not os.path.exists(file_path):
        os.makedirs(file_path)
    
    checkpoint_path = os.path.join(file_path, f"model_epoch_{epoch}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")

def load_checkpoint(model, optimizer, file_path):
    checkpoint = torch.load(file_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Model loaded from {file_path}, epoch {epoch}")
    return epoch

## Init

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

filename = "../text_labels.h5"  
max_len = 128 
dataset = CustomDataset(filename=filename, tokenizer=tokenizer, max_len=max_len)

n_classes = 3  
model = SentimentClassifier(n_classes=n_classes)

In [None]:
# train_size = int(0.8 * len(dataset))
# val_size = int(0.1 * len(dataset))
# test_size = len(dataset) - train_size - val_size
test_size = 0.1
val_size = 0.1

# Split the dataset
train_dataset, val_dataset, test_dataset = stratified_split(dataset, test_size=test_size, val_size=val_size)

In [None]:
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


## Execution

In [None]:
epochs = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-6)
total_steps = len(train_loader) * epochs

# Load checkpoint (ignore if no checkpoint exists)
checkpoint_file = "model_checkpoints/model_epoch_XX.pth"

if os.path.isfile(checkpoint_file):
    starting_epoch = load_checkpoint(model, optimizer, checkpoint_file)
else:
    starting_epoch = 0


# Training and validation loop
train_losses = []
val_losses = []

for epoch in range(starting_epoch, epochs):
    print(f'Epoch {epoch+1}/{epochs}')
    print('-' * 10)

    train_loss = train(model, train_loader, optimizer, device)
    print(f'Training loss: {train_loss}')
    train_losses.append(train_loss)  

    val_loss, val_accuracy = evaluate(model, val_loader, device)
    print(f'Validation loss: {val_loss}, Accuracy: {val_accuracy}')
    val_losses.append(val_loss)  

    save_checkpoint(model, optimizer, epoch, file_path="model_checkpoints")



## Plotting

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

plt.figure(figsize=(10, 6))
n = len(train_losses)
plt.plot(range(1, n+1), train_losses, label='Training Loss')
plt.plot(range(1, n+1), val_losses, label='Validation Loss')
plt.title('Training and Validation Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))

plt.show()


## Testing

In [None]:
checkpoint_file = "model_checkpoints/model_epoch_0.pth"
load_checkpoint(model, optimizer, checkpoint_file)
test_loss, test_accuracy = evaluate(model, test_loader, device)
print(f'Test loss: {test_loss}, Test Accuracy: {test_accuracy}')