## Import and Setup

In [86]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split

In [87]:
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

DATA_DIR = Path.cwd() / 'data'
MODEL_DIR = Path.cwd() / 'models'

## Define Helper Classes and Functions

In [112]:
class stats:
    def __init__(self, batch_size):
        self.batch_size = batch_size
        self.count = 0
        self.loss = 0
        self.correct = 0
        
    def update(self, loss, correct):
        self.count += 1
        self.loss += loss
        self.correct += correct
    
    def avg_loss(self):
        return self.loss / self.count
    
    def avg_acc(self):
        return self.correct / (self.count * self.batch_size)


class Config:
    def __init__(self, learning_rate, batch_size, vocab_size, embedding_dim, rnn_size, out_size, rnn_num_layers, bidirectional=False, batch_first=True, dropout=0.0, num_workers=0):
        self.lr = learning_rate
        self.dropout = dropout
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.rnn_size = rnn_size
        self.out_size = out_size
        self.rnn_num_layers = rnn_num_layers
        self.bidirectional = bidirectional
        self.batch_first = batch_first

## Setup Data Pipeline

In [89]:
class TextDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.x = self.df['text'].to_numpy()
        self.y = self.df['pro_nuclear'].to_numpy()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]


class MyCollate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, batch):
        texts = [item[0] for item in batch]
        targets = [item[1] for item in batch]
        
        texts = self.tokenizer(
            texts,
            padding=True,
            # truncation=True,
            # max_length=128,
            return_tensors='pt',
        ).input_ids
        
        targets = torch.tensor(targets, dtype=torch.float32)
        
        return texts, targets


def get_loader(df, tokenizer, config=None, batch_size=8, num_workers=0, suffle=True, pin_memory=True):
    if config is not None: 
        batch_size = config.batch_size
        num_workers = config.num_workers
    
    dataset = TextDataset(df)
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=suffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=MyCollate(tokenizer=tokenizer),
    )
    
    return loader, dataset

## Create Model

In [88]:
class RNN(nn.Module):
    def __init__(self, config):
        super(RNN, self).__init__()
        
        self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
                
        self.lstm = nn.LSTM(
            config.embedding_dim,
            config.rnn_size,
            num_layers=config.rnn_num_layers,
            # bidirectional=config.bidirectional,
            batch_first=config.batch_first,
            dropout=config.dropout,
        )
        
        # if config.bidirectional:
        #     self.fc = nn.Linear(config.rnn_size * 2, config.out_size)
        # else:
        self.fc = nn.Linear(config.rnn_size, config.out_size)
        
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        embedded = self.embedding(x)
        logits, hidden = self.lstm(embedded)
        logits = self.dropout(logits)
        logits = logits[:,-1,:]
        out = self.fc(logits)
        
        return out
        

## Create Trainer Class

In [None]:
class Trainer:
    def __init__(self, model, optimizer, criterion, train_loader, val_loader, batch_size):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.batch_size = batch_size
    
    def save_checkpoint(self, filename='checkpoint.pt', directory=MODEL_DIR):
        torch.save(self.model.state_dict(), directory / filename)
    
    def load_checkpoint(self, filename='checkpoint.pt', directory=MODEL_DIR):
        self.model.load_state_dict(torch.load(directory / filename))
    
    def train(self, num_epochs, patience=np.inf, leave_pbar=True):
        best_loss = np.inf
        best_epoch = 0
        best_model = None
        
        for epoch in range(num_epochs):
            tstats = stats(self.batch_size)
            vstats = stats(self.batch_size)
            
            ## TRAINING LOOP
            pbar_train = tqdm(self.train_loader, total=len(self.train_loader), leave=False)
            self.model.train()
            for idx, (text, targets) in enumerate(pbar_train):
                # Move data to device
                text = text.to(device)
                targets = targets.to(device)
                
                # Forward pass
                outputs = self.model(text)
                outputs = outputs.squeeze(1)
                loss = self.criterion(outputs, targets)
                
                # Backpropagation
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # Update loss and accuracy
                predictions = torch.round(torch.sigmoid(outputs))
                tstats.update(loss.item(), (predictions == targets).sum().item())
                
                # Update progress bar
                desc = (
                    f"Epoch {epoch+1:02}/{num_epochs:02}"
                    f" - Train Loss: {tstats.avg_loss():.4f}"
                    f" - Train Acc: {tstats.avg_acc():.4f}"
                )
                pbar_train.set_description(desc=desc)
            
            ## VALIDATION LOOP  
            pbar_val = tqdm(self.val_loader, total=len(self.val_loader), leave=leave_pbar)
            self.model.eval()
            with torch.no_grad():
                for idx, (text, targets) in enumerate(pbar_val):
                    # Move data to device
                    text = text.to(device)
                    targets = targets.to(device)
                    
                    # Forward pass
                    outputs = self.model(text)
                    outputs = outputs.squeeze(1)
                    loss = self.criterion(outputs, targets)
                    
                    # Update loss and accuracy
                    predictions = torch.round(torch.sigmoid(outputs))
                    vstats.update(loss.item(), (predictions == targets).sum().item())
                    
                    # Update progress bar
                    desc = (
                        f"Epoch {epoch+1:02}/{num_epochs:02}"
                        f" - Train Loss: {tstats.avg_loss():.4f}"
                        f" - Train Acc: {tstats.avg_acc():.4f}"
                        f" - Val Loss: {vstats.avg_loss():.4f}"
                        f" - Val Acc: {vstats.avg_acc():.4f}"
                    )
                    pbar_val.set_description(desc=desc)
                          
            ## CHECKPOINTING AND EARLY STOPPING
            if vstats.avg_loss() < best_loss:
                best_loss = vstats.avg_loss()
                best_epoch = epoch
                best_model = self.model.state_dict()
                save_checkpoint(self.model, f"best_model_{best_epoch}.pt")
                
            if epoch - best_epoch > patience:
                print(f"Stopping early at epoch {epoch+1}")
                break
                
        return best_model
    

