## Import and Setup

In [86]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset

from sklearn.model_selection import train_test_split

In [87]:
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

DATA_DIR = Path.cwd() / 'content'
MODEL_DIR = Path.cwd() / 'models'

## Create Model

In [88]:
class RNN(nn.Module):
    def __init__(self, config):
        super(RNN, self).__init__()
        
        self.embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
                
        self.lstm = nn.LSTM(
            config.embedding_dim,
            config.rnn_size,
            num_layers=config.rnn_num_layers,
            # bidirectional=config.bidirectional,
            batch_first=config.batch_first,
            dropout=config.dropout,
        )
        
        # if config.bidirectional:
        #     self.fc = nn.Linear(config.rnn_size * 2, config.out_size)
        # else:
        self.fc = nn.Linear(config.rnn_size, config.out_size)
        
        self.dropout = nn.Dropout(config.dropout)
        
    def forward(self, x):
        embedded = self.embedding(x)
        logits, hidden = self.lstm(embedded)
        logits = self.dropout(logits)
        logits = logits[:,-1,:]
        out = self.fc(logits)
        
        return out
        

In [112]:
def save_checkpoint(model_to_save, filename='checkpoint.pt'):
    torch.save(model_to_save.state_dict(), MODEL_DIR / filename)

## Create DataLoader

In [89]:
class TextDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.text = self.df['text'].to_numpy()
        self.targets = self.df['pro_nuclear'].to_numpy()
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        return self.text[index], self.targets[index]
    
class MyCollate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    
    def __call__(self, batch):
        texts = [item[0] for item in batch]
        targets = [item[1] for item in batch]
        
        texts = self.tokenizer(
            texts,
            padding=True,
            # truncation=True,
            # max_length=128,
            return_tensors='pt',
        ).input_ids
        
        targets = torch.tensor(targets, dtype=torch.float32)
        
        return texts, targets
    
def get_loader(df, tokenizer, batch_size=32, suffle=True, num_workers=0, pin_memory=True):
    
    dataset = TextDataset(df)
    
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=suffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        collate_fn=MyCollate(tokenizer=tokenizer),
    )
    
    return loader, dataset

## Data Preprocessing

In [90]:
data_df = pd.read_csv(DATA_DIR / 'nuclear_energy_en' / 'processed_data.csv')
data_df["pro_nuclear"] = data_df["pro_nuclear"].astype(int)
# print(f"Removing {data_df['text'].isna().sum()} rows with NaN values")
data_df = data_df.dropna()
# Remove rows that begin with "##"
data_df = data_df[~data_df['text'].str.startswith('##')]


test_split = 0.2
val_split = 0.2

train_df, test_df = train_test_split(data_df, test_size=test_split, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=val_split/(1-test_split), random_state=42)

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', use_fast=False)

# print(len(train_df), len(val_df), len(test_df))

## Train Model

In [136]:
learning_rate = 3e-3
dropout = 0.65
batch_size = 8
num_workers = 0

num_epochs = 20
patience = 10

In [137]:
train_loader, train_dataset = get_loader(train_df, tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)
val_loader, val_dataset = get_loader(val_df, tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)
test_loader, test_dataset = get_loader(test_df, tokenizer=tokenizer, batch_size=batch_size, num_workers=num_workers)

In [138]:
class Config:
    def __init__(self, tokenizer):
        self.vocab_size = tokenizer.vocab_size
        self.embedding_dim = 768
        self.rnn_size = 128
        self.rnn_num_layers = 2
        self.bidirectional = False
        self.batch_first = True
        self.dropout = dropout
        self.out_size = 1
        
config = Config(tokenizer)

In [139]:
model = RNN(config).to(device)
# criterion = nn.CrossEntropyLoss(reduction="mean").to(device)
criterion = nn.BCEWithLogitsLoss(reduction="mean").to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [140]:
best_loss = None
counter = 0

