# 1. Setup

## Set environment and device

In [None]:
# Mount to Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Define project folder
FOLDERNAME = 'Colab\ Notebooks/AmazonReviews'

%cd drive/MyDrive/$FOLDERNAME

Mounted at /content/drive
/content/drive/MyDrive/Colab Notebooks/AmazonReviews


In [None]:
# Define device
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

Device: cuda


## Set random seed

In [None]:
torch.manual_seed(77)  # For reproducibility

<torch._C.Generator at 0x7914b41510f0>

# 2. Data Preparation

## Download dataset

### Set up Kaggle API key

In [None]:
import os
import json

# Path to the Kaggle API key JSON file
KAGGLE_API_KEY_PATH = '/content/drive/MyDrive/Colab Notebooks/credentials/kaggle/kaggle.json'


# Kaggle API checks for credentials in this order:
# - Environment variables (KAGGLE_USERNAME and KAGGLE_KEY)
# - Kaggle config file (Kaggle API expects it to be found at ~/.kaggle/kaggle.json)
def setup_kaggle():
    # This function reads a json file and store the API info in the environment variables.
    if os.path.exists(KAGGLE_API_KEY_PATH):
        try:
            with open(KAGGLE_API_KEY_PATH, 'r') as f:
                kaggle_api_key = json.load(f)

            os.environ['KAGGLE_USERNAME'] = kaggle_api_key['username']
            os.environ['KAGGLE_KEY'] = kaggle_api_key['key']

            print("Kaggle API setup completed successfully.")

        except Exception as e:
            print(f"Error setting up Kaggle credentials: {e}")

    else:
        print(f"Kaggle API key file not found at {KAGGLE_API_KEY_PATH}")

setup_kaggle()

Kaggle API setup completed successfully.


### Download Kaggle dataset using Kaggle API



In [None]:
from pathlib import Path
import kaggle

# Define the base directory for datasets
BASE_DATASET_DIR = '/content/drive/MyDrive/Colab Notebooks/raw datasets/kaggle'

def download_kaggle_dataset(dataset_name, dataset_folder=None):
    """
    Download and unzip a Kaggle dataset

    Parameters:
    dataset_name (str): Name of dataset (e.g., 'username/dataset-name')
    dataset_folder (str): Optional subfolder name. If None, uses last part of dataset_name
    """
    try:
        # If no specific folder name is provided, use the dataset name
        if dataset_folder is None:
            dataset_folder = dataset_name.split('/')[-1]

        # Create the full save directory path
        save_dir = os.path.join(BASE_DATASET_DIR, dataset_folder)

        # Create save directory if it doesn't exist
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        # Download the dataset
        kaggle.api.dataset_download_files(
            dataset_name,
            path=save_dir,
            unzip=True  # Set to False if you want to keep the zip file
        )

        print(f"Dataset downloaded successfully to {save_dir}")

        # List remaining files
        print("\nDownloaded files:")
        for file in save_path.glob('*'):
            print(f"- {file.name}")

    except Exception as e:
        print(f"Error downloading dataset: {e}")

download_kaggle_dataset('kritanjalijain/amazon-reviews')

Dataset URL: https://www.kaggle.com/datasets/kritanjalijain/amazon-reviews
Dataset downloaded successfully to /content/drive/MyDrive/Colab Notebooks/raw datasets/kaggle/amazon-reviews

Downloaded files:
- amazon-reviews.zip
- test.csv
- train.csv
- amazon_review_polarity_csv.tgz


## Explore dataset to find out what preprocessing steps are needed

In [None]:
import pandas as pd

train_df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/raw datasets/kaggle/amazon-reviews/train.csv', header=None, names=['label', 'title', 'review'])

train_df.head()

In [None]:
# Check missing value in title and review columns
print('title NaN: ', len(train_df[train_df.title.isnull()]))
print('review NaN: ', len(train_df[train_df.review.isnull()]))

In [None]:
# Combine review title and review columns to a new column 'full_review'
# Replace any NaN values in the title or review with an empty string
train_df['full_review'] = train_df.title.fillna('') + " " + train_df.review.fillna('')