## Data Preprocessing

In [90]:
data_df = pd.read_csv(DATA_DIR / 'nuclear_energy_en' / 'processed_data.csv')
data_df["pro_nuclear"] = data_df["pro_nuclear"].astype(int)
# print(f"Removing {data_df['text'].isna().sum()} rows with NaN values")
data_df = data_df.dropna()
# Remove rows that begin with "##"
data_df = data_df[~data_df['text'].str.startswith('##')]

test_split = 0.2
val_split = 0.2

train_df, test_df = train_test_split(data_df, test_size=test_split, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=val_split/(1-test_split), random_state=42)

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=False)

# print(len(train_df), len(val_df), len(test_df))

## Hyperparameters and Train Model

In [None]:
config = Config(
    learning_rate=3e-3,
    dropout=0.65,
    batch_size=8,
    num_workers=0,
    vocab_size=tokenizer.vocab_size,
    embedding_dim=768,
    rnn_size=128,
    out_size=1,
    rnn_num_layers=2,
    bidirectional=False,
    batch_first=True,
)

In [137]:
train_loader, train_dataset = get_loader(train_df, tokenizer=tokenizer, config=config)
val_loader, val_dataset = get_loader(val_df, tokenizer=tokenizer, config=config)
test_loader, test_dataset = get_loader(test_df, tokenizer=tokenizer, config=config)

In [139]:
model = RNN(config).to(device)
# criterion = nn.CrossEntropyLoss(reduction="mean").to(device)
criterion = nn.BCEWithLogitsLoss(reduction="mean").to(device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

In [None]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_loader,
    val_loader=val_loader,
    batch_size=config.batch_size,
)

In [None]:
trainer.train(
    num_epochs=20,
    patience=10,
)

## Generate Predictions

In [150]:
#load model
prediction_model = RNN(config).to(device)
prediction_model.load_state_dict(torch.load(MODEL_DIR / 'checkpoint.pt'))
prediction_model.eval()

RNN(
  (embedding): Embedding(28996, 768)
  (lstm): LSTM(768, 128, num_layers=2, batch_first=True, dropout=0.65)
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (dropout): Dropout(p=0.65, inplace=False)
)

In [151]:
def predict(model, text):
    tokenized_text = tokenizer.encode(text, return_tensors="pt").to(device)
    
    out = model(tokenized_text)

    pred = torch.sigmoid(out)
    rounded_pred = torch.round(pred)

    return pred, rounded_pred

In [155]:
txt = "Nuclear energy is necessary to combat climate change."

pred, rounded_pred = predict(prediction_model, txt)

print(f"Prediction: {pred.item():.4f} - Pro Nuclear: {rounded_pred.item()}")

Prediction: 0.9334 - Pro Nuclear: 1.0


In [156]:
txt = "Nuclear waste is a huge problem."

pred, rounded_pred = predict(prediction_model, txt)

print(f"Prediction: {pred.item():.4f} - Pro Nuclear: {rounded_pred.item()}")

Prediction: 0.0823 - Pro Nuclear: 0.0
