In [24]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np

from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR

EMBEDDING_DIM = 384
HIDDEN_DIM = 64

# Function Definitions
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class SimpleClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, dropout_prob=0.7):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.dropout1 = nn.Dropout(p=dropout_prob)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout2 = nn.Dropout(p=dropout_prob)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

class DeepClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, dropout_prob=0.7):
        super(DeepClassifier, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.dropout1 = nn.Dropout(p=dropout_prob)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout2 = nn.Dropout(p=dropout_prob)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout3 = nn.Dropout(p=dropout_prob)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout4 = nn.Dropout(p=dropout_prob)
        self.fc5 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = F.relu(self.fc3(x))
        x = self.dropout3(x)
        x = F.relu(self.fc4(x))
        x = self.dropout4(x)
        x = self.fc5(x)
        return x
    
class DeeperClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, dropout_prob=0.7):
        super(DeeperClassifier, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.dropout1 = nn.Dropout(p=dropout_prob)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout2 = nn.Dropout(p=dropout_prob)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout3 = nn.Dropout(p=dropout_prob)
        self.fc4 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout4 = nn.Dropout(p=dropout_prob)
        self.fc5 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout5 = nn.Dropout(p=dropout_prob)
        self.fc6 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = F.relu(self.fc3(x))
        x = self.dropout3(x)
        x = F.relu(self.fc4(x))
        x = self.dropout4(x)
        x = F.relu(self.fc5(x))
        x = self.dropout5(x)
        x = self.fc6(x)
        return x


class ResidualBlock(nn.Module):
    def __init__(self, hidden_dim):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        identity = x
        out = F.relu(self.fc1(x))
        out = self.fc2(out)
        out += identity
        return F.relu(out)

class ResidualClassifier(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, output_dim, dropout_prob=0.7):
        super(ResidualClassifier, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.dropout1 = nn.Dropout(p=dropout_prob)
        self.res_block1 = ResidualBlock(hidden_dim)
        self.res_block2 = ResidualBlock(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = self.fc2(x)
        return x

def encode_sentences_in_batches(sentences, tokenizer, model, device, batch_size=32):
    embeddings = []
    for i in range(0, len(sentences), batch_size):
        batch = sentences[i:i+batch_size]
        encoded_input = tokenizer(batch, padding=True, truncation=True, return_tensors='pt').to(device)
        with torch.no_grad():
            model_output = model(**encoded_input)
        batch_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        embeddings.append(F.normalize(batch_embeddings, p=2, dim=1).cpu().numpy())
    return np.vstack(embeddings)

def train_epoch(classifier, optimizer, criterion, data_loader, device, scheduler=None):
    classifier.train()
    total_loss = 0
    for embeddings, labels in data_loader:
        embeddings = embeddings.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        predictions = classifier(embeddings)
        loss = criterion(predictions.view(-1), labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # if scheduler:
        # scheduler.step()

    return total_loss / len(data_loader)

def evaluate_accuracy(classifier, data_loader, device):
    classifier.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for embeddings, labels in data_loader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            predictions = classifier(embeddings)
            predicted_labels = predictions.view(-1) > 0.0
            correct += (predicted_labels == labels.byte()).sum().item()
            total += labels.size(0)

    return correct / total

def prepare_data_loader(embeddings, labels, batch_size=16):
    dataset = TensorDataset(torch.tensor(embeddings, dtype=torch.float32), torch.tensor(labels, dtype=torch.float32))
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def evaluate_loss(classifier, data_loader, criterion, device):
    classifier.eval()
    total_loss = 0
    with torch.no_grad():
        for embeddings, labels in data_loader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)

            predictions = classifier(embeddings)
            loss = criterion(predictions.view(-1), labels)
            total_loss += loss.item()
    return total_loss / len(data_loader)

# Set device
device = "mps"
print(f"Using device: {device}")

# Load data
df = pd.read_csv('example_dataset_with_controls.csv')
train_data, temp_data = train_test_split(df, test_size=0.4, random_state=42)
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

# Prepare tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(device)

# Paths to save embeddings
train_embeddings_file = 'train_embeddings.npy'
val_embeddings_file = 'val_embeddings.npy'
test_embeddings_file = 'test_embeddings.npy'
train_labels_file = 'train_labels.npy'
val_labels_file = 'val_labels.npy'
test_labels_file = 'test_labels.npy'

# Encode sentences and save embeddings if not already saved
def save_embeddings_and_labels(embeddings_file, labels_file, data, tokenizer, model, device, batch_size=32):
    sentences = data['text'].tolist()
    labels = data['isTransition'].values
    embeddings = encode_sentences_in_batches(sentences, tokenizer, model, device, batch_size)
    np.save(embeddings_file, embeddings)
    np.save(labels_file, labels)
    return embeddings, labels

def load_embeddings_and_labels(embeddings_file, labels_file):
    embeddings = np.load(embeddings_file)
    labels = np.load(labels_file)
    return embeddings, labels

if not os.path.exists(train_embeddings_file) or not os.path.exists(train_labels_file):
    print("Encoding and saving training data...")
    train_embeddings, train_labels = save_embeddings_and_labels(train_embeddings_file, train_labels_file, train_data, tokenizer, model, device)
else:
    print("Loading training data...")
    train_embeddings, train_labels = load_embeddings_and_labels(train_embeddings_file, train_labels_file)

if not os.path.exists(val_embeddings_file) or not os.path.exists(val_labels_file):
    print("Encoding and saving validation data...")
    val_embeddings, val_labels = save_embeddings_and_labels(val_embeddings_file, val_labels_file, val_data, tokenizer, model, device)
else:
    print("Loading validation data...")
    val_embeddings, val_labels = load_embeddings_and_labels(val_embeddings_file, val_labels_file)

if not os.path.exists(test_embeddings_file) or not os.path.exists(test_labels_file):
    print("Encoding and saving test data...")
    test_embeddings, test_labels = save_embeddings_and_labels(test_embeddings_file, test_labels_file, test_data, tokenizer, model, device)
else:
    print("Loading test data...")
    test_embeddings, test_labels = load_embeddings_and_labels(test_embeddings_file, test_labels_file)

# Prepare DataLoaders
train_loader = prepare_data_loader(train_embeddings, train_labels)
val_loader = prepare_data_loader(val_embeddings, val_labels)
test_loader = prepare_data_loader(test_embeddings, test_labels)

# Initialize classifier, optimizer, and loss function

# Select model architecture
model_type = input("Enter model type (simple, deep, deeper, residual): ").strip().lower()

if model_type == "simple":
    classifier = SimpleClassifier(embedding_dim=384, hidden_dim=HIDDEN_DIM, output_dim=1, dropout_prob=0.7).to(device)
elif model_type == "deep":
    classifier = DeepClassifier(embedding_dim=384, hidden_dim=HIDDEN_DIM, output_dim=1, dropout_prob=0.7).to(device)
elif model_type == "residual":
    classifier = ResidualClassifier(embedding_dim=384, hidden_dim=HIDDEN_DIM, output_dim=1, dropout_prob=0.7).to(device)
elif model_type == "deeper":
    classifier = DeeperClassifier(embedding_dim=384, hidden_dim=HIDDEN_DIM, output_dim=1, dropout_prob=0.7).to(device)
else:
    print("Invalid model type. Defaulting to simple model.")
    classifier = SimpleClassifier(embedding_dim=384, hidden_dim=HIDDEN_DIM, output_dim=1, dropout_prob=0.7).to(device)
    
optimizer = optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

# Set up learning rate scheduler
# scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
# scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Model save path
model_path = 'classifier_state_768_dims.pth'

# Check if a saved state exists and load it
if os.path.exists(model_path):
    print("Loading checkpoint...")
    checkpoint = torch.load(model_path)
    classifier.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Continue from next epoch
else:
    start_epoch = 0  # Start from scratch

# Training Loop
epochs = int(input("Enter the number of epochs to train for: "))
max_accuracy = 0
stopper_count = 0

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(start_epoch, epochs):
    # Train for one epoch
    train_loss = train_epoch(classifier, optimizer, criterion, train_loader, device, scheduler)
    train_losses.append(train_loss)

    # Evaluate train and validation accuracy and loss
    train_accuracy = evaluate_accuracy(classifier, train_loader, device)
    val_loss = evaluate_loss(classifier, val_loader, criterion, device)
    val_accuracy = evaluate_accuracy(classifier, val_loader, device)
    train_accuracies.append(train_accuracy)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

    # scheduler.step(val_loss)

    # Evaluate accuracy on test set
    test_accuracy = evaluate_accuracy(classifier, test_loader, device)

    print(f"Epoch {epoch+1}/{epochs}, Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy*100:.2f}%, Test Accuracy: {test_accuracy*100:.2f}%")

    # Save the model state if best accuracy so far
    if test_accuracy > max_accuracy:
        stopper_count = 0  # Reset stopping criteria
        max_accuracy = test_accuracy
        torch.save({
            'epoch': epoch,
            'model_state_dict': classifier.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, model_path)
    else:
        stopper_count += 1  # Increment stopping criteria

    # Early stopping criteria
    if stopper_count == 50:
        print("Early stopping criteria met. Stopping training.")
        break  # If 10 epochs pass without improvement, end training.

# Plot the train and validation accuracies and losses
plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# Evaluate on test set
test_accuracy = evaluate_accuracy(classifier, test_loader, device)
print(f"Test Accuracy: {test_accuracy*100:.2f}%")


Using device: mps




Loading training data...
Loading validation data...
Loading test data...
Epoch 1/200, Loss: 0.6242, Validation Loss: 0.5795, Validation Accuracy: 70.87%, Test Accuracy: 71.11%
Epoch 2/200, Loss: 0.5687, Validation Loss: 0.5513, Validation Accuracy: 71.27%, Test Accuracy: 72.56%
Epoch 3/200, Loss: 0.5559, Validation Loss: 0.5458, Validation Accuracy: 72.55%, Test Accuracy: 73.71%
Epoch 4/200, Loss: 0.5383, Validation Loss: 0.5499, Validation Accuracy: 71.76%, Test Accuracy: 73.93%
Epoch 5/200, Loss: 0.5302, Validation Loss: 0.5418, Validation Accuracy: 71.71%, Test Accuracy: 73.45%
Epoch 6/200, Loss: 0.5230, Validation Loss: 0.5256, Validation Accuracy: 73.30%, Test Accuracy: 75.87%
Epoch 7/200, Loss: 0.5058, Validation Loss: 0.5224, Validation Accuracy: 74.05%, Test Accuracy: 75.61%
Epoch 8/200, Loss: 0.5038, Validation Loss: 0.5242, Validation Accuracy: 74.01%, Test Accuracy: 75.30%
Epoch 9/200, Loss: 0.4971, Validation Loss: 0.5219, Validation Accuracy: 73.65%, Test Accuracy: 76.05%


KeyboardInterrupt: 