# Check if full_review is combined correctly, not effected by NaN title or review
train_df[train_df.title.isnull() | train_df.review.isnull()].head()

### Label distribution

In [None]:
# Check label distribution (check if the dataset is balanced
pd.Series(train_df.label).value_counts()  # pandas.value_counts is deprecated

### Length distribution --> sequence length

In [None]:
# Check dataset length distribution (length of string in 'full_review' column)
# This can help to determine the sequence length for the model
chunk_size = 1000  # Process in chunks to avoid memory issues
word_lengths = []

for i in range(0, len(train_df), chunk_size):
    chunk = train_df['full_review'].iloc[i:i+chunk_size]
    # Process each review individually
    for review in chunk:
        # Quick and rough word count using split()
        word_lengths.append(len(str(review).split()))

# Convert list to pandas Series for easy statistics
word_lengths = pd.Series(word_lengths)


# Print basic statistics
print("Word count statistics:")
print(word_lengths.describe())

"""
Word count statistics:
count    3.600000e+06
mean     7.848268e+01
std      4.283283e+01
min      2.000000e+00
25%      4.200000e+01
50%      7.000000e+01
75%      1.080000e+02
max      2.570000e+02
"""

### Word frequency distribution --> corpus size, stop words

In [None]:
# Check word frequency distribution
# How many words show up in the dataset and how many times each word shows up
# This can help to determine the vocabulary (corpus) size for the model

# The Oxford English Dictionary: ~600k words, incl. both current and obsolete words
# Merriam-Webster’s Dictionary: ~470k words
# Studies suggest that most people use around 20k-35k words in their active vocabulary
from collections import Counter

chunk_size = 1000  # Process in chunks to avoid memory issues
word_freq = Counter()

for i in range(0, len(train_df), chunk_size):
    chunk = train_df['full_review'].iloc[i:i+chunk_size]
    for review in chunk:
        words = str(review).lower().split()
        word_freq.update(words)

# Print basic statistics
print(f"Total unique words: {len(word_freq)}")
print("\nMost common words:")
for word, count in word_freq.most_common(20):
    print(f"{word}: {count}")

# Look at frequency distribution
print("\nWord frequency distribution:")
freq_dist = Counter([count for count in word_freq.values()])
print("Words appearing:")
for freq in sorted(freq_dist.keys())[:5]:
    print(f"{freq} times: {freq_dist[freq]} words")

# Vocabulary filtering
MIN_FREQ = 8  # Only keep words appearing 8+ times
filtered_vocab = {word: count for word, count in word_freq.items()
                 if count >= MIN_FREQ}
print(f"\nFiltered vocabulary size: {len(filtered_vocab)}")
"""Filtered vocabulary size: 342294  # MIN_FREQ = 8"""

### Check for HTML tags, URLs and special characters

In [None]:
# Look for HTML tags and URLs
# Quick check for HTML tags and URLs in a random sample
import re  # Regular expression

train_df_sample = train_df.sample(n=6000, random_state=33).astype(str)

html_count = train_df_sample['full_review'].str.contains(r'<[^>]+>').sum()
url_count = train_df_sample['full_review'].str.contains(r'http[s]?://').sum()

print(f"Reviews with HTML tags: {html_count}")
print(f"Reviews with URLs: {url_count}")

# Find special symbols
# Get and count special characters
special_chars = ''
for review in train_df_sample['full_review']:
    chars = re.findall(r'[^a-zA-Z0-9\s]', str(review))
    special_chars += ''.join(chars)

print("\nMost common special characters:")
for char, count in Counter(special_chars).most_common(15):
    print(f"'{char}': {count}")

# Check for unusual punctuation
punct_patterns = re.findall(r'[^\w\s]{2,}', ' '.join(train_df_sample['full_review']))
print("\nUnusual punctuation patterns:")
punct_counter = Counter(punct_patterns)
for pattern, count in punct_counter.most_common(10):
    print(f"'{pattern}': {count}")

