In [None]:
import os
import pandas as pd
import torch
from transformers import BigBirdTokenizer, BigBirdModel
from torch import nn
from torch.utils.data import DataLoader

# Set device to GPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("Using device:", device)

# Directories
train_dir = 'C:\\Users\\ericb\\Desktop\\Research\\542_Project\\train_test_data\\train\\'
test_dir = 'C:\\Users\\ericb\\Desktop\\Research\\542_Project\\train_test_data\\test\\'

def load_and_concatenate(directory):
    dataframes = []
    for filename in os.listdir(directory):
        if filename.endswith('.csv'):
            df = pd.read_csv(os.path.join(directory, filename))
            df['text'] = df['cleaned_text'] + ' ' + df['subject']
            dataframes.append(df)
    return pd.concat(dataframes, ignore_index=True)

# Tokenize the text
def tokenize_dataframe(df):
    # Initialize the tokenizer
    tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
    return tokenizer(df['text'].tolist(), max_length=4096, truncation=True, padding='max_length', return_tensors='pt')


# Load data and tokenize
train_data = load_and_concatenate(train_dir)
test_data = load_and_concatenate(test_dir)
train_encodings = tokenize_dataframe(train_data)
test_encodings = tokenize_dataframe(test_data)

# Custom Dataset
class EmailDataset(torch.utils.data.Dataset):
    # [your existing EmailDataset class code]

# Create datasets
train_dataset = EmailDataset(train_encodings, train_data['label'].tolist())
test_dataset = EmailDataset(test_encodings, test_data['label'].tolist())

# Model architecture
class SpamClassifier(nn.Module):
    def __init__(self):
        super(SpamClassifier, self).__init__()
        self.bigbird = BigBirdModel.from_pretrained('google/bigbird-roberta-base')
        self.lstm = nn.LSTM(input_size=768, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
        self.classifier = nn.Linear(256*2, 2)  # Assuming binary classification: warranted or unwarranted spam

    def forward(self, input_ids, attention_mask):
        # BigBird encoder
        outputs = self.bigbird(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]

        # LSTM layer
        lstm_output, _ = self.lstm(sequence_output)
        lstm_output = lstm_output[:, -1, :]

        # Classifier
        logits = self.classifier(lstm_output)
        return logits

# Instantiate model
model = SpamClassifier().to(device)

# Training configuration
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    evaluation_strategy="epoch"
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

# Train the model
trainer.train()
