In [1]:
import os
import re
import json
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter, defaultdict
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_score, recall_score
from sklearn.feature_extraction.text import TfidfVectorizer
import spacy
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the output directory for Kaggle
KAGGLE_OUTPUT_DIR = "/kaggle/working/"
os.makedirs(KAGGLE_OUTPUT_DIR, exist_ok=True)

# Set random seeds for reproducibility
def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

class ImprovedStyleFeatureExtractor:
    def __init__(self, spacy_model="en_core_web_sm"):
        self.nlp = spacy.load(spacy_model, disable=["ner"])
        
        # Cache for processed documents to avoid redundant computation
        self.doc_cache = {}
        # Fixed feature dimension to ensure consistency
        self._feature_dim = None
        self._feature_keys = None

    def extract_features(self, text):
        # Check cache first
        if text in self.doc_cache:
            doc = self.doc_cache[text]
        else:
            doc = self.nlp(text)
            self.doc_cache[text] = doc
            
        # Basic text statistics
        features = {}

        # Lexical diversity
        tokens = [token.text.lower() for token in doc if not token.is_punct]
        types = set(tokens)
        features["type_token_ratio"] = len(types) / max(len(tokens), 1)

        # Sentence length stats
        sent_lengths = [len(sent) for sent in doc.sents]
        features["avg_sent_length"] = np.mean(sent_lengths) if sent_lengths else 0
        features["sent_length_var"] = np.var(sent_lengths) if len(sent_lengths) > 1 else 0
        
        # Sentence length variability (human writing tends to have more variance)
        if len(sent_lengths) > 1:
            features["sent_length_std"] = np.std(sent_lengths)
            features["sent_length_cv"] = features["sent_length_std"] / max(features["avg_sent_length"], 1)  # Coefficient of variation
        else:
            features["sent_length_std"] = 0
            features["sent_length_cv"] = 0

        # POS distributions (simplified - only count main POS tags)
        main_pos_tags = ["NOUN", "VERB", "ADJ", "ADV", "PRON", "DET", "ADP", "CONJ", "PART"]
        pos_counts = Counter([token.pos_ for token in doc])
        for pos in main_pos_tags:
            features[f"pos_{pos}"] = pos_counts.get(pos, 0) / max(len(doc), 1)

        # Function word ratio
        function_words = [token for token in doc if token.is_stop]
        features["function_word_ratio"] = len(function_words) / max(len(doc), 1)

        # Punctuation ratio
        punct_count = len([token for token in doc if token.is_punct])
        features["punct_ratio"] = punct_count / max(len(doc), 1)
        
        # Basic vocabulary richness metric
        if len(tokens) > 10:
            # Simplified Yule's K (vocabulary richness)
            token_freq = Counter(tokens)
            m1 = sum(token_freq.values())
            m2 = sum([freq ** 2 for freq in token_freq.values()])
            features["yules_k"] = (m2 - m1) / (m1 ** 2) if m1 > 0 else 0
        else:
            features["yules_k"] = 0

        # Additional text statistics
        features["avg_word_length"] = np.mean([len(token.text) for token in doc if not token.is_punct]) if tokens else 0
        features["unique_word_ratio"] = len(types) / max(len(tokens), 1)
        
        # NEW HUMAN-SPECIFIC FEATURES
        
        # Contraction ratio (humans use more contractions)
        contraction_pattern = re.compile(r"\b\w+['']\w+\b")
        contraction_count = len(contraction_pattern.findall(text))
        features["contraction_ratio"] = contraction_count / max(len(tokens), 1)
        
        # Question and exclamation marks (humans tend to use more)
        features["question_mark_ratio"] = text.count('?') / max(len(text), 1) * 100  # Multiply by 100 to make it more significant
        features["exclamation_ratio"] = text.count('!') / max(len(text), 1) * 100
        
        # Entropy calculation (human text tends to have higher entropy)
        if tokens:
            token_freq = Counter(tokens)
            total = sum(token_freq.values())
            # Add small epsilon to avoid log(0)
            entropy = -sum([(count/total) * np.log2(count/total + 1e-10) for count in token_freq.values()])
            features["token_entropy"] = entropy
        else:
            features["token_entropy"] = 0
            
        # Sentence complexity indicators
        if sent_lengths:
            # Average commas per sentence (complex sentences have more commas)
            comma_count = text.count(',')
            features["comma_per_sentence"] = comma_count / max(len(sent_lengths), 1)
            
            # Long sentence ratio (human writing often has some very long sentences)
            long_sents = len([s for s in sent_lengths if s > 20])
            features["long_sentence_ratio"] = long_sents / max(len(sent_lengths), 1)
        else:
            features["comma_per_sentence"] = 0
            features["long_sentence_ratio"] = 0
        
        # Convert to numpy array with fixed order
        feature_keys = sorted(features.keys())
        feature_array = np.array([features[k] for k in feature_keys])
        
        # Store feature keys for consistency
        if self._feature_keys is None:
            self._feature_keys = feature_keys
            self._feature_dim = len(feature_keys)
        else:
            # Ensure same features are returned every time
            # If new feature keys are present, add zeros for missing ones
            if set(feature_keys) != set(self._feature_keys):
                fixed_features = {}
                for k in self._feature_keys:
                    if k in features:
                        fixed_features[k] = features[k]
                    else:
                        fixed_features[k] = 0.0
                
                feature_array = np.array([fixed_features[k] for k in self._feature_keys])
        
        return feature_array, feature_keys

    def get_feature_dim(self):
        # Get the dimension of feature vectors
        if self._feature_dim is not None:
            return self._feature_dim
            
        sample_text = "This is a sample text to determine feature dimension."
        features, feature_keys = self.extract_features(sample_text)
        self._feature_dim = len(features)
        self._feature_keys = feature_keys
        return self._feature_dim