# Check for repeated punctuation
repeated_punct = re.findall(r'([!?\.]{2,})', ' '.join(train_df_sample['full_review']))
print("\nRepeated punctuation:")
repeat_counter = Counter(repeated_punct)
for pattern, count in repeat_counter.most_common(10):
        print(f"'{pattern}': {count}")


## Preprocess dataset before loading into Dataset class


In [None]:


# # Define patterns
# url_pattern = re.compile(r'http\S+|www\S+|https\S+')
# repeated_char_pattern = re.compile(r'(.)\1{3,}')  # e.g., 'loooove' -> 'love'
# repeated_word_pattern = re.compile(r'\b(\w+)( \1\b)+')  # e.g., 'very very' -> 'very'

# # Process a string
# def simple_preprocess(text):
#     # Convert to lowercase
#     text = str(text).lower()

#     # Remove URLs
#     text = url_pattern.sub('', text)
#     # Remove repeated characters
#     text = repeated_char_pattern.sub(r'\1', text)

#     text = repeated_word_pattern.sub(r'\1', text)

#     # Basic cleaning (no SpaCy)
#     text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
#     text = re.sub(r'\d+', '', text)      # Remove numbers
#     return ' '.join(text.split())

# # Test
# # text = "I loooove this product!!! Check it out at http://example.com"
# # clean_text = simple_preprocess(text)
# # print(f"Original: {text}")
# # print(f"Cleaned: {clean_text}")

# # Process multiple texts (DataFrame)
# from joblib import Parallel, delayed
# from tqdm import tqdm

# def process_multiple_texts_parallel(texts, n_jobs=-1):
#     """
#     Process texts in parallel using joblib
#     n_jobs=-1 means use all available cores
#     """
#     cleaned_texts = Parallel(n_jobs=n_jobs, backend="multiprocessing")(
#         delayed(simple_preprocess)(text) for text in tqdm(texts, desc="Preprocessing texts",
#             miniters=len(texts)//100)
#     )
#     return cleaned_texts

# # Process the reviews
# cleaned_reviews = process_multiple_texts_parallel(train_df['full_review'])
# train_df['cleaned_review'] = cleaned_reviews

# # Check the results
# print("\nExample of original vs cleaned review:")
# print("Original:", train_df['full_review'].iloc[0])
# print("Cleaned:", train_df['cleaned_review'].iloc[0])

# # Check a few more examples
# print("\nMore examples:")
# for i in range(5):
#     print(f"\nExample {i+1}:")
#     print("Original:", train_df['full_review'].iloc[i])
#     print("Cleaned:", train_df['cleaned_review'].iloc[i])

In [None]:
import pandas as pd
import numpy as np

# Load the data
train_df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/raw datasets/kaggle/amazon-reviews/train.csv',
                       header=None, names=['label', 'title', 'review'])
val_df = pd.read_csv('/content/drive/MyDrive/Colab Notebooks/raw datasets/kaggle/amazon-reviews/test.csv',
                     header=None, names=['label', 'title', 'review'])

# Set sample sizes
TRAIN_SIZE = 10_000
VAL_SIZE = 2_000

