In [1]:
from IPython.core.display import display, HTML
display(HTML("""
    <script>
        var kernel = IPython.notebook.kernel;
        kernel.execute('from IPython.core.display import clear_output');
    </script>
"""))

import IPython
IPython.get_ipython().config.TerminalInteractiveShell.rate_limit_window = 0.0

  from IPython.core.display import display, HTML


In [3]:
import pandas as pd
df = pd.read_csv("pred_data_preprocessed.csv")
df = df.drop("position",axis=1)
df = df.drop("Unnamed: 0",axis=1)
df

Unnamed: 0,committee_name,session_id,chairperson,speaker_name,conversation,contain_offensive_words,dicta_answer
0,אל על,64670,אברהם הירשזון,אברהם הירשזון,. אני כן אסתום לך את הפה.. .,0,0.0
1,אל על,64670,אברהם הירשזון,יצחק כהן,". לא, זאת הצעה לסדר. כל הדיון הזה מיותר. אני מ...",0,0.0
2,אל על,64670,אברהם הירשזון,רוחמה אברהם,". אם חברה הולכת להפרטה, אסור לה לפרסם תשקיף..",0,0.0
3,אל על,64670,אברהם הירשזון,אברהם הירשזון,. אני קורא אותך לסדר..,1,1.0
4,אל על,64670,אברהם הירשזון,יצחק כהן,". לא, זאת הצעה לסדר. כל הדיון הזה מיותר. אני מ...",0,0.0
...,...,...,...,...,...,...,...
127835,ועדת החוקה חוק ומשפט,2225681,שמחה רוטמן,קריאות,\n- - -\n,0,0.0
127836,ועדת החוקה חוק ומשפט,2225681,שמחה רוטמן,גלעד קריב,\nמין שאינו במינו.\n,0,0.0
127837,ועדת החוקה חוק ומשפט,2225681,שמחה רוטמן,שמחה רוטמן,\nההערה הזאת צרמה לי מאוד באוזן על המישור המקצ...,0,0.0
127838,ועדת החוקה חוק ומשפט,2225681,שמחה רוטמן,אלעזר שטרן,"\nאדוני, אני אשמח להגיב לדבר הזה. קודם כול, אנ...",0,0.0


## Model with Window

In [7]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os
import json
import logging
from datetime import datetime
import sys
import torch.amp
import gc
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
from sklearn.metrics import roc_auc_score

class WindowDictaDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.windows = []
        self.labels = []
        self.session_ids = []
        
        # Group by session
        grouped = df.groupby('session_id')
        for session_id, session_df in grouped:
            session_convs = session_df['conversation'].tolist()
            session_labels = session_df['dicta_answer'].tolist()
            
            # Create windows of size 3 and predict the 4th (without seeing it)
            for i in range(0, len(session_convs) - 3):
                # Only use previous 3 conversations with their labels
                window = [
                    f"{session_convs[i]} [LABEL] {session_labels[i]}",
                    f"{session_convs[i+1]} [LABEL] {session_labels[i+1]}",
                    f"{session_convs[i+2]} [LABEL] {session_labels[i+2]}"
                ]
                # The label to predict (4th conversation's label)
                label = session_labels[i+3]
                self.windows.append(window)
                self.labels.append(label)
                self.session_ids.append(session_id)

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        window = self.windows[idx]
        label = self.labels[idx]
        
        # Concatenate only the three previous conversations
        concatenated_text = " [SEP] ".join(window)
        
        encoding = self.tokenizer.encode_plus(
            concatenated_text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(int(label), dtype=torch.long),
            'session_id': self.session_ids[idx]
        }
        
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def setup_logging(base_dir='results_windowed'):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    results_dir = os.path.join(base_dir, f'run_{timestamp}')
    os.makedirs(results_dir, exist_ok=True)
    
    log_file = os.path.join(results_dir, 'training.log')
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    
    plots_dir = os.path.join(results_dir, 'plots')
    models_dir = os.path.join(results_dir, 'models')
    metrics_dir = os.path.join(results_dir, 'metrics')
    
    os.makedirs(plots_dir, exist_ok=True)
    os.makedirs(models_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)
    
    return results_dir, plots_dir, models_dir, metrics_dir

def prepare_windowed_data(df):
    """Prepare the data by ensuring proper session handling and creating train/val/test splits"""
    # Get unique session IDs
    session_ids = df['session_id'].unique()
    
    # Split sessions into train/val/test
    train_sessions, temp_sessions = train_test_split(
        session_ids, test_size=0.2, random_state=42
    )
    val_sessions, test_sessions = train_test_split(
        temp_sessions, test_size=0.5, random_state=42
    )
    
    # Create dataframes for each split
    train_df = df[df['session_id'].isin(train_sessions)]
    val_df = df[df['session_id'].isin(val_sessions)]
    test_df = df[df['session_id'].isin(test_sessions)]
    
    return train_df, val_df, test_df