class ImprovedEmbeddingGenerator:
    def __init__(self, max_features=5000, output_dim=32):
        # Use TF-IDF with improved n-gram range to capture more patterns
        self.vectorizer = TfidfVectorizer(
            max_features=max_features,
            ngram_range=(1, 3),  # Use unigrams, bigrams and trigrams
            min_df=2,
            max_df=0.95,
            use_idf=True,
            sublinear_tf=True  # Apply sublinear tf scaling (1 + log(tf))
        )
        self.output_dim = output_dim
        
        # Enhanced dimensionality reduction layer
        self.projector = nn.Sequential(
            nn.Linear(max_features, output_dim * 2),
            nn.ReLU(),
            nn.Linear(output_dim * 2, output_dim),
            nn.ReLU()
        ).to(device)
        
        # For sentence splitting
        self.nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser'])
        self.nlp.add_pipe('sentencizer')
        
        # Cache
        self.embedding_cache = {}
        self.fitted = False
        
    def fit(self, texts):
        """Fit the TF-IDF vectorizer on a corpus of texts"""
        print("Fitting TF-IDF vectorizer...")
        self.vectorizer.fit(texts)
        self.fitted = True
        return self
        
    def get_embedding(self, text):
        """Get embedding for a single text"""
        # Check cache first
        if text in self.embedding_cache:
            return self.embedding_cache[text]
            
        if not self.fitted:
            raise ValueError("TF-IDF vectorizer not fitted. Call fit() first.")
            
        # Get TF-IDF vector
        X = self.vectorizer.transform([text]).toarray()
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        
        # Project to lower dimension
        with torch.no_grad():
            embedding = self.projector(X_tensor)
            normalized = F.normalize(embedding, p=2, dim=1)
            
        result = normalized.cpu().numpy()[0]
        self.embedding_cache[text] = result
        
        return result
        
    def get_sentence_embeddings(self, text, max_sentences=16):
        """Get embeddings for sentences in text"""
        try:
            # Use simpler sentence splitting
            doc = self.nlp(text)
            sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
            
            # Limit number of sentences
            sentences = sentences[:max_sentences]
            
            # Get embeddings for sentences
            embeddings = []
            for sent in sentences:
                if sent:  # Skip empty sentences
                    embedding = self.get_embedding(sent)
                    embeddings.append(embedding)
                    
            # Handle empty text
            if not embeddings:
                return np.zeros((max_sentences, self.output_dim), dtype=np.float32)
                
            # Convert to numpy array
            embeddings = np.array(embeddings)
            
            # Pad if necessary
            if len(embeddings) < max_sentences:
                padding = np.zeros((max_sentences - len(embeddings), self.output_dim), dtype=np.float32)
                embeddings = np.vstack([embeddings, padding])
            
            return embeddings
            
        except Exception as e:
            print(f"Error in get_sentence_embeddings: {e}")
            # Return zero embeddings as fallback
            return np.zeros((max_sentences, self.output_dim), dtype=np.float32)
    
    def save(self, path):
        """Save model weights and config"""
        # Save vectorizer
        vectorizer_path = path + ".vectorizer"
        with open(vectorizer_path, 'wb') as f:
            import pickle
            pickle.dump(self.vectorizer, f)
        
        # Save projector
        projector_path = path + ".projector"
        torch.save(self.projector.state_dict(), projector_path)
        
        # Save config
        config_path = path + ".config"
        config = {
            'max_features': self.vectorizer.max_features,
            'output_dim': self.output_dim,
            'fitted': self.fitted
        }
        with open(config_path, 'w') as f:
            json.dump(config, f)
            
        return vectorizer_path, projector_path, config_path
        
    @classmethod
    def load(cls, path):
        """Load a saved model"""
        # Load config
        config_path = path + ".config"
        with open(config_path, 'r') as f:
            config = json.load(f)
            
        # Create instance
        instance = cls(
            max_features=config['max_features'],
            output_dim=config['output_dim']
        )
        
        # Load vectorizer
        vectorizer_path = path + ".vectorizer"
        with open(vectorizer_path, 'rb') as f:
            import pickle
            instance.vectorizer = pickle.load(f)
            
        # Load projector
        projector_path = path + ".projector"
        instance.projector.load_state_dict(torch.load(projector_path, map_location=device))
        
        instance.fitted = config['fitted']
        return instance


class ImprovedDetector(nn.Module):
    def __init__(self, embed_dim=32, style_feature_dim=20, hidden_dim=64):
        super().__init__()
        self.embed_dim = embed_dim
        self.style_feature_dim = style_feature_dim
        self.hidden_dim = hidden_dim

        # Bidirectional LSTM with attention for sentence embeddings
        self.sentence_lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Softmax(dim=1)
        )

        # Processing stylometric features
        self.style_encoder = nn.Sequential(
            nn.Linear(style_feature_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim // 2),
            nn.ReLU(),
        )

        # Full text embedding processor
        self.text_encoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Final classifier (with separate human detection path)
        self.human_detector = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, sentence_embeds, full_text_embed, style_features):
        batch_size = sentence_embeds.size(0)
        
        # Process sentence embeddings with LSTM
        lstm_out, (h_n, _) = self.sentence_lstm(sentence_embeds)
        
        # Apply attention to get weighted representation
        attention_weights = self.attention(lstm_out)
        sentence_repr_attn = torch.sum(lstm_out * attention_weights, dim=1)
        
        # Combine with final hidden state
        h_combined = torch.cat([h_n[0], h_n[1]], dim=1)
        sentence_repr = h_combined
        
        # Human detection path
        human_features = self.human_detector(sentence_repr_attn)

        # Process full-text embedding
        text_repr = self.text_encoder(full_text_embed)

        # Process style features
        style_repr = self.style_encoder(style_features)

        # Concatenate all features
        combined = torch.cat([sentence_repr, text_repr, style_repr], dim=1)
        
        # Final classification
        output = self.classifier(combined)
        
        # Return additional human_features for auxiliary loss if needed
        return output, human_features
        
    def save(self, path):
        """Save model weights and config"""
        save_dict = {
            'state_dict': self.state_dict(),
            'embed_dim': self.embed_dim,
            'style_feature_dim': self.style_feature_dim,
            'hidden_dim': self.hidden_dim
        }
        torch.save(save_dict, path)
        
    @classmethod
    def load(cls, path):
        """Load a saved model"""
        save_dict = torch.load(path, map_location=device)
        instance = cls(
            embed_dim=save_dict['embed_dim'],
            style_feature_dim=save_dict['style_feature_dim'],
            hidden_dim=save_dict['hidden_dim']
        )
        instance.load_state_dict(save_dict['state_dict'])
        return instance


# Find optimal threshold using validation data
def find_optimal_threshold(val_loader, model):
    """Find the optimal classification threshold using validation data"""
    model.eval()
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Finding optimal threshold"):
            try:
                output, _ = model(
                    batch["sentence_embeds"].to(device),
                    batch["full_text_embeds"].to(device),
                    batch["style_features"].to(device)
                )
                probs = torch.sigmoid(output).squeeze().cpu().numpy()
                
                # Handle scalar vs vector case
                if isinstance(probs, np.ndarray):
                    all_probs.extend(probs)
                else:
                    all_probs.append(probs)
                    
                all_labels.extend(batch["label"].cpu().numpy())
            except Exception as e:
                print(f"Error in validation batch: {e}")
                continue
    
    # Try different thresholds
    best_f1, best_threshold = 0, 0.5
    best_balanced_acc = 0
    
    for threshold in np.arange(0.3, 0.7, 0.01):
        preds = [1 if p > threshold else 0 for p in all_probs]
        
        # Get per-class metrics
        human_idx = [i for i, l in enumerate(all_labels) if l == 0]
        ai_idx = [i for i, l in enumerate(all_labels) if l == 1]
        
        if human_idx and ai_idx:
            human_preds = [preds[i] for i in human_idx]
            human_labels = [all_labels[i] for i in human_idx]
            human_acc = accuracy_score(human_labels, human_preds)
            
            ai_preds = [preds[i] for i in ai_idx]
            ai_labels = [all_labels[i] for i in ai_idx]
            ai_acc = accuracy_score(ai_labels, ai_preds)
            
            # Calculate balanced accuracy
            balanced_acc = (human_acc + ai_acc) / 2
            
            # Use F1 score as primary metric
            f1 = f1_score(all_labels, preds)
            
            # Favor thresholds that give more balanced performance
            if balanced_acc > best_balanced_acc:
                best_balanced_acc = balanced_acc
                best_threshold = threshold
                best_f1 = f1
    
    print(f"Optimal threshold: {best_threshold:.4f}, F1: {best_f1:.4f}, Balanced Acc: {best_balanced_acc:.4f}")
    return best_threshold