for epoch in range(num_epochs):
    ## TRAINING LOOP
    train_count = 0
    train_loss = 0
    avg_train_loss = 0
    train_correct = 0
    avg_train_acc = 0
    pbar_train = tqdm(train_loader, total=len(train_loader), leave=False)
    model.train()
    for idx, (text, targets) in enumerate(pbar_train):
        # Move data to device
        
        text = text.to(device)
        targets = targets.to(device)
        
        # Forward pass
        outputs = model(text)
        outputs = outputs.squeeze()
        loss = criterion(outputs, targets)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate loss
        train_count += 1
        train_loss += loss.item()
        avg_train_loss = train_loss/train_count
        
        # Calculate accuracy
        predictions = torch.round(torch.sigmoid(outputs))
        train_correct += (predictions == targets).sum().item()
        avg_train_acc = train_correct/(train_count * batch_size)
        
        # Update progress bar
        desc = (
            f"Epoch {epoch+1}/{num_epochs}"
            f" - Train Loss: {avg_train_loss:.4f}"
            f" - Train Acc: {avg_train_acc:.4f}"
        )
        pbar_train.set_description(desc=desc)
        
        
    ## VALIDATION LOOP   
    val_count = 0
    val_loss = 0
    avg_val_loss = 0
    val_correct = 0
    avg_val_acc = 0
    pbar_val = tqdm(val_loader, total=len(val_loader), leave=True)
    model.eval()
    with torch.no_grad():
        for idx, (text, targets) in enumerate(pbar_val):
            # Move data to device
            text = text.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs = model(text)
            outputs = outputs.squeeze()
            loss = criterion(outputs, targets)
            
            # Calculate loss
            val_count += 1
            val_loss += loss.item()
            avg_val_loss = val_loss/val_count
            
            # Calculate accuracy
            predictions = torch.round(torch.sigmoid(outputs))
            val_correct += (predictions == targets).sum().item()
            avg_val_acc = val_correct/(val_count * batch_size)
            
            # Update progress bar
            desc = (
                f"Epoch {epoch+1}/{num_epochs}"
                f" - Train Loss: {avg_train_loss:.4f}"
                f" - Train Acc: {avg_train_acc:.4f}"
                f" - Val Loss: {avg_val_loss:.4f}"
                f" - Val Acc: {avg_val_acc:.4f}"
            )
            pbar_val.set_description(desc=desc)
        
        
    ## CHECKPOINTING AND EARLY STOPPING
    if best_loss is None:   # i.e. first epoch
        best_loss = avg_val_loss
        save_checkpoint(model, filename=f'checkpoint.pt')
        
    elif avg_val_loss > best_loss:  # i.e. loss increased
        counter += 1
        if counter >= patience:
            print(f"Validation loss has not decreased in {patience} epochs. Stopping training.")
            break
        
    else:   # avg_val_loss < best_loss i.e. loss decreased
        best_loss = avg_val_loss
        save_checkpoint(model, filename=f'checkpoint.pt')
        counter = 0

Epoch 1/20 - Train Loss: 0.7007 - Train Acc: 0.4457 - Val Loss: 0.7001 - Val Acc: 0.4844: 100%|██████████| 8/8 [00:00<00:00, 141.24it/s]
Epoch 2/20 - Train Loss: 0.6976 - Train Acc: 0.5217 - Val Loss: 0.7312 - Val Acc: 0.4375: 100%|██████████| 8/8 [00:00<00:00, 133.58it/s]
Epoch 3/20 - Train Loss: 0.7003 - Train Acc: 0.5380 - Val Loss: 0.6653 - Val Acc: 0.5625: 100%|██████████| 8/8 [00:00<00:00, 141.15it/s]
Epoch 4/20 - Train Loss: 0.6323 - Train Acc: 0.5761 - Val Loss: 0.6533 - Val Acc: 0.5156: 100%|██████████| 8/8 [00:00<00:00, 140.93it/s]
Epoch 5/20 - Train Loss: 0.5756 - Train Acc: 0.6685 - Val Loss: 0.5932 - Val Acc: 0.7031: 100%|██████████| 8/8 [00:00<00:00, 138.79it/s]
Epoch 6/20 - Train Loss: 0.5621 - Train Acc: 0.7337 - Val Loss: 0.6401 - Val Acc: 0.5781: 100%|██████████| 8/8 [00:00<00:00, 139.63it/s]
Epoch 7/20 - Train Loss: 0.5684 - Train Acc: 0.6793 - Val Loss: 0.5915 - Val Acc: 0.7656: 100%|██████████| 8/8 [00:00<00:00, 118.23it/s]
Epoch 8/20 - Train Loss: 0.3733 - Train A

Validation loss has not decreased in 10 epochs. Stopping training.





## Generate Predictions

In [150]:
#load model
prediction_model = RNN(config).to(device)
prediction_model.load_state_dict(torch.load(MODEL_DIR / 'checkpoint.pt'))
prediction_model.eval()

RNN(
  (embedding): Embedding(28996, 768)
  (lstm): LSTM(768, 128, num_layers=2, batch_first=True, dropout=0.65)
  (fc): Linear(in_features=128, out_features=1, bias=True)
  (dropout): Dropout(p=0.65, inplace=False)
)

In [151]:
def predict(model, text):
    tokenized_text = tokenizer.encode(text, return_tensors="pt").to(device)
    
    out = model(tokenized_text)

    pred = torch.sigmoid(out)
    rounded_pred = torch.round(pred)

    return pred, rounded_pred

In [155]:
txt = "Nuclear energy is necessary to combat climate change."

pred, rounded_pred = predict(prediction_model, txt)

print(f"Prediction: {pred.item():.4f} - Pro Nuclear: {rounded_pred.item()}")

Prediction: 0.9334 - Pro Nuclear: 1.0


In [156]:
txt = "Nuclear waste is a huge problem."

pred, rounded_pred = predict(prediction_model, txt)

print(f"Prediction: {pred.item():.4f} - Pro Nuclear: {rounded_pred.item()}")

Prediction: 0.0823 - Pro Nuclear: 0.0