def evaluate_model(model, dataloader, device, plots_dir=None, phase='test'):
    model.eval()
    predictions = []
    probabilities = []
    actual_labels = []
    session_predictions = {}  # Track predictions by session
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f'Evaluating {phase} set'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels']
            session_ids = batch['session_id']

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            # Get probabilities and predictions
            probs = torch.softmax(outputs.logits, dim=1)
            _, preds = torch.max(probs, dim=1)
            
            predictions.extend(preds.cpu().tolist())
            probabilities.extend(probs[:, 1].cpu().tolist())  # Probability of positive class
            actual_labels.extend(labels.cpu().tolist())
            
            # Track predictions by session
            for sid, pred, actual, prob in zip(session_ids, preds.cpu().tolist(), labels.cpu().tolist(), probs[:, 1].cpu().tolist()):
                if sid not in session_predictions:
                    session_predictions[sid] = {'pred': [], 'actual': [], 'prob': []}
                session_predictions[sid]['pred'].append(pred)
                session_predictions[sid]['actual'].append(actual)
                session_predictions[sid]['prob'].append(prob)
    
    # Calculate overall metrics
    metrics = calculate_metrics(actual_labels, predictions, probabilities, session_predictions)
    
    # Plot confusion matrix
    if plots_dir:
        plot_confusion_matrix(metrics['confusion_matrix'], 
                            os.path.join(plots_dir, f'confusion_matrix_{phase}.png'))
        
        # Plot ROC curve
        plot_roc_curve(actual_labels, probabilities, 
                     os.path.join(plots_dir, f'roc_curve_{phase}.png'))
        
        # Plot session performance
        plot_session_performance(session_predictions, 
                               os.path.join(plots_dir, f'session_performance_{phase}.png'))
    
    return metrics

def calculate_metrics(actual_labels, predictions, probabilities, session_predictions):
    """Calculate comprehensive metrics including per-session analysis and ROC AUC"""
    # Overall metrics
    accuracy = float(accuracy_score(actual_labels, predictions))
    precision = float(precision_score(actual_labels, predictions, average='weighted'))
    recall = float(recall_score(actual_labels, predictions, average='weighted'))
    f1 = float(f1_score(actual_labels, predictions, average='weighted'))
    cm = confusion_matrix(actual_labels, predictions).tolist()  # Convert to list
    
    # ROC AUC
    try:
        roc_auc = float(roc_auc_score(actual_labels, probabilities))
    except ValueError:
        # This can happen if there's only one class in y_true
        roc_auc = None
    
    # Per-session metrics
    session_accuracies = []
    for session in session_predictions.values():
        if session['actual']:  # Check if session has predictions
            session_acc = float(accuracy_score(session['actual'], session['pred']))
            session_accuracies.append(session_acc)
    
    # Convert numpy types to Python native types
    session_metrics = {
        'mean_accuracy': float(np.mean(session_accuracies)),
        'std_accuracy': float(np.std(session_accuracies)),
        'min_accuracy': float(np.min(session_accuracies)),
        'max_accuracy': float(np.max(session_accuracies))
    }
    
    # Compile metrics
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'roc_auc': roc_auc,
        'confusion_matrix': cm,
        'classification_report': classification_report(actual_labels, predictions),
        'session_metrics': session_metrics
    }
    
    return metrics

def plot_roc_curve(actual_labels, probabilities, save_path):
    """Plot ROC curve"""
    try:
        fpr, tpr, thresholds = roc_curve(actual_labels, probabilities)
        
        plt.figure(figsize=(10, 8))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc="lower right")
        plt.savefig(save_path)
        plt.close()
    except Exception as e:
        logging.error(f"Could not plot ROC curve: {e}")
        
def plot_confusion_matrix(cm, save_path):
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(save_path)
    plt.close()

def plot_session_performance(session_predictions, save_path):
    """Plot distribution of per-session accuracies"""
    accuracies = []
    for session in session_predictions.values():
        if session['actual']:  # Check if session has predictions
            acc = accuracy_score(session['actual'], session['pred'])
            accuracies.append(acc)
    
    plt.figure(figsize=(10, 6))
    plt.hist(accuracies, bins=20, edgecolor='black')
    plt.title('Distribution of Session Accuracies')
    plt.xlabel('Accuracy')
    plt.ylabel('Number of Sessions')
    plt.savefig(save_path)
    plt.close()