# Custom collate function for batches
def custom_collate_fn(batch):
    if len(batch) == 0:
        return {}
    
    elem = batch[0]
    result = {}
    
    for key in elem.keys():
        if key == 'text':  # Handle text field (non-tensor)
            result[key] = [b[key] for b in batch]
            continue
            
        values = [b[key] for b in batch]
        
        if isinstance(values[0], torch.Tensor):
            try:
                result[key] = torch.stack(values, 0)
            except RuntimeError as e:
                print(f"ERROR stacking tensors for key '{key}': {e}")
                first_tensor = values[0]
                result[key] = first_tensor.repeat(len(batch), *[1 for _ in range(first_tensor.dim())])
        else:
            result[key] = values
            
    return result


# Load model from checkpoint
def load_from_checkpoint(model, optimizer=None, scheduler=None, checkpoint_path=None):
    """Load model and training state from a checkpoint"""
    if checkpoint_path is None:
        # Find the latest checkpoint
        checkpoint_dir = os.path.join(KAGGLE_OUTPUT_DIR, "checkpoints")
        if not os.path.exists(checkpoint_dir):
            print("No checkpoints found.")
            return model, 0, 0.5
            
        checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pt')]
        if not checkpoints:
            print("No checkpoints found.")
            return model, 0, 0.5
            
        # Sort by epoch number
        checkpoints.sort(key=lambda x: int(x.split('_')[1]))
        checkpoint_path = os.path.join(checkpoint_dir, checkpoints[-1])
    
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load model weights
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Optionally load optimizer and scheduler states
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Return epoch and threshold for resuming
    start_epoch = checkpoint.get('epoch', 0)
    threshold = checkpoint.get('threshold', 0.5)
    
    print(f"Resuming from epoch {start_epoch} with threshold {threshold:.4f}")
    return model, start_epoch, threshold