# Sample with stratification
def stratified_sample(df, n, stratify_col='label'):
    df_sample = df.groupby(stratify_col).apply(
        lambda x: x.sample(n=n//len(df[stratify_col].unique()))
    ).reset_index(drop=True)
    return df_sample

# Take balanced samples
train_df = stratified_sample(train_df, TRAIN_SIZE)
val_df = stratified_sample(val_df, VAL_SIZE)

# Convert labels from 1/2 to 0/1
train_df['label'] = train_df['label'].map({1: 0, 2: 1})
val_df['label'] = val_df['label'].map({1: 0, 2: 1})
# Verify the conversion
print("Train labels unique values:", train_df['label'].unique())
print("Train labels value counts:\n", train_df['label'].value_counts())
print("\nValidation labels unique values:", val_df['label'].unique())
print("Validation labels value counts:\n", val_df['label'].value_counts())

# Combine review title and review columns to a new column 'full_review'
# Replace any NaN values in the title or review with an empty string
train_df['full_review'] = train_df.title.fillna('') + " " + train_df.review.fillna('')
val_df['full_review'] = val_df.title.fillna('') + " " + val_df.review.fillna('')

  df_sample = df.groupby(stratify_col).apply(


Train labels unique values: [0 1]
Train labels value counts:
 label
0    5000
1    5000
Name: count, dtype: int64

Validation labels unique values: [0 1]
Validation labels value counts:
 label
0    1000
1    1000
Name: count, dtype: int64


  df_sample = df.groupby(stratify_col).apply(


In [None]:
import os
from collections import Counter
from nltk.corpus import stopwords
import nltk
import re

# Define NLTK data path in Google Drive
nltk_data_path = '/content/drive/MyDrive/Colab Notebooks/nltk_data'

# Create the directory if it doesn't exist
if not os.path.exists(nltk_data_path):
    os.makedirs(nltk_data_path)

# Add the custom path to NLTK's data path
nltk.data.path.append(nltk_data_path)

# Download stopwords if they don't exist
if not os.path.exists(os.path.join(nltk_data_path, 'corpora/stopwords')):
    nltk.download('stopwords', download_dir=nltk_data_path)



NEGATION_WORDS = {'not', 'no', 'nor', 'never', 'none', 'neither', 'nowhere', 'nothing'}
STOP_WORDS = set(stopwords.words('english')) - NEGATION_WORDS

# Create vocabulary from training data (do this once before training)
def create_vocabulary(texts, min_freq=8):
    word_freq = Counter()
    for text in texts:
        words = str(text).lower().split()
        word_freq.update(words)

    # Filter vocabulary but keep negation words regardless of frequency
    return {word: count for word, count in word_freq.items()
            if (count >= min_freq and word not in STOP_WORDS) or word in NEGATION_WORDS}

# Create vocabulary from training data
vocabulary = create_vocabulary(train_df['full_review'], min_freq=8)


# Define patterns
url_pattern = re.compile(r'http\S+|www\S+|https\S+')
repeated_char_pattern = re.compile(r'(.)\1{3,}')  # e.g., 'loooove' -> 'love'
repeated_word_pattern = re.compile(r'\b(\w+)( \1\b)+')  # e.g., 'very very' -> 'very'

def transform(text):
    """
    Default preprocessing transform - matches simple_preprocess function
    """
    # Convert to lowercase
    text = str(text).lower()

    # Remove URLs
    text = url_pattern.sub('', text)
    # Remove repeated characters
    text = repeated_char_pattern.sub(r'\1', text)
    # Remove repeated words
    text = repeated_word_pattern.sub(r'\1', text)

    # Basic cleaning (no SpaCy)
    text = re.sub(r'[^\w\s]', '', text)  # Remove punctuation
    text = re.sub(r'\d+', '', text)      # Remove numbers

    # Split into words and filter
    words = text.split()
    # Keep only words that are in our vocabulary (this removes both stopwords and infrequent words)
    filtered_words = [word for word in words if word in vocabulary]

    return ' '.join(filtered_words)

# Test transform function
# Test the transform function on a few examples
print("Testing transform function:")
for i in range(3):
    original = train_df['full_review'].iloc[i]
    transformed = transform(original)
    print(f"\nExample {i+1}:")
    print("Original:", original)
    print("Transformed:", transformed)


Testing transform function:

Example 1:
Original: doesn't work Don't buy , put a rib roast in and turned it on , heated up so we left ...came home to a cold pot with an uncooked and ruined $3o piece of meat....called customer service and were give the phone tag run around for 35 minutes and told there was nothing they could do about the cost of the meat which was Their fault...so if u buy one , good luck and don't bother with crappy c s dept !!
Transformed: doesnt work dont buy put turned left came home cold pot ruined piece customer service give phone tag run around minutes told nothing could cost meat u buy one good luck dont bother crappy c

Example 2:
Original: We sent this back We have tied this swaddle, the "swaddleme" and the Amazing Miracle Blanket.We have twins (boy and girl) and not only did this swaddle not keep them secure, but we awoke to find the swaddle part wrapped around their necks - not good.Do yourself a favour and get the miracle blanket, it is the best of the 3.
T

## Load datasets with transforms

In [None]:
from torch.utils.data import Dataset
import torch
from transformers import BertTokenizer  # or another tokenizer of your choice

class ReviewDataset(Dataset):
    def __init__(self, reviews, labels, max_length=300, transform=transform):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length
        # Transform
        self.transform = transform

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        label = self.labels[idx]

        # Apply transform to the review text
        review = self.transform(review)

        # Tokenize and convert to tensor
        encoding = self.tokenizer(
            review,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Create dataset
train_dataset = ReviewDataset(
    reviews=train_df['full_review'],
    labels=train_df['label'].values
)

val_dataset = ReviewDataset(
    reviews=val_df['full_review'],
    labels=val_df['label'].values
)

# Test the dataset
print(f"Dataset size: {len(train_dataset)}")
sample = train_dataset[0]
print(f"Sample input shape: {sample['input_ids'].shape}")
print(f"Sample label: {sample['label']}")

Dataset size: 10000
Sample input shape: torch.Size([300])
Sample label: 0




In [None]:
for i in range(3):
    # Get original text
    original_text = train_dataset.reviews.iloc[i]

    # Get preprocessed text by applying transform
    preprocessed_text = train_dataset.transform(original_text)

    # Get tokenized tensor sample
    sample = train_dataset[i]

    print(f"\nSample {i+1}:")
    print(f"Original text: {original_text[:200]}...")
    print(f"Preprocessed text: {preprocessed_text[:200]}...")
    print(f"Label: {sample['label']}")

    # Decode the tokens back to words to see what BERT sees
    tokens = train_dataset.tokenizer.convert_ids_to_tokens(sample['input_ids'])
    print(f"First 30 tokens: {tokens[:30]}")
    print(f"Input tensor shape: {sample['input_ids'].shape}")
    print(f"Attention mask shape: {sample['attention_mask'].shape}")
    print("-" * 100)


Sample 1:
Original text: doesn't work Don't buy , put a rib roast in and turned it on , heated up so we left ...came home to a cold pot with an uncooked and ruined $3o piece of meat....called customer service and were give th...
Preprocessed text: doesnt work dont buy put turned left came home cold pot ruined piece customer service give phone tag run around minutes told nothing could cost meat u buy one good luck dont bother crappy c...
Label: 0
First 30 tokens: ['[CLS]', 'doesn', '##t', 'work', 'don', '##t', 'buy', 'put', 'turned', 'left', 'came', 'home', 'cold', 'pot', 'ruined', 'piece', 'customer', 'service', 'give', 'phone', 'tag', 'run', 'around', 'minutes', 'told', 'nothing', 'could', 'cost', 'meat', 'u']
Input tensor shape: torch.Size([300])
Attention mask shape: torch.Size([300])
----------------------------------------------------------------------------------------------------

Sample 2:
Original text: We sent this back We have tied this swaddle, the "swaddleme" and the Amaz

In [None]:
# Get a single sample
sample = train_dataset[0]

# Look at where the attention mask is active (1) vs padding (0)
print("Attention mask (first 50 positions):")
print(sample['attention_mask'][:50])

# Look at the actual token IDs
print("\nInput IDs (first 50 positions):")
print(sample['input_ids'][:50])

# Decode a few tokens
token_ids = sample['input_ids'][:10]
tokens = train_dataset.tokenizer.convert_ids_to_tokens(token_ids)
print("\nFirst 10 tokens:")
for id, token in zip(token_ids, tokens):
    print(f"ID: {id:5d} -> Token: {token}")

Attention mask (first 50 positions):
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0])

Input IDs (first 50 positions):
tensor([  101,  2987,  2102,  2147,  2123,  2102,  4965,  2404,  2357,  2187,
         2234,  2188,  3147,  8962,  9868,  3538,  8013,  2326,  2507,  3042,
         6415,  2448,  2105,  2781,  2409,  2498,  2071,  3465,  6240,  1057,
         4965,  2028,  2204,  6735,  2123,  2102,  8572, 10231,  7685,  1039,
          102,     0,     0,     0,     0,     0,     0,     0,     0,     0])

First 10 tokens:
ID:   101 -> Token: [CLS]
ID:  2987 -> Token: doesn
ID:  2102 -> Token: ##t
ID:  2147 -> Token: work
ID:  2123 -> Token: don
ID:  2102 -> Token: ##t
ID:  4965 -> Token: buy
ID:  2404 -> Token: put
ID:  2357 -> Token: turned
ID:  2187 -> Token: left


## Load dataloaders

In [None]:
# Create DataLoader
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=2
)

# Test the dataloader
for batch in train_loader:
    print(f"Batch input shape: {batch['input_ids'].shape}")
    print(f"Batch label shape: {batch['label'].shape}")
    break

Batch input shape: torch.Size([8, 300])
Batch label shape: torch.Size([8])


# 3. Model

## Transfer learning

In [None]:
from transformers import AutoModelForSequenceClassification

# Load pre-trained BERT model and configure for binary classification
model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels=2,
    gradient_checkpointing=True  # Memory optimization
)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# 4. Training Setup