def train_model(train_dataloader, val_dataloader, model, optimizer, device, 
                num_epochs=3, results_dir=None, gradient_accumulation_steps=4):
    model.train()
    training_stats = []
    early_stopping = EarlyStopping(patience=3)
    scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else torch.amp.GradScaler('cpu')
    
    # Create scheduler with warmup
    num_training_steps = len(train_dataloader) * num_epochs
    num_warmup_steps = num_training_steps // 10
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=num_warmup_steps, 
        num_training_steps=num_training_steps
    )
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs} [Train]')
        
        for batch_idx, batch in enumerate(progress_bar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Only use autocast when CUDA is available
            if torch.cuda.is_available():
                with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                    outputs = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    loss = outputs.loss / gradient_accumulation_steps
            else:
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss / gradient_accumulation_steps

            if torch.cuda.is_available():
                scaler.scale(loss).backward()
            else:
                loss.backward()
            
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                if torch.cuda.is_available():
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    
                scheduler.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * gradient_accumulation_steps
            progress_bar.set_postfix({'train_loss': f'{loss.item() * gradient_accumulation_steps:.4f}'})

        # Validation phase with tqdm
        model.eval()
        val_loss = 0
        val_progress_bar = tqdm(val_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs} [Validation]')
        
        with torch.no_grad():
            for batch in val_progress_bar:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                batch_val_loss = outputs.loss.item()
                val_loss += batch_val_loss
                
                # Update validation progress bar with current loss
                val_progress_bar.set_postfix({'val_loss': f'{batch_val_loss:.4f}'})
        
        val_loss = val_loss / len(val_dataloader)
        
        # Save stats
        epoch_stats = {
            'epoch': epoch + 1,
            'train_loss': total_loss / len(train_dataloader),
            'val_loss': val_loss
        }
        training_stats.append(epoch_stats)
        
        # Save checkpoint if it's the best model
        if epoch == 0 or val_loss < min(s['val_loss'] for s in training_stats[:-1]):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_loss': val_loss,
            }, os.path.join(results_dir, 'models', 'best_model.pt'))
        
        # Early stopping check
        early_stopping(val_loss)
        if early_stopping.early_stop:
            logging.info(f'Early stopping triggered after epoch {epoch + 1}')
            break
        
        logging.info(f'Epoch {epoch + 1}: Train Loss = {epoch_stats["train_loss"]:.4f}, '
                    f'Val Loss = {epoch_stats["val_loss"]:.4f}')
    
    return training_stats