def train_improved(model, train_loader, val_loader, epochs=12, lr=0.001, start_epoch=0):
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    
    # Learning rate scheduler with early stopping
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, patience=1, threshold=0.005, min_lr=1e-6, verbose=True
    )
    
    # Initialize with normal BCE loss
    main_criterion = nn.BCEWithLogitsLoss()
    aux_criterion = nn.BCEWithLogitsLoss()

    # Accumulate gradients for larger effective batch size
    accum_steps = 2  # Reduced from original to speed up training

    best_val_auc = 0
    best_model = None
    lr_history = []
    optimal_threshold = 0.5
    patience = 3  # Early stopping patience
    patience_counter = 0
    
    # Create checkpoint directory
    checkpoint_dir = os.path.join(KAGGLE_OUTPUT_DIR, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Create a log file for tracking metrics
    log_file = os.path.join(KAGGLE_OUTPUT_DIR, "training_log.csv")
    if start_epoch == 0 or not os.path.exists(log_file):
        with open(log_file, 'w') as f:
            f.write("epoch,train_loss,val_auc,val_acc,val_f1,human_acc,ai_acc,balanced_acc,threshold\n")

    for epoch in range(start_epoch, start_epoch + epochs):
        # Training
        model.train()
        running_loss = 0.0
        preds, labels = [], []

        current_lr = optimizer.param_groups[0]['lr']
        lr_history.append(current_lr)
        
        # Dynamic class weighting based on previous epoch performance
        if epoch > start_epoch and hasattr(model, 'training_stats'):
            human_acc = model.training_stats.get('human_acc', 0)
            ai_acc = model.training_stats.get('ai_acc', 0)
            
            # Adjust loss weights to favor improvement in human accuracy
            if human_acc < ai_acc:
                weight_ratio = min(ai_acc / max(human_acc, 0.01), 5.0)  # Limit max weight to 5x
                pos_weight = torch.tensor([weight_ratio]).to(device)
                main_criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
                print(f"Adjusted loss weight to {weight_ratio:.2f}x for human class")
        
        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{start_epoch + epochs}")):
            try:
                # Forward pass
                output, human_features = model(
                    batch["sentence_embeds"].to(device),
                    batch["full_text_embeds"].to(device),
                    batch["style_features"].to(device),
                )
                
                # Prepare labels
                main_labels = batch["label"].to(device)
                human_labels = (1 - main_labels).float()  # Inverse of AI labels (1 for human, 0 for AI)
                
                # Calculate losses
                main_loss = main_criterion(output.squeeze(), main_labels)
                aux_loss = aux_criterion(human_features.squeeze(), human_labels)
                
                # Combined loss with auxiliary task
                loss = main_loss + 0.3 * aux_loss  # Weighted auxiliary loss
                loss = loss / accum_steps

                # Backward pass
                loss.backward()

                # Gradient accumulation
                if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                    optimizer.step()
                    optimizer.zero_grad()

                running_loss += loss.item() * accum_steps
                sigmoid_preds = torch.sigmoid(output).squeeze().detach().cpu().numpy()
                
                # Handle scalar vs vector case
                if isinstance(sigmoid_preds, np.ndarray):
                    preds.extend(sigmoid_preds)
                else:
                    preds.append(sigmoid_preds)
                    
                labels.extend(main_labels.cpu().numpy())
            except Exception as e:
                print(f"Error in batch {i}: {e}")
                torch.cuda.empty_cache()
                optimizer.zero_grad()
                continue

        # Validation at each epoch for better monitoring
        model.eval()
        val_preds, val_labels = [], []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                try:
                    output, _ = model(
                        batch["sentence_embeds"].to(device),
                        batch["full_text_embeds"].to(device),
                        batch["style_features"].to(device),
                    )
                    sigmoid_preds = torch.sigmoid(output).squeeze().cpu().numpy()
                    
                    # Handle scalar vs vector case
                    if isinstance(sigmoid_preds, np.ndarray):
                        val_preds.extend(sigmoid_preds)
                    else:
                        val_preds.append(sigmoid_preds)
                        
                    val_labels.extend(batch["label"].cpu().numpy())
                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

        # Calculate metrics
        if len(val_preds) > 0 and len(val_labels) > 0:
            # Find optimal threshold every 2 epochs
            if (epoch - start_epoch) % 2 == 0 or epoch == (start_epoch + epochs - 1):
                optimal_threshold = find_optimal_threshold(val_loader, model)
            
            val_auc = roc_auc_score(val_labels, val_preds)
            val_preds_binary = [1 if p > optimal_threshold else 0 for p in val_preds]
            val_acc = accuracy_score(val_labels, val_preds_binary)
            val_f1 = f1_score(val_labels, val_preds_binary)
            
            # Calculate per-class metrics
            human_idx = [i for i, l in enumerate(val_labels) if l == 0]
            ai_idx = [i for i, l in enumerate(val_labels) if l == 1]
            
            human_acc = accuracy_score([val_labels[i] for i in human_idx], 
                                     [val_preds_binary[i] for i in human_idx]) if human_idx else 0
            ai_acc = accuracy_score([val_labels[i] for i in ai_idx], 
                                   [val_preds_binary[i] for i in ai_idx]) if ai_idx else 0
            
            balanced_acc = (human_acc + ai_acc) / 2
            
            # Store stats for dynamic weighting in next epoch
            model.training_stats = {
                'human_acc': human_acc,
                'ai_acc': ai_acc
            }

            # Update learning rate based on balanced accuracy
            scheduler.step(balanced_acc)

            # Save checkpoint for this epoch
            checkpoint_path = os.path.join(
                checkpoint_dir, 
                f"epoch_{epoch+1}_auc_{val_auc:.4f}_human_{human_acc:.4f}_ai_{ai_acc:.4f}.pt"
            )
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_auc': val_auc,
                'human_acc': human_acc,
                'ai_acc': ai_acc,
                'balanced_acc': balanced_acc,
                'threshold': optimal_threshold
            }, checkpoint_path)
            print(f"Saved checkpoint to {checkpoint_path}")
            
            # Save metrics to log file
            with open(log_file, 'a') as f:
                f.write(f"{epoch+1},{running_loss/max(1, len(train_loader)):.6f},{val_auc:.6f},"
                       f"{val_acc:.6f},{val_f1:.6f},{human_acc:.6f},{ai_acc:.6f},"
                       f"{balanced_acc:.6f},{optimal_threshold:.6f}\n")

            # Early stopping based on AUC
            if val_auc > best_val_auc:
                best_val_auc = val_auc
                best_model = copy.deepcopy(model.state_dict())
                torch.save(best_model, os.path.join(KAGGLE_OUTPUT_DIR, "best_model.pt"))
                patience_counter = 0
            else:
                patience_counter += 1
                
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

            print(f"Epoch {epoch+1}/{start_epoch + epochs}:")
            print(f"  Learning Rate: {current_lr:.6f}")
            print(f"  Train Loss: {running_loss/max(1, len(train_loader)):.4f}")
            print(f"  Val AUC: {val_auc:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
            print(f"  Human Acc: {human_acc:.4f}, AI Acc: {ai_acc:.4f}, Balanced: {balanced_acc:.4f}")
            print(f"  Using threshold: {optimal_threshold:.4f}")
        else:
            print(f"Epoch {epoch+1}/{start_epoch + epochs}: No valid predictions for validation")
    
    # Save learning rate history and optimal threshold
    lr_history_path = os.path.join(KAGGLE_OUTPUT_DIR, "lr_history.json")
    with open(lr_history_path, 'w') as f:
        json.dump(lr_history, f)
        
    threshold_path = os.path.join(KAGGLE_OUTPUT_DIR, "optimal_threshold.json")
    with open(threshold_path, 'w') as f:
        json.dump({"threshold": float(optimal_threshold)}, f)

    # Load best model
    if best_model is not None:
        model.load_state_dict(best_model)
    return model, optimal_threshold


class TextDataset(Dataset):
    def __init__(self, texts, labels, embedding_generator, style_extractor, max_sentences=16):
        self.texts = texts
        self.labels = labels
        self.embedding_generator = embedding_generator
        self.style_extractor = style_extractor
        self.max_sentences = max_sentences
        
        # Cache computations
        self.cache = {}
        self.feature_dim = style_extractor.get_feature_dim()
        
    def __len__(self):
        return len(self.texts)
        
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Check cache first
        if idx in self.cache:
            return self.cache[idx]
        
        try:    
            # Get sentence embeddings (with reduced max_sentences)
            sentence_embeds = self.embedding_generator.get_sentence_embeddings(
                text, max_sentences=self.max_sentences
            )
            
            # Get full text embedding
            full_text_embed = self.embedding_generator.get_embedding(text)
            
            # Get stylometric features
            style_features, _ = self.style_extractor.extract_features(text)
            
            # Verify dimensions
            if style_features.shape[0] != self.feature_dim:
                print(f"WARNING: Style feature dimension mismatch for idx {idx}")
                if style_features.shape[0] < self.feature_dim:
                    style_features = np.pad(style_features, (0, self.feature_dim - style_features.shape[0]))
                else:
                    style_features = style_features[:self.feature_dim]
            
            item = {
                "sentence_embeds": torch.tensor(sentence_embeds, dtype=torch.float32),
                "full_text_embeds": torch.tensor(full_text_embed, dtype=torch.float32),
                "style_features": torch.tensor(style_features, dtype=torch.float32),
                "label": torch.tensor(label, dtype=torch.float32),
                "text": text
            }
            
            self.cache[idx] = item
            return item
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            # Return a dummy item with zero tensors
            dummy_item = {
                "sentence_embeds": torch.zeros((self.max_sentences, self.embedding_generator.output_dim), dtype=torch.float32),
                "full_text_embeds": torch.zeros(self.embedding_generator.output_dim, dtype=torch.float32),
                "style_features": torch.zeros(self.feature_dim, dtype=torch.float32),
                "label": torch.tensor(label, dtype=torch.float32),
                "text": text
            }
            return dummy_item


def augment_human_samples(texts, labels, multiplier=1.5):
    """Augment human-written samples to address class imbalance"""
    human_texts = [text for text, label in zip(texts, labels) if label == 0]
    human_indices = [i for i, label in enumerate(labels) if label == 0]
    
    print(f"Found {len(human_texts)} human texts for augmentation")
    new_texts = []
    new_labels = []
    
    # Simple augmentation: paragraph shuffling and sentence removal
    # Should maintain human characteristics while creating variation
    for idx in human_indices[:int(len(human_indices) * multiplier)]:
        text = texts[idx]
        
        # Skip very short texts
        if len(text.split()) < 30:
            continue
            
        # Augmentation 1: Remove random sentences (10-20%)
        sentences = re.split(r'(?<=[.!?])\s+', text)
        if len(sentences) > 5:
            # Randomly remove 10-20% of sentences
            remove_count = max(1, int(len(sentences) * (0.1 + 0.1 * np.random.random())))
            remove_indices = np.random.choice(len(sentences), size=remove_count, replace=False)
            augmented = ' '.join([s for i, s in enumerate(sentences) if i not in remove_indices])
            new_texts.append(augmented)
            new_labels.append(0)  # Still human-written
            
    print(f"Created {len(new_texts)} augmented human samples")
    
    # Combine with original data
    augmented_texts = texts + new_texts
    augmented_labels = labels + new_labels
    
    return augmented_texts, augmented_labels


class ImprovedAITextDetector:
    def __init__(self, model_path=None, embedding_generator_path=None, feature_keys_path=None, threshold_path=None):
        self.style_extractor = None
        self.embedding_generator = None
        self.model = None
        self.feature_keys = None
        self.threshold = 0.5
        
        if model_path and embedding_generator_path and feature_keys_path:
            self.load(model_path, embedding_generator_path, feature_keys_path, threshold_path)
    
    def initialize_components(self, max_features=5000, embed_dim=32, hidden_dim=64):
        """Initialize all components of the detector"""
        print("Initializing ImprovedStyleFeatureExtractor...")
        self.style_extractor = ImprovedStyleFeatureExtractor()
        
        print("Initializing ImprovedEmbeddingGenerator...")
        self.embedding_generator = ImprovedEmbeddingGenerator(max_features=max_features, output_dim=embed_dim)
        
        # Get feature dimension
        style_feature_dim = self.style_extractor.get_feature_dim()
        print(f"Style feature dimension: {style_feature_dim}")
        
        print("Initializing ImprovedDetector model...")
        self.model = ImprovedDetector(
            embed_dim=embed_dim,
            style_feature_dim=style_feature_dim,
            hidden_dim=hidden_dim
        ).to(device)
        
        # Initialize training stats
        self.model.training_stats = {'human_acc': 0, 'ai_acc': 0}
        
        # Get feature keys
        _, self.feature_keys = self.style_extractor.extract_features("Sample text.")
    
    def load_data(self, csv_path, test_size=0.2, val_size=0.1, max_samples=None, augment_human=True):
        """Load and preprocess data from CSV file"""
        print(f"Loading data from {csv_path}...")
        df = pd.read_csv(csv_path)
        
        # Handle different column names
        if 'text' not in df.columns or 'ai' not in df.columns:
            # Try to infer columns
            text_col = None
            label_col = None
            
            # Look for text column
            text_candidates = ['text', 'content', 'document', 'passage']
            for col in text_candidates:
                if col in df.columns:
                    text_col = col
                    break
            
            # Look for label column
            label_candidates = ['ai', 'label', 'is_ai', 'generated', 'is_generated']
            for col in label_candidates:
                if col in df.columns:
                    label_col = col
                    break
            
            if text_col is None or label_col is None:
                print(f"Could not identify text and label columns. Available columns: {df.columns}")
                raise ValueError("CSV must contain 'text' and 'ai' columns")
            
            # Rename columns
            df = df.rename(columns={text_col: 'text', label_col: 'ai'})
        
        # Ensure text column is string type
        df['text'] = df['text'].astype(str)
        
        # Filter out very short texts
        df = df[df['text'].str.len() > 20].reset_index(drop=True)
        
        # Calculate class distribution
        ai_count = df['ai'].sum()
        human_count = len(df) - ai_count
        print(f"Class distribution: Human={human_count}, AI={ai_count}")
        
        # Limit samples if specified
        if max_samples and len(df) > max_samples:
            df = df.sample(max_samples, random_state=42).reset_index(drop=True)
        
        # Split into train, validation, and test sets
        train_df, temp_df = train_test_split(df, test_size=test_size+val_size, random_state=42, stratify=df['ai'])
        val_df, test_df = train_test_split(temp_df, test_size=test_size/(test_size+val_size), random_state=42, stratify=temp_df['ai'])
        
        # Augment human samples in training set if needed
        if augment_human:
            texts = train_df['text'].tolist()
            labels = train_df['ai'].tolist()
            
            augmented_texts, augmented_labels = augment_human_samples(texts, labels)
            
            train_df = pd.DataFrame({
                'text': augmented_texts,
                'ai': augmented_labels
            })
            
            print(f"Training set after augmentation: {len(train_df)} samples")
            
        print(f"Data split: Train={len(train_df)}, Validation={len(val_df)}, Test={len(test_df)}")
        
        return train_df, val_df, test_df
    
    def create_datasets(self, train_df, val_df, test_df, batch_size=32, max_sentences=16, num_workers=0):
        """Create PyTorch datasets and dataloaders"""
        print("Creating datasets...")
        
        # Fit the TF-IDF vectorizer on training data
        self.embedding_generator.fit(train_df['text'].tolist())
        
        train_dataset = TextDataset(
            train_df['text'].tolist(),
            train_df['ai'].tolist(),
            self.embedding_generator,
            self.style_extractor,
            max_sentences=max_sentences
        )
        
        val_dataset = TextDataset(
            val_df['text'].tolist(),
            val_df['ai'].tolist(),
            self.embedding_generator,
            self.style_extractor,
            max_sentences=max_sentences
        )
        
        test_dataset = TextDataset(
            test_df['text'].tolist(),
            test_df['ai'].tolist(),
            self.embedding_generator,
            self.style_extractor,
            max_sentences=max_sentences
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        return train_loader, val_loader, test_loader
    
    def train(self, train_loader, val_loader, epochs=12, lr=0.001, resume_training=True):
        """Train the model"""
        print("Training model...")
        
        # Check for checkpoints to resume from
        start_epoch = 0
        if resume_training:
            checkpoint_dir = os.path.join(KAGGLE_OUTPUT_DIR, "checkpoints")
            if os.path.exists(checkpoint_dir):
                checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pt')]
                if checkpoints:
                    # Find latest checkpoint
                    checkpoints.sort(key=lambda x: int(x.split('_')[1]))
                    latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])
                    
                    # Setup optimizer and scheduler for loading
                    optimizer = optim.AdamW(self.model.parameters(), lr=lr)
                    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                        optimizer, mode="max", factor=0.5, patience=1
                    )
                    
                    # Load the checkpoint
                    self.model, start_epoch, self.threshold = load_from_checkpoint(
                        self.model, optimizer, scheduler, latest_checkpoint
                    )
        
        # Train model (continue from last checkpoint if available)
        if start_epoch < epochs:
            self.model, self.threshold = train_improved(
                self.model, 
                train_loader, 
                val_loader, 
                epochs=epochs,
                lr=lr,
                start_epoch=start_epoch
            )
        else:
            print(f"Training already completed for {epochs} epochs. No further training needed.")
        
        # Save the optimal threshold
        threshold_path = os.path.join(KAGGLE_OUTPUT_DIR, "optimal_threshold.json")
        with open(threshold_path, 'w') as f:
            json.dump({"threshold": float(self.threshold)}, f)
    
    def evaluate(self, test_loader):
        """Evaluate the model on test data"""
        print("Evaluating model...")
        self.model.eval()
        all_preds = []
        all_labels = []
        all_texts = []  # To store texts for error analysis
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Testing"):
                try:
                    output, _ = self.model(
                        batch["sentence_embeds"].to(device),
                        batch["full_text_embeds"].to(device),
                        batch["style_features"].to(device),
                    )
                    # Apply sigmoid to get probabilities
                    sigmoid_preds = torch.sigmoid(output).squeeze().cpu().numpy()
                    
                    # Handle scalar vs vector case
                    if isinstance(sigmoid_preds, np.ndarray):
                        all_preds.extend(sigmoid_preds)
                    else:
                        all_preds.append(sigmoid_preds)
                        
                    all_labels.extend(batch["label"].cpu().numpy())
                    all_texts.extend(batch["text"])
                except Exception as e:
                    print(f"Error in test batch: {e}")
                    continue
        
        # Calculate metrics
        if len(all_preds) > 0 and len(all_labels) > 0:
            # Apply optimal threshold
            all_preds_binary = [1 if p > self.threshold else 0 for p in all_preds]
            
            # Overall metrics
            results = {
                "accuracy": accuracy_score(all_labels, all_preds_binary),
                "auc": roc_auc_score(all_labels, all_preds),
                "f1": f1_score(all_labels, all_preds_binary),
                "precision": precision_score(all_labels, all_preds_binary),
                "recall": recall_score(all_labels, all_preds_binary)
            }
            
            # Calculate per-class metrics
            human_idx = [i for i, l in enumerate(all_labels) if l == 0]
            ai_idx = [i for i, l in enumerate(all_labels) if l == 1]
            
            human_preds = [all_preds_binary[i] for i in human_idx]
            human_labels = [all_labels[i] for i in human_idx]
            human_acc = accuracy_score(human_labels, human_preds) if human_idx else 0
            
            ai_preds = [all_preds_binary[i] for i in ai_idx]
            ai_labels = [all_labels[i] for i in ai_idx]
            ai_acc = accuracy_score(ai_labels, ai_preds) if ai_idx else 0
            
            results["human_accuracy"] = human_acc
            results["ai_accuracy"] = ai_acc
            results["balanced_accuracy"] = (human_acc + ai_acc) / 2
            
            print("Test Results:")
            for metric, value in results.items():
                print(f"  {metric.capitalize()}: {value:.4f}")
            
            # Error analysis
            misclassified_human = [(all_texts[i], all_preds[i]) for i in human_idx if all_preds_binary[i] != 0]
            misclassified_ai = [(all_texts[i], all_preds[i]) for i in ai_idx if all_preds_binary[i] != 1]
            
            print(f"\nMisclassified human texts: {len(misclassified_human)}/{len(human_idx)}")
            print(f"Misclassified AI texts: {len(misclassified_ai)}/{len(ai_idx)}")
            
            # Save error analysis
            error_analysis = {
                "misclassified_human_count": len(misclassified_human),
                "misclassified_ai_count": len(misclassified_ai),
                "human_total": len(human_idx),
                "ai_total": len(ai_idx),
            }
            
            # Add error examples (sample to avoid huge files)
            if misclassified_human:
                error_analysis["misclassified_human_examples"] = [
                    {"text": text[:500] + "...", "prob": float(prob)} 
                    for text, prob in misclassified_human[:10]  # Sample 10 examples
                ]
            
            if misclassified_ai:
                error_analysis["misclassified_ai_examples"] = [
                    {"text": text[:500] + "...", "prob": float(prob)} 
                    for text, prob in misclassified_ai[:10]  # Sample 10 examples
                ]
            
            # Save error analysis
            error_path = os.path.join(KAGGLE_OUTPUT_DIR, "error_analysis.json")
            with open(error_path, 'w') as f:
                json.dump(error_analysis, f, indent=2)
                
            return results
        else:
            print("No valid predictions for testing")
            return None
    
    def predict_proba(self, text):
        """
        Predict the probability of text being AI-generated
        """
        # Ensure components are loaded
        if self.model is None or self.embedding_generator is None or self.style_extractor is None:
            raise ValueError("Model components not initialized. Load a trained model first.")
        
        self.model.eval()
        with torch.no_grad():
            # Get sentence embeddings
            sentence_embeds = self.embedding_generator.get_sentence_embeddings(text)
            sentence_embeds = torch.tensor(sentence_embeds, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get full text embedding
            full_text_embed = self.embedding_generator.get_embedding(text)
            full_text_embed = torch.tensor(full_text_embed, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get stylometric features
            style_features, _ = self.style_extractor.extract_features(text)
            style_features = torch.tensor(style_features, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get prediction
            logits, _ = self.model(sentence_embeds, full_text_embed, style_features)
            # Apply sigmoid to get probability
            probability = torch.sigmoid(logits).item()
        
        return probability
    
    def predict(self, text):
        """
        Predict whether text is AI-generated or human-written,
        with confidence-based threshold adjustment
        """
        probability = self.predict_proba(text)
        
        # Apply threshold with confidence adjustment
        if 0.4 < probability < 0.6:
            # For borderline cases, be more conservative with human classification
            result = {
                "probability_ai_generated": probability,
                "prediction": "AI-generated" if probability > 0.52 else "Human-written",
                "confidence": max(probability, 1 - probability),
                "note": "Borderline case - classification uncertain"
            }
        else:
            result = {
                "probability_ai_generated": probability,
                "prediction": "AI-generated" if probability > self.threshold else "Human-written",
                "confidence": max(probability, 1 - probability)
            }
        
        return result
    
    def save(self, output_dir=KAGGLE_OUTPUT_DIR):
        """Save all components necessary for inference"""
        if self.model is None or self.embedding_generator is None or self.style_extractor is None:
            raise ValueError("Model components not initialized.")
            
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Save model
        model_path = os.path.join(output_dir, "improved_detector.pt")
        self.model.save(model_path)
        print(f"Model saved to {model_path}")
        
        # Save embedding generator
        embedding_generator_path = os.path.join(output_dir, "improved_embedding_generator")
        self.embedding_generator.save(embedding_generator_path)
        print(f"Embedding generator saved to {embedding_generator_path}")
        
        # Save feature keys
        feature_keys_path = os.path.join(output_dir, "improved_feature_keys.json")
        with open(feature_keys_path, 'w') as f:
            json.dump(self.feature_keys, f)
        print(f"Feature keys saved to {feature_keys_path}")
        
        # Save threshold
        threshold_path = os.path.join(output_dir, "optimal_threshold.json")
        with open(threshold_path, 'w') as f:
            json.dump({"threshold": float(self.threshold)}, f)
        print(f"Optimal threshold saved to {threshold_path}")
        
        return model_path, embedding_generator_path, feature_keys_path, threshold_path
    
    def load(self, model_path, embedding_generator_path, feature_keys_path, threshold_path=None):
        """Load all components necessary for inference"""
        # Load model
        print(f"Loading model from {model_path}...")
        self.model = ImprovedDetector.load(model_path)
        self.model.to(device)
        
        # Load embedding generator
        print(f"Loading embedding generator from {embedding_generator_path}...")
        self.embedding_generator = ImprovedEmbeddingGenerator.load(embedding_generator_path)
        
        # Load feature keys
        print(f"Loading feature keys from {feature_keys_path}...")
        with open(feature_keys_path, 'r') as f:
            self.feature_keys = json.load(f)
        
        # Load threshold if provided
        if threshold_path and os.path.exists(threshold_path):
            with open(threshold_path, 'r') as f:
                self.threshold = json.load(f).get("threshold", 0.5)
            print(f"Loaded optimal threshold: {self.threshold:.4f}")
        
        # Initialize StyleFeatureExtractor
        print("Initializing ImprovedStyleFeatureExtractor...")
        self.style_extractor = ImprovedStyleFeatureExtractor()


def train_improved_detector(csv_path, batch_size=32, epochs=12, lr=0.001, max_features=5000, embed_dim=32, hidden_dim=64, max_sentences=16):
    """Function to train the improved detector"""
    # Initialize detector
    detector = ImprovedAITextDetector()
    detector.initialize_components(max_features=max_features, embed_dim=embed_dim, hidden_dim=hidden_dim)
    
    # Load and preprocess data
    train_df, val_df, test_df = detector.load_data(csv_path, augment_human=True)
    
    # Create datasets and dataloaders
    train_loader, val_loader, test_loader = detector.create_datasets(
        train_df, val_df, test_df, 
        batch_size=batch_size,
        max_sentences=max_sentences,
        num_workers=0  # Use 0 workers for Kaggle
    )
    
    # Train model with automatic checkpoint resumption
    detector.train(train_loader, val_loader, epochs=epochs, lr=lr, resume_training=True)
    
    # Evaluate model
    results = detector.evaluate(test_loader)
    
    # Save model for future inference
    model_path, embedding_generator_path, feature_keys_path, threshold_path = detector.save()
    
    # Save results to a file
    if results:
        results_path = os.path.join(KAGGLE_OUTPUT_DIR, "results.json")
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"Results saved to {results_path}")
    
    print(f"Model saved to {model_path}")
    
    return detector


def predict_with_improved_detector(text, model_dir=KAGGLE_OUTPUT_DIR):
    """
    Predict whether a text is AI-generated or human-written
    using the improved detector
    
    Args:
        text (str): Text to classify
        model_dir (str): Directory containing saved model files
        
    Returns:
        dict: Prediction results including probability, prediction label, and confidence
    """
    # Construct paths to saved files
    model_path = os.path.join(model_dir, "improved_detector.pt")
    embedding_generator_path = os.path.join(model_dir, "improved_embedding_generator")
    feature_keys_path = os.path.join(model_dir, "improved_feature_keys.json")
    threshold_path = os.path.join(model_dir, "optimal_threshold.json")
    
    # Check if files exist
    if not os.path.exists(model_path) or not os.path.exists(feature_keys_path):
        raise FileNotFoundError(f"Model files not found. Please train the model first.")
    
    # Load detector
    detector = ImprovedAITextDetector(model_path, embedding_generator_path, feature_keys_path, threshold_path)
    
    # Get prediction
    result = detector.predict(text)
    
    return result

# ============================================================================
# Configuration
# ============================================================================

# Path to your dataset (CSV file with 'text' and 'ai' columns)
# 'ai' column should contain 0 for human-written, 1 for AI-generated
CSV_PATH = "/kaggle/input/combined-dataset-ai-human/combined_data.csv"

# Training parameters (optimized for faster training)
BATCH_SIZE = 96        # Slightly larger batch size
EPOCHS = 12            # More epochs for better convergence
LEARNING_RATE = 0.001  # Standard learning rate

# Model parameters (still light but with more features)
MAX_FEATURES = 5000    # Number of TF-IDF features
EMBED_DIM = 32         # Same embedding dimension
HIDDEN_DIM = 64        # Same hidden dimension
MAX_SENTENCES = 16     # Same maximum sentences

# Example text for testing
EXAMPLE_TEXT = """
This is a sample text to demonstrate the improved AI text detector.
It can classify whether this text was written by a human or generated by an AI.
The detector uses enhanced stylometric analysis and optimized embeddings to make its determination.
The improved model adds human-specific linguistic features and better training processes.
"""

# Mode: "train", "predict", or "both"
MODE = "both"

# ============================================================================
# Main Execution
# ============================================================================

def main():
    global CSV_PATH
    # Print configuration
    print("=" * 80)
    print("Improved Lightweight AI Text Detector - Configuration")
    print("=" * 80)
    print(f"CSV Path: {CSV_PATH}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Epochs: {EPOCHS}")
    print(f"Learning Rate: {LEARNING_RATE}")
    print(f"Max TF-IDF Features: {MAX_FEATURES}")
    print(f"Embedding Dimension: {EMBED_DIM}")
    print(f"Hidden Dimension: {HIDDEN_DIM}")
    print(f"Maximum Sentences: {MAX_SENTENCES}")
    print(f"Mode: {MODE}")
    print("=" * 80)
    
    # Install necessary packages if needed
    try:
        import spacy
    except ImportError:
        print("Installing required packages...")
        import subprocess
        subprocess.call(["pip", "install", "spacy"])
        subprocess.call(["python", "-m", "spacy", "download", "en_core_web_sm"])
        print("Packages installed.")
    
    # Check if model already exists
    model_path = os.path.join(KAGGLE_OUTPUT_DIR, "improved_detector.pt")
    embedding_generator_path = os.path.join(KAGGLE_OUTPUT_DIR, "improved_embedding_generator")
    feature_keys_path = os.path.join(KAGGLE_OUTPUT_DIR, "improved_feature_keys.json")
    threshold_path = os.path.join(KAGGLE_OUTPUT_DIR, "optimal_threshold.json")
    
    # Use auto-discovery to find CSV files if path not specified
    if CSV_PATH == "/kaggle/input/combined-dataset-ai-human/combined_data.csv":
        import glob
        csv_files = glob.glob("/kaggle/input/combined-dataset-ai-human/combined_data.csv")
        if csv_files:
            CSV_PATH = csv_files[0]
            print(f"Auto-discovered CSV file: {CSV_PATH}")
        else:
            print("No CSV files found in /kaggle/input/. Using a small synthetic dataset for demonstration.")
            # Create a small synthetic dataset
            synthetic_df = pd.DataFrame({
                'text': [
                    "This is a human written text. It has some style and personality.",
                    "The AI generated this text. It's somewhat different from human text.",
                    "Humans write with varying sentence lengths and unique expressions.",
                    "AI models produce text with patterns that can be detected by our system."
                ],
                'ai': [0, 1, 0, 1]
            })
            CSV_PATH = os.path.join(KAGGLE_OUTPUT_DIR, "synthetic_dataset.csv")
            synthetic_df.to_csv(CSV_PATH, index=False)
            print(f"Created synthetic dataset at {CSV_PATH}")
    
    detector = None
    
    # Training
    if MODE in ["train", "both"]:
        print("\nStarting training...")
        detector = train_improved_detector(
            csv_path=CSV_PATH,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            lr=LEARNING_RATE,
            max_features=MAX_FEATURES,
            embed_dim=EMBED_DIM,
            hidden_dim=HIDDEN_DIM,
            max_sentences=MAX_SENTENCES
        )
    
    # Prediction
    if MODE in ["predict", "both"]:
        if all(os.path.exists(path) for path in [model_path, feature_keys_path]):
            print("\nPredicting on example text:")
            print("-" * 50)
            print(EXAMPLE_TEXT)
            print("-" * 50)
            
            result = predict_with_improved_detector(EXAMPLE_TEXT)
            
            print("\nPrediction Results:")
            print(f"Classification: {result['prediction']}")
            print(f"Probability of being AI-generated: {result['probability_ai_generated']:.4f}")
            print(f"Confidence: {result['confidence']:.4f}")
            if 'note' in result:
                print(f"Note: {result['note']}")
            
            # Example of how to use for custom text
            print("\nTo predict on your own text, use:")
            print("result = predict_with_improved_detector(your_text)")
        else:
            print("\nNo trained model found. Please run with MODE='train' first.")

# Run the main function
if __name__ == "__main__":
    main()

Using device: cuda
Improved Lightweight AI Text Detector - Configuration
CSV Path: /kaggle/input/combined-dataset-ai-human/combined_data.csv
Batch Size: 96
Epochs: 12
Learning Rate: 0.001
Max TF-IDF Features: 5000
Embedding Dimension: 32
Hidden Dimension: 64
Maximum Sentences: 16
Mode: both
Auto-discovered CSV file: /kaggle/input/combined-dataset-ai-human/combined_data.csv

Starting training...
Initializing ImprovedStyleFeatureExtractor...
Initializing ImprovedEmbeddingGenerator...
Style feature dimension: 25
Initializing ImprovedDetector model...
Loading data from /kaggle/input/combined-dataset-ai-human/combined_data.csv...
Class distribution: Human=23513, AI=23513
Found 16459 human texts for augmentation
Created 16051 augmented human samples
Training set after augmentation: 48969 samples
Data split: Train=48969, Validation=4702, Test=9406
Creating datasets...
Fitting TF-IDF vectorizer...
Training model...


Epoch 1/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3500, F1: 0.8277, Balanced Acc: 0.8269
Saved checkpoint to /kaggle/working/checkpoints/epoch_1_auc_0.8981_human_0.8222_ai_0.8316.pt
Epoch 1/12:
  Learning Rate: 0.001000
  Train Loss: 0.6175
  Val AUC: 0.8981, Acc: 0.8269, F1: 0.8277
  Human Acc: 0.8222, AI Acc: 0.8316, Balanced: 0.8269
  Using threshold: 0.3500
Adjusted loss weight to 1.01x for human class


Epoch 2/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Saved checkpoint to /kaggle/working/checkpoints/epoch_2_auc_0.9025_human_0.8596_ai_0.7924.pt
Epoch 2/12:
  Learning Rate: 0.001000
  Train Loss: 0.5665
  Val AUC: 0.9025, Acc: 0.8260, F1: 0.8200
  Human Acc: 0.8596, AI Acc: 0.7924, Balanced: 0.8260
  Using threshold: 0.3500


Epoch 3/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3300, F1: 0.8400, Balanced Acc: 0.8386
Saved checkpoint to /kaggle/working/checkpoints/epoch_3_auc_0.9076_human_0.8299_ai_0.8473.pt
Epoch 3/12:
  Learning Rate: 0.001000
  Train Loss: 0.5535
  Val AUC: 0.9076, Acc: 0.8386, F1: 0.8400
  Human Acc: 0.8299, AI Acc: 0.8473, Balanced: 0.8386
  Using threshold: 0.3300
Adjusted loss weight to 1.02x for human class


Epoch 4/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Saved checkpoint to /kaggle/working/checkpoints/epoch_4_auc_0.9098_human_0.8277_ai_0.8537.pt
Epoch 4/12:
  Learning Rate: 0.001000
  Train Loss: 0.5447
  Val AUC: 0.9098, Acc: 0.8407, F1: 0.8427
  Human Acc: 0.8277, AI Acc: 0.8537, Balanced: 0.8407
  Using threshold: 0.3300
Adjusted loss weight to 1.03x for human class


Epoch 5/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3000, F1: 0.8369, Balanced Acc: 0.8386
Saved checkpoint to /kaggle/working/checkpoints/epoch_5_auc_0.9117_human_0.8486_ai_0.8286.pt
Epoch 5/12:
  Learning Rate: 0.001000
  Train Loss: 0.5406
  Val AUC: 0.9117, Acc: 0.8386, F1: 0.8369
  Human Acc: 0.8486, AI Acc: 0.8286, Balanced: 0.8386
  Using threshold: 0.3000


Epoch 6/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Saved checkpoint to /kaggle/working/checkpoints/epoch_6_auc_0.9122_human_0.7852_ai_0.8958.pt
Epoch 6/12:
  Learning Rate: 0.000500
  Train Loss: 0.5355
  Val AUC: 0.9122, Acc: 0.8405, F1: 0.8489
  Human Acc: 0.7852, AI Acc: 0.8958, Balanced: 0.8405
  Using threshold: 0.3000
Adjusted loss weight to 1.14x for human class


Epoch 7/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3000, F1: 0.8454, Balanced Acc: 0.8420
Saved checkpoint to /kaggle/working/checkpoints/epoch_7_auc_0.9143_human_0.8197_ai_0.8643.pt
Epoch 7/12:
  Learning Rate: 0.000500
  Train Loss: 0.5522
  Val AUC: 0.9143, Acc: 0.8420, F1: 0.8454
  Human Acc: 0.8197, AI Acc: 0.8643, Balanced: 0.8420
  Using threshold: 0.3000
Adjusted loss weight to 1.05x for human class


Epoch 8/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Saved checkpoint to /kaggle/working/checkpoints/epoch_8_auc_0.9148_human_0.8222_ai_0.8592.pt
Epoch 8/12:
  Learning Rate: 0.000250
  Train Loss: 0.5309
  Val AUC: 0.9148, Acc: 0.8407, F1: 0.8436
  Human Acc: 0.8222, AI Acc: 0.8592, Balanced: 0.8407
  Using threshold: 0.3000
Adjusted loss weight to 1.05x for human class


Epoch 9/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3100, F1: 0.8471, Balanced Acc: 0.8447
Saved checkpoint to /kaggle/working/checkpoints/epoch_9_auc_0.9163_human_0.8294_ai_0.8601.pt
Epoch 9/12:
  Learning Rate: 0.000250
  Train Loss: 0.5292
  Val AUC: 0.9163, Acc: 0.8447, F1: 0.8471
  Human Acc: 0.8294, AI Acc: 0.8601, Balanced: 0.8447
  Using threshold: 0.3100
Adjusted loss weight to 1.04x for human class


Epoch 10/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Saved checkpoint to /kaggle/working/checkpoints/epoch_10_auc_0.9165_human_0.8201_ai_0.8652.pt
Epoch 10/12:
  Learning Rate: 0.000250
  Train Loss: 0.5248
  Val AUC: 0.9165, Acc: 0.8426, F1: 0.8461
  Human Acc: 0.8201, AI Acc: 0.8652, Balanced: 0.8426
  Using threshold: 0.3100
Adjusted loss weight to 1.05x for human class


Epoch 11/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3000, F1: 0.8453, Balanced Acc: 0.8426
Saved checkpoint to /kaggle/working/checkpoints/epoch_11_auc_0.9169_human_0.8252_ai_0.8601.pt
Epoch 11/12:
  Learning Rate: 0.000250
  Train Loss: 0.5289
  Val AUC: 0.9169, Acc: 0.8426, F1: 0.8453
  Human Acc: 0.8252, AI Acc: 0.8601, Balanced: 0.8426
  Using threshold: 0.3000
Adjusted loss weight to 1.04x for human class


Epoch 12/12:   0%|          | 0/511 [00:00<?, ?it/s]

Validation:   0%|          | 0/49 [00:00<?, ?it/s]

Finding optimal threshold:   0%|          | 0/49 [00:00<?, ?it/s]

Optimal threshold: 0.3300, F1: 0.8458, Balanced Acc: 0.8452
Saved checkpoint to /kaggle/working/checkpoints/epoch_12_auc_0.9177_human_0.8409_ai_0.8494.pt
Epoch 12/12:
  Learning Rate: 0.000125
  Train Loss: 0.5245
  Val AUC: 0.9177, Acc: 0.8452, F1: 0.8458
  Human Acc: 0.8409, AI Acc: 0.8494, Balanced: 0.8452
  Using threshold: 0.3300
Evaluating model...


Testing:   0%|          | 0/98 [00:00<?, ?it/s]

Test Results:
  Accuracy: 0.8370
  Auc: 0.9142
  F1: 0.8386
  Precision: 0.8303
  Recall: 0.8471
  Human_accuracy: 0.8269
  Ai_accuracy: 0.8471
  Balanced_accuracy: 0.8370

Misclassified human texts: 814/4703
Misclassified AI texts: 719/4703
Model saved to /kaggle/working/improved_detector.pt
Embedding generator saved to /kaggle/working/improved_embedding_generator
Feature keys saved to /kaggle/working/improved_feature_keys.json
Optimal threshold saved to /kaggle/working/optimal_threshold.json
Results saved to /kaggle/working/results.json
Model saved to /kaggle/working/improved_detector.pt

Predicting on example text:
--------------------------------------------------

This is a sample text to demonstrate the improved AI text detector.
It can classify whether this text was written by a human or generated by an AI.
The detector uses enhanced stylometric analysis and optimized embeddings to make its determination.
The improved model adds human-specific linguistic features and better trai