## Hyperparameters for training

In [None]:
LEARNING_RATE = 0.001
MIN_LEARNING_RATE = 1e-6
NUM_EPOCHS = 100

## Loss function, optimizer and learning rate scheduler

In [None]:
import torch.nn as nn
import torch.optim as optim
# Define loss function and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Define learning rate scheduler
# Linear warmup + decay (recommended for BERT)
from transformers import get_linear_schedule_with_warmup

num_training_steps = len(train_loader) * NUM_EPOCHS
num_warmup_steps = num_training_steps // 10  # 10% of total steps

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)


## Metrics

In [None]:
# Define accuracy metrics for classification
def compute_correct(scores, targets):
    predictions = scores.max(axis=1)[1]
    correct = predictions.eq(targets).sum().item()

    return correct

# 5. Training Loop

In [None]:
# Define training function for one epoch
def train_one_epoch(model, train_loader, loss_function, optimizer):
    model.train()
    total_loss = 0
    num_correct = 0
    num_total_inputs = 0

    for batch in train_loader:
        # Move data to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        targets = batch['label'].to(device)

        # Forward pass
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # Get the logits from the output
        loss = loss_function(logits, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        # num_correct += compute_correct(logits, targets)
        # num_total_inputs += targets.size(0)
        # total_loss += loss.item()

    # avg_loss = total_loss / len(train_loader)
    # accuracy = num_correct / num_total_inputs
    # return avg_loss, accuracy


# Define validation function for one epoch
def validate_one_epoch(model, val_loader, loss_function):
    model.eval()
    total_loss = 0
    num_correct = 0
    num_total_inputs = 0

    with torch.no_grad():
        for batch in val_loader:
            # Move data to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            targets = batch['label'].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits  # Get the logits from the output
            loss = loss_function(logits, targets)

            # Calculate accuracy
            num_correct += compute_correct(logits, targets)
            num_total_inputs += targets.size(0)
            total_loss += loss.item()

    avg_loss = total_loss / len(val_loader)
    accuracy = num_correct / num_total_inputs
    return avg_loss, accuracy

# Define main training function
def train_model(model, train_loader, val_loader, loss_function, optimizer, scheduler, num_epochs, patience=5):
    # Log the loss and accuracy of each epoch
    training_history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    # Early stopping
    # Stop when the val_loss isn't improving after certain number of epochs (patience) and revert to the best model
    # This saves training time and prevents overfitting
    best_val_acc = 0.5
    best_model_state = None
    best_epoch = 0
    patience_counter = 0

    for epoch in range(num_epochs):
        print(f'\n{"-" * 20}\nEpoch [{epoch+1}/{num_epochs}]')

        # Training
        train_one_epoch(model, train_loader, loss_function, optimizer)

        # Validation
        val_loss, val_acc = validate_one_epoch(model, val_loader, loss_function)

        # Update learning rate, one time in each epoch
        if scheduler is not None:
            scheduler.step()

        # Update loss and acc history
        # training_history['train_loss'].append(train_loss)
        # training_history['train_acc'].append(train_acc)
        training_history['val_loss'].append(val_loss)
        training_history['val_acc'].append(val_acc)

        # Print epoch summary
        print(
            # f'Train Loss: {train_loss:.6f}, Train Accuracy: {train_acc:.4f}\n'
            f'Val Loss: {val_loss:.6f}, Val Accuracy: {val_acc:.4f}'
        )

        # Early stopping logic
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_loss = val_loss
            best_model_state = model.state_dict()  # A dictionary containing all the model's parameters
            best_epoch = epoch
            patience_counter = 0  # Reset patience counter to zero
            # Save best model checkpoint
            # checkpoint = {
            #     'epoch': epoch,
            #     'model_state_dict': model.state_dict(),
            #     'optimizer_state_dict': optimizer.state_dict(),
            #     'scheduler_state_dict': scheduler.state_dict(),

            #     'best_val_acc': best_val_acc,
            #     'best_val_loss': best_val_loss,
            #     'patience_counter': patience_counter,

            #     'history': training_history,  # Save full history of loss and acc in each epoch
            # }
            # torch.save(checkpoint, 'best_model_checkpoint.pth')
            # print('Saved best model checkpoint.')
        else:
            patience_counter += 1
            print(f'Patience: [{patience_counter}/{patience}]')

        # Save checkpoint
        # if epoch % 20 == 0:
        #     checkpoint = {
        #         'epoch': epoch,
        #         'model_state_dict': model.state_dict(),
        #         'optimizer_state_dict': optimizer.state_dict(),
        #         'scheduler_state_dict': scheduler.state_dict(),

        #         'best_val_loss': best_val_loss,
        #         'patience_counter': patience_counter,

        #         'history': training_history,  # Save full history of loss and acc in each epoch
        #     }
        #     torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')
        #     print('Saved model checkpoint.')

        # Early stopping check (break training loop if patience reached)
        if patience_counter >= patience or val_acc > 0.95:
            print(f'Early stopping triggered at Epoch [{epoch+1}/{num_epochs}].')
            break  # Stop training

    # Revert to best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f'Restored best model state. Epoch [{best_epoch+1}/{num_epochs}]')

# Run the training loop
train_model(model, train_loader, val_loader, loss_function, optimizer, scheduler=scheduler, num_epochs=NUM_EPOCHS, patience=10)


--------------------
Epoch [1/100]
Val Loss: 0.699817, Val Accuracy: 0.4965
Patience: [1/10]

--------------------
Epoch [2/100]
Val Loss: 0.673638, Val Accuracy: 0.6305

--------------------
Epoch [3/100]
Val Loss: 0.611086, Val Accuracy: 0.7155

--------------------
Epoch [4/100]
Val Loss: 0.453030, Val Accuracy: 0.8390

--------------------
Epoch [5/100]
Val Loss: 0.334264, Val Accuracy: 0.8830

--------------------
Epoch [6/100]
Val Loss: 0.295683, Val Accuracy: 0.8805
Patience: [1/10]

--------------------
Epoch [7/100]
Val Loss: 0.277447, Val Accuracy: 0.8855

--------------------
Epoch [8/100]
Val Loss: 0.266331, Val Accuracy: 0.8925

--------------------
Epoch [9/100]
Val Loss: 0.260924, Val Accuracy: 0.8945

--------------------
Epoch [10/100]
Val Loss: 0.256662, Val Accuracy: 0.8975

--------------------
Epoch [11/100]
Val Loss: 0.249780, Val Accuracy: 0.8995

--------------------
Epoch [12/100]
Val Loss: 0.258981, Val Accuracy: 0.9005

--------------------
Epoch [13/100]
Va

KeyboardInterrupt: 