def predict_next_conversation(model, tokenizer, previous_three_conversations, device):
    """Predict the label for the next conversation based on the previous three conversations"""
    model.eval()
    
    concatenated_text = " [SEP] ".join(previous_three_conversations)
    
    encoding = tokenizer.encode_plus(
        concatenated_text,
        add_special_tokens=True,
        max_length=256,
        return_token_type_ids=False,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        _, preds = torch.max(outputs.logits, dim=1)
    
    return preds.item()

def main():
    """
    Main training pipeline for the windowed conversation classification model
    """
    # Setup logging and directories
    results_dir, plots_dir, models_dir, metrics_dir = setup_logging()
    logging.info("Starting training process...")
    
    # Load your data (IMPORTANT: Uncomment and modify to load your actual dataset)
    # df = pd.read_csv('your_conversation_data.csv')
    # Ensure your DataFrame has these columns:
    # - 'session_id': Unique identifier for conversation sessions
    # - 'conversation': The text of each conversation turn
    # - 'dicta_answer': Binary label (0 or 1) for classification
    
   
    logging.info(f"Loaded dataset with {len(df)} samples")
    
    # Prepare windowed data
    train_df, val_df, test_df = prepare_windowed_data(df)
    
    # Initialize tokenizer and model
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-multilingual-cased',
        num_labels=2  # Binary classification
    )
    
    # Create datasets
    train_dataset = WindowDictaDataset(train_df, tokenizer)
    val_dataset = WindowDictaDataset(val_df, tokenizer)
    test_dataset = WindowDictaDataset(test_df, tokenizer)
    
    logging.info(f"Created windowed datasets:")
    logging.info(f"Training windows: {len(train_dataset)}")
    logging.info(f"Validation windows: {len(val_dataset)}")
    logging.info(f"Test windows: {len(test_dataset)}")
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=32,  # Reduced batch size due to longer sequences
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=64,
        num_workers=4,
        pin_memory=True
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=64,
        num_workers=4,
        pin_memory=True
    )
    
    # Setup device and model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # Enable gradient checkpointing for memory efficiency
    model.gradient_checkpointing_enable()
    
    # Setup optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=2e-5,
        weight_decay=0.01
    )
    
    # Train the model
    logging.info("Starting training...")
    training_stats = train_model(
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        model=model,
        optimizer=optimizer,
        device=device,
        num_epochs=5,
        results_dir=results_dir,
        gradient_accumulation_steps=4
    )
    
    # Plot training curves
    plt.figure(figsize=(10, 6))
    epochs = [stat['epoch'] for stat in training_stats]
    train_losses = [stat['train_loss'] for stat in training_stats]
    val_losses = [stat['val_loss'] for stat in training_stats]
    
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join(plots_dir, 'training_curves.png'))
    plt.close()
    
    # Load best model for evaluation
    best_model_path = os.path.join(results_dir, 'models', 'best_model.pt')
    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        logging.info(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")
    
    # Evaluate on validation set
    logging.info("\nEvaluating on validation set...")
    val_metrics = evaluate_model(model, val_dataloader, device, plots_dir, 'validation')
    
    # Log validation metrics
    logging.info("\nValidation Metrics:")
    logging.info(f"Accuracy: {val_metrics['accuracy']:.4f}")
    logging.info(f"Precision: {val_metrics['precision']:.4f}")
    logging.info(f"Recall: {val_metrics['recall']:.4f}")
    logging.info(f"F1 Score: {val_metrics['f1']:.4f}")
    logging.info(f"ROC AUC: {val_metrics['roc_auc']:.4f}")
    logging.info("\nPer-session metrics:")
    logging.info(f"Mean accuracy: {val_metrics['session_metrics']['mean_accuracy']:.4f}")
    logging.info(f"Std accuracy: {val_metrics['session_metrics']['std_accuracy']:.4f}")
    
    # Save validation metrics
    with open(os.path.join(metrics_dir, 'validation_metrics.json'), 'w') as f:
        json.dump(val_metrics, f, indent=4)
    
    # Evaluate on test set
    logging.info("\nEvaluating on test set...")
    test_metrics = evaluate_model(model, test_dataloader, device, plots_dir, 'test')
    
    # Log test metrics
    logging.info("\nTest Metrics:")
    logging.info(f"Accuracy: {test_metrics['accuracy']:.4f}")
    logging.info(f"Precision: {test_metrics['precision']:.4f}")
    logging.info(f"Recall: {test_metrics['recall']:.4f}")
    logging.info(f"F1 Score: {test_metrics['f1']:.4f}")
    logging.info(f"ROC AUC: {test_metrics['roc_auc']:.4f}")
    logging.info("\nPer-session metrics:")
    logging.info(f"Mean accuracy: {test_metrics['session_metrics']['mean_accuracy']:.4f}")
    logging.info(f"Std accuracy: {test_metrics['session_metrics']['std_accuracy']:.4f}")
    
    # Save test metrics
    with open(os.path.join(metrics_dir, 'test_metrics.json'), 'w') as f:
        json.dump(test_metrics, f, indent=4)
    
    # Save model and tokenizer
    model_save_path = os.path.join(models_dir, 'final_model')
    tokenizer_save_path = os.path.join(models_dir, 'tokenizer')
    
    model.save_pretrained(model_save_path)
    tokenizer.save_pretrained(tokenizer_save_path)
    
    # Save configuration
    config = {
        'model_name': 'bert-base-multilingual-cased',
        'max_length': 256,
        'num_labels': 2,
        'window_size': 3,
        'batch_size': 32,
        'gradient_accumulation_steps': 4,
        'learning_rate': 2e-5,
        'weight_decay': 0.01,
        'num_epochs': 5,
        'early_stopping_patience': 3,
        'train_windows': len(train_dataset),
        'val_windows': len(val_dataset),
        'test_windows': len(test_dataset)
    }
    
    with open(os.path.join(models_dir, 'model_config.json'), 'w') as f:
        json.dump(config, f, indent=4)
    
    logging.info("\nTraining and evaluation completed!")
    logging.info(f"Model saved to: {model_save_path}")
    logging.info(f"Tokenizer saved to: {tokenizer_save_path}")
    
    return model, tokenizer, {
        'train': training_stats,
        'validation': val_metrics,
        'test': test_metrics
    }

if __name__ == "__main__":
    main()

2025-02-05 14:54:53,822 - INFO - Starting training process...
2025-02-05 14:54:53,822 - INFO - Loaded dataset with 127840 samples


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


2025-02-05 14:54:55,488 - INFO - Created windowed datasets:
2025-02-05 14:54:55,489 - INFO - Training windows: 96016
2025-02-05 14:54:55,489 - INFO - Validation windows: 11473
2025-02-05 14:54:55,490 - INFO - Test windows: 15758
2025-02-05 14:54:55,599 - INFO - Starting training...


Epoch 1/5 [Train]: 100%|██████████| 3001/3001 [03:28<00:00, 14.40it/s, train_loss=0.3118]
Epoch 1/5 [Validation]: 100%|██████████| 180/180 [00:29<00:00,  6.04it/s, val_loss=0.0338]


2025-02-05 14:59:00,025 - INFO - Epoch 1: Train Loss = 0.3765, Val Loss = 0.2780


Epoch 2/5 [Train]: 100%|██████████| 3001/3001 [03:30<00:00, 14.23it/s, train_loss=0.5271]
Epoch 2/5 [Validation]: 100%|██████████| 180/180 [00:29<00:00,  6.04it/s, val_loss=0.0366]


2025-02-05 15:03:09,399 - INFO - Epoch 2: Train Loss = 0.3151, Val Loss = 0.2648


Epoch 3/5 [Train]: 100%|██████████| 3001/3001 [03:30<00:00, 14.25it/s, train_loss=0.5240]
Epoch 3/5 [Validation]: 100%|██████████| 180/180 [00:29<00:00,  6.06it/s, val_loss=0.0159]

2025-02-05 15:07:09,688 - INFO - Epoch 3: Train Loss = 0.3016, Val Loss = 0.2655



Epoch 4/5 [Train]: 100%|██████████| 3001/3001 [03:31<00:00, 14.21it/s, train_loss=0.1497]
Epoch 4/5 [Validation]: 100%|██████████| 180/180 [00:29<00:00,  6.05it/s, val_loss=0.0144]

2025-02-05 15:11:10,617 - INFO - Epoch 4: Train Loss = 0.2878, Val Loss = 0.2673



Epoch 5/5 [Train]: 100%|██████████| 3001/3001 [03:31<00:00, 14.22it/s, train_loss=0.0621]
Epoch 5/5 [Validation]: 100%|██████████| 180/180 [00:29<00:00,  6.05it/s, val_loss=0.0097]

2025-02-05 15:15:11,414 - INFO - Early stopping triggered after epoch 5



  checkpoint = torch.load(best_model_path)


2025-02-05 15:15:39,116 - INFO - Loaded best model from epoch 2
2025-02-05 15:15:39,116 - INFO - 
Evaluating on validation set...


Evaluating validation set: 100%|██████████| 180/180 [00:29<00:00,  6.16it/s]


2025-02-05 15:16:14,082 - INFO - 
Validation Metrics:
2025-02-05 15:16:14,083 - INFO - Accuracy: 0.8965
2025-02-05 15:16:14,083 - INFO - Precision: 0.8875
2025-02-05 15:16:14,083 - INFO - Recall: 0.8965
2025-02-05 15:16:14,084 - INFO - F1 Score: 0.8805
2025-02-05 15:16:14,084 - INFO - ROC AUC: 0.8831
2025-02-05 15:16:14,084 - INFO - 
Per-session metrics:
2025-02-05 15:16:14,084 - INFO - Mean accuracy: 0.8965
2025-02-05 15:16:14,085 - INFO - Std accuracy: 0.3047
2025-02-05 15:16:14,098 - INFO - 
Evaluating on test set...


Evaluating test set: 100%|██████████| 247/247 [00:40<00:00,  6.15it/s]


2025-02-05 15:17:01,958 - INFO - 
Test Metrics:
2025-02-05 15:17:01,959 - INFO - Accuracy: 0.8471
2025-02-05 15:17:01,960 - INFO - Precision: 0.8370
2025-02-05 15:17:01,960 - INFO - Recall: 0.8471
2025-02-05 15:17:01,960 - INFO - F1 Score: 0.8282
2025-02-05 15:17:01,961 - INFO - ROC AUC: 0.8704
2025-02-05 15:17:01,961 - INFO - 
Per-session metrics:
2025-02-05 15:17:01,961 - INFO - Mean accuracy: 0.8471
2025-02-05 15:17:01,961 - INFO - Std accuracy: 0.3599
2025-02-05 15:17:04,059 - INFO - 
Training and evaluation completed!
2025-02-05 15:17:04,059 - INFO - Model saved to: results_windowed/run_20250205_145453/models/final_model
2025-02-05 15:17:04,060 - INFO - Tokenizer saved to: results_windowed/run_20250205_145453/models/tokenizer
