In [1]:
# Lightweight AI Text Detector
# Optimized for faster training while maintaining similar performance

import os
import re
import json
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter, defaultdict
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, precision_score, recall_score
from sklearn.feature_extraction.text import TfidfVectorizer
import spacy
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the output directory for Kaggle
KAGGLE_OUTPUT_DIR = "/kaggle/working/"
os.makedirs(KAGGLE_OUTPUT_DIR, exist_ok=True)

# Set random seeds for reproducibility
def set_seed(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

class LightStyleFeatureExtractor:
    def __init__(self, spacy_model="en_core_web_sm"):
        self.nlp = spacy.load(spacy_model, disable=["ner"])
        
        # Cache for processed documents to avoid redundant computation
        self.doc_cache = {}
        # Fixed feature dimension to ensure consistency
        self._feature_dim = None
        self._feature_keys = None

    def extract_features(self, text):
        # Check cache first
        if text in self.doc_cache:
            doc = self.doc_cache[text]
        else:
            doc = self.nlp(text)
            self.doc_cache[text] = doc
            
        # Basic text statistics
        features = {}

        # Lexical diversity
        tokens = [token.text.lower() for token in doc if not token.is_punct]
        types = set(tokens)
        features["type_token_ratio"] = len(types) / max(len(tokens), 1)

        # Sentence length stats
        sent_lengths = [len(sent) for sent in doc.sents]
        features["avg_sent_length"] = np.mean(sent_lengths) if sent_lengths else 0
        features["sent_length_var"] = np.var(sent_lengths) if len(sent_lengths) > 1 else 0

        # POS distributions (simplified - only count main POS tags)
        main_pos_tags = ["NOUN", "VERB", "ADJ", "ADV", "PRON", "DET", "ADP", "CONJ", "PART"]
        pos_counts = Counter([token.pos_ for token in doc])
        for pos in main_pos_tags:
            features[f"pos_{pos}"] = pos_counts.get(pos, 0) / max(len(doc), 1)

        # Function word ratio
        function_words = [token for token in doc if token.is_stop]
        features["function_word_ratio"] = len(function_words) / max(len(doc), 1)

        # Punctuation ratio
        punct_count = len([token for token in doc if token.is_punct])
        features["punct_ratio"] = punct_count / max(len(doc), 1)
        
        # Basic vocabulary richness metric
        if len(tokens) > 10:
            # Simplified Yule's K (vocabulary richness)
            token_freq = Counter(tokens)
            m1 = sum(token_freq.values())
            m2 = sum([freq ** 2 for freq in token_freq.values()])
            features["yules_k"] = (m2 - m1) / (m1 ** 2) if m1 > 0 else 0
        else:
            features["yules_k"] = 0

        # Additional text statistics
        features["avg_word_length"] = np.mean([len(token.text) for token in doc if not token.is_punct]) if tokens else 0
        features["unique_word_ratio"] = len(types) / max(len(tokens), 1)
        
        # Convert to numpy array with fixed order
        feature_keys = sorted(features.keys())
        feature_array = np.array([features[k] for k in feature_keys])
        
        # Store feature keys for consistency
        if self._feature_keys is None:
            self._feature_keys = feature_keys
            self._feature_dim = len(feature_keys)
        else:
            # Ensure same features are returned every time
            # If new feature keys are present, add zeros for missing ones
            if set(feature_keys) != set(self._feature_keys):
                fixed_features = {}
                for k in self._feature_keys:
                    if k in features:
                        fixed_features[k] = features[k]
                    else:
                        fixed_features[k] = 0.0
                
                feature_array = np.array([fixed_features[k] for k in self._feature_keys])
        
        return feature_array, feature_keys

    def get_feature_dim(self):
        # Get the dimension of feature vectors
        if self._feature_dim is not None:
            return self._feature_dim
            
        sample_text = "This is a sample text to determine feature dimension."
        features, feature_keys = self.extract_features(sample_text)
        self._feature_dim = len(features)
        self._feature_keys = feature_keys
        return self._feature_dim


class LightEmbeddingGenerator:
    def __init__(self, max_features=5000, output_dim=32):
        # Use TF-IDF instead of transformer models for much lighter computation
        self.vectorizer = TfidfVectorizer(
            max_features=max_features,
            ngram_range=(1, 2),  # Use unigrams and bigrams
            min_df=2,
            max_df=0.9
        )
        self.output_dim = output_dim
        
        # Dimensionality reduction layer (to replace transformers)
        self.projector = nn.Sequential(
            nn.Linear(max_features, output_dim),
            nn.ReLU()
        ).to(device)
        
        # For sentence splitting
        self.nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser'])
        self.nlp.add_pipe('sentencizer')
        
        # Cache
        self.embedding_cache = {}
        self.fitted = False
        
    def fit(self, texts):
        """Fit the TF-IDF vectorizer on a corpus of texts"""
        print("Fitting TF-IDF vectorizer...")
        self.vectorizer.fit(texts)
        self.fitted = True
        return self
        
    def get_embedding(self, text):
        """Get embedding for a single text"""
        # Check cache first
        if text in self.embedding_cache:
            return self.embedding_cache[text]
            
        if not self.fitted:
            raise ValueError("TF-IDF vectorizer not fitted. Call fit() first.")
            
        # Get TF-IDF vector
        X = self.vectorizer.transform([text]).toarray()
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        
        # Project to lower dimension
        with torch.no_grad():
            embedding = self.projector(X_tensor)
            normalized = F.normalize(embedding, p=2, dim=1)
            
        result = normalized.cpu().numpy()[0]
        self.embedding_cache[text] = result
        
        return result
        
    def get_sentence_embeddings(self, text, max_sentences=16):
        """Get embeddings for sentences in text"""
        try:
            # Use simpler sentence splitting
            doc = self.nlp(text)
            sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
            
            # Limit number of sentences
            sentences = sentences[:max_sentences]
            
            # Get embeddings for sentences
            embeddings = []
            for sent in sentences:
                if sent:  # Skip empty sentences
                    embedding = self.get_embedding(sent)
                    embeddings.append(embedding)
                    
            # Handle empty text
            if not embeddings:
                return np.zeros((max_sentences, self.output_dim), dtype=np.float32)
                
            # Convert to numpy array
            embeddings = np.array(embeddings)
            
            # Pad if necessary
            if len(embeddings) < max_sentences:
                padding = np.zeros((max_sentences - len(embeddings), self.output_dim), dtype=np.float32)
                embeddings = np.vstack([embeddings, padding])
            
            return embeddings
            
        except Exception as e:
            print(f"Error in get_sentence_embeddings: {e}")
            # Return zero embeddings as fallback
            return np.zeros((max_sentences, self.output_dim), dtype=np.float32)
    
    def save(self, path):
        """Save model weights and config"""
        # Save vectorizer
        vectorizer_path = path + ".vectorizer"
        with open(vectorizer_path, 'wb') as f:
            import pickle
            pickle.dump(self.vectorizer, f)
        
        # Save projector
        projector_path = path + ".projector"
        torch.save(self.projector.state_dict(), projector_path)
        
        # Save config
        config_path = path + ".config"
        config = {
            'max_features': self.vectorizer.max_features,
            'output_dim': self.output_dim,
            'fitted': self.fitted
        }
        with open(config_path, 'w') as f:
            json.dump(config, f)
            
        return vectorizer_path, projector_path, config_path
        
    @classmethod
    def load(cls, path):
        """Load a saved model"""
        # Load config
        config_path = path + ".config"
        with open(config_path, 'r') as f:
            config = json.load(f)
            
        # Create instance
        instance = cls(
            max_features=config['max_features'],
            output_dim=config['output_dim']
        )
        
        # Load vectorizer
        vectorizer_path = path + ".vectorizer"
        with open(vectorizer_path, 'rb') as f:
            import pickle
            instance.vectorizer = pickle.load(f)
            
        # Load projector
        projector_path = path + ".projector"
        instance.projector.load_state_dict(torch.load(projector_path, map_location=device))
        
        instance.fitted = config['fitted']
        return instance


class LightDetector(nn.Module):
    def __init__(self, embed_dim=32, style_feature_dim=15, hidden_dim=64):
        super().__init__()
        self.embed_dim = embed_dim
        self.style_feature_dim = style_feature_dim
        self.hidden_dim = hidden_dim

        # Simple LSTM for sentence embeddings (much lighter than transformer)
        self.sentence_lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        # Processing stylometric features
        self.style_encoder = nn.Sequential(
            nn.Linear(style_feature_dim, hidden_dim // 2),
            nn.ReLU(),
        )

        # Full text embedding processor
        self.text_encoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim // 2),
            nn.ReLU(),
        )

        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, sentence_embeds, full_text_embed, style_features):
        batch_size = sentence_embeds.size(0)
        
        # Process sentence embeddings with LSTM
        lstm_out, (h_n, _) = self.sentence_lstm(sentence_embeds)
        
        # Get LSTM final hidden state and combine directions
        sentence_repr = torch.cat([h_n[0], h_n[1]], dim=1)

        # Process full-text embedding
        text_repr = self.text_encoder(full_text_embed)

        # Process style features
        style_repr = self.style_encoder(style_features)

        # Concatenate all features
        combined = torch.cat([sentence_repr, text_repr, style_repr], dim=1)
        output = self.classifier(combined)

        return output
        
    def save(self, path):
        """Save model weights and config"""
        save_dict = {
            'state_dict': self.state_dict(),
            'embed_dim': self.embed_dim,
            'style_feature_dim': self.style_feature_dim,
            'hidden_dim': self.hidden_dim
        }
        torch.save(save_dict, path)
        
    @classmethod
    def load(cls, path):
        """Load a saved model"""
        save_dict = torch.load(path, map_location=device)
        instance = cls(
            embed_dim=save_dict['embed_dim'],
            style_feature_dim=save_dict['style_feature_dim'],
            hidden_dim=save_dict['hidden_dim']
        )
        instance.load_state_dict(save_dict['state_dict'])
        return instance


# Custom collate function for batches
def custom_collate_fn(batch):
    if len(batch) == 0:
        return {}
    
    elem = batch[0]
    result = {}
    
    for key in elem.keys():
        if key == 'text':  # Handle text field (non-tensor)
            result[key] = [b[key] for b in batch]
            continue
            
        values = [b[key] for b in batch]
        
        if isinstance(values[0], torch.Tensor):
            try:
                result[key] = torch.stack(values, 0)
            except RuntimeError as e:
                print(f"ERROR stacking tensors for key '{key}': {e}")
                first_tensor = values[0]
                result[key] = first_tensor.repeat(len(batch), *[1 for _ in range(first_tensor.dim())])
        else:
            result[key] = values
            
    return result


def train_efficient(model, train_loader, val_loader, epochs=10, lr=0.001):
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, patience=2, threshold=0.005, min_lr=1e-6, verbose=True
    )
    
    criterion = nn.BCEWithLogitsLoss()

    # Accumulate gradients for larger effective batch size
    accum_steps = 2  # Reduced from original to speed up training

    best_val_auc = 0
    best_model = None
    lr_history = []

    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        preds, labels = [], []

        current_lr = optimizer.param_groups[0]['lr']
        lr_history.append(current_lr)
        
        for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")):
            try:
                # Forward pass (without mixed precision to simplify)
                output = model(
                    batch["sentence_embeds"].to(device),
                    batch["full_text_embeds"].to(device),
                    batch["style_features"].to(device),
                )
                loss = criterion(output.squeeze(), batch["label"].to(device))
                loss = loss / accum_steps

                # Backward pass
                loss.backward()

                # Gradient accumulation
                if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                    optimizer.step()
                    optimizer.zero_grad()

                running_loss += loss.item() * accum_steps
                sigmoid_preds = torch.sigmoid(output).squeeze().detach().cpu().numpy()
                preds.extend(sigmoid_preds)
                labels.extend(batch["label"].cpu().numpy())
            except Exception as e:
                print(f"Error in batch {i}: {e}")
                torch.cuda.empty_cache()
                optimizer.zero_grad()
                continue

        # Validation with less frequency to speed up training
        if (epoch + 1) % 2 == 0 or epoch == 0 or epoch == epochs - 1:
            model.eval()
            val_preds, val_labels = [], []

            with torch.no_grad():
                for batch in tqdm(val_loader, desc="Validation"):
                    try:
                        output = model(
                            batch["sentence_embeds"].to(device),
                            batch["full_text_embeds"].to(device),
                            batch["style_features"].to(device),
                        )
                        sigmoid_preds = torch.sigmoid(output).squeeze().cpu().numpy()
                        val_preds.extend(sigmoid_preds)
                        val_labels.extend(batch["label"].cpu().numpy())
                    except Exception as e:
                        print(f"Error in validation batch: {e}")
                        continue

            # Calculate metrics
            if len(val_preds) > 0 and len(val_labels) > 0:
                val_auc = roc_auc_score(val_labels, val_preds)
                val_preds_binary = [1 if p > 0.5 else 0 for p in val_preds]
                val_acc = accuracy_score(val_labels, val_preds_binary)
                val_f1 = f1_score(val_labels, val_preds_binary)

                # Update learning rate
                scheduler.step(val_auc)

                # Save best model
                if val_auc > best_val_auc:
                    best_val_auc = val_auc
                    best_model = copy.deepcopy(model.state_dict())
                    torch.save(best_model, os.path.join(KAGGLE_OUTPUT_DIR, "best_model.pt"))

                print(f"Epoch {epoch+1}/{epochs}:")
                print(f"  Learning Rate: {current_lr:.6f}")
                print(f"  Train Loss: {running_loss/max(1, len(train_loader)):.4f}")
                print(f"  Val AUC: {val_auc:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
            else:
                print(f"Epoch {epoch+1}/{epochs}: No valid predictions for validation")
    
    # Save learning rate history
    lr_history_path = os.path.join(KAGGLE_OUTPUT_DIR, "lr_history.json")
    with open(lr_history_path, 'w') as f:
        json.dump(lr_history, f)

    # Load best model
    if best_model is not None:
        model.load_state_dict(best_model)
    return model


class TextDataset(Dataset):
    def __init__(self, texts, labels, embedding_generator, style_extractor, max_sentences=16):
        self.texts = texts
        self.labels = labels
        self.embedding_generator = embedding_generator
        self.style_extractor = style_extractor
        self.max_sentences = max_sentences
        
        # Cache computations
        self.cache = {}
        self.feature_dim = style_extractor.get_feature_dim()
        
    def __len__(self):
        return len(self.texts)
        
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Check cache first
        if idx in self.cache:
            return self.cache[idx]
        
        try:    
            # Get sentence embeddings (with reduced max_sentences)
            sentence_embeds = self.embedding_generator.get_sentence_embeddings(
                text, max_sentences=self.max_sentences
            )
            
            # Get full text embedding
            full_text_embed = self.embedding_generator.get_embedding(text)
            
            # Get stylometric features
            style_features, _ = self.style_extractor.extract_features(text)
            
            # Verify dimensions
            if style_features.shape[0] != self.feature_dim:
                print(f"WARNING: Style feature dimension mismatch for idx {idx}")
                if style_features.shape[0] < self.feature_dim:
                    style_features = np.pad(style_features, (0, self.feature_dim - style_features.shape[0]))
                else:
                    style_features = style_features[:self.feature_dim]
            
            item = {
                "sentence_embeds": torch.tensor(sentence_embeds, dtype=torch.float32),
                "full_text_embeds": torch.tensor(full_text_embed, dtype=torch.float32),
                "style_features": torch.tensor(style_features, dtype=torch.float32),
                "label": torch.tensor(label, dtype=torch.float32),
                "text": text
            }
            
            self.cache[idx] = item
            return item
        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            # Return a dummy item with zero tensors
            dummy_item = {
                "sentence_embeds": torch.zeros((self.max_sentences, self.embedding_generator.output_dim), dtype=torch.float32),
                "full_text_embeds": torch.zeros(self.embedding_generator.output_dim, dtype=torch.float32),
                "style_features": torch.zeros(self.feature_dim, dtype=torch.float32),
                "label": torch.tensor(label, dtype=torch.float32),
                "text": text
            }
            return dummy_item


class LightAITextDetector:
    def __init__(self, model_path=None, embedding_generator_path=None, feature_keys_path=None):
        self.style_extractor = None
        self.embedding_generator = None
        self.model = None
        self.feature_keys = None
        
        if model_path and embedding_generator_path and feature_keys_path:
            self.load(model_path, embedding_generator_path, feature_keys_path)
    
    def initialize_components(self, max_features=5000, embed_dim=32, hidden_dim=64):
        """Initialize all components of the detector"""
        print("Initializing LightStyleFeatureExtractor...")
        self.style_extractor = LightStyleFeatureExtractor()
        
        print("Initializing LightEmbeddingGenerator...")
        self.embedding_generator = LightEmbeddingGenerator(max_features=max_features, output_dim=embed_dim)
        
        # Get feature dimension
        style_feature_dim = self.style_extractor.get_feature_dim()
        print(f"Style feature dimension: {style_feature_dim}")
        
        print("Initializing LightDetector model...")
        self.model = LightDetector(
            embed_dim=embed_dim,
            style_feature_dim=style_feature_dim,
            hidden_dim=hidden_dim
        ).to(device)
        
        # Get feature keys
        _, self.feature_keys = self.style_extractor.extract_features("Sample text.")
    
    def load_data(self, csv_path, test_size=0.2, val_size=0.1, max_samples=None):
        """Load and preprocess data from CSV file"""
        print(f"Loading data from {csv_path}...")
        df = pd.read_csv(csv_path)
        
        # Handle different column names
        if 'text' not in df.columns or 'ai' not in df.columns:
            # Try to infer columns
            text_col = None
            label_col = None
            
            # Look for text column
            text_candidates = ['text', 'content', 'document', 'passage']
            for col in text_candidates:
                if col in df.columns:
                    text_col = col
                    break
            
            # Look for label column
            label_candidates = ['ai', 'label', 'is_ai', 'generated', 'is_generated']
            for col in label_candidates:
                if col in df.columns:
                    label_col = col
                    break
            
            if text_col is None or label_col is None:
                print(f"Could not identify text and label columns. Available columns: {df.columns}")
                raise ValueError("CSV must contain 'text' and 'ai' columns")
            
            # Rename columns
            df = df.rename(columns={text_col: 'text', label_col: 'ai'})
        
        # Ensure text column is string type
        df['text'] = df['text'].astype(str)
        
        # Filter out very short texts
        df = df[df['text'].str.len() > 20].reset_index(drop=True)
        
        # Limit samples if specified
        if max_samples and len(df) > max_samples:
            df = df.sample(max_samples, random_state=42).reset_index(drop=True)
        
        # Split into train, validation, and test sets
        train_df, temp_df = train_test_split(df, test_size=test_size+val_size, random_state=42)
        val_df, test_df = train_test_split(temp_df, test_size=test_size/(test_size+val_size), random_state=42)
        
        print(f"Data split: Train={len(train_df)}, Validation={len(val_df)}, Test={len(test_df)}")
        
        return train_df, val_df, test_df
    
    def create_datasets(self, train_df, val_df, test_df, batch_size=32, max_sentences=16, num_workers=0):
        """Create PyTorch datasets and dataloaders"""
        print("Creating datasets...")
        
        # Fit the TF-IDF vectorizer on training data
        self.embedding_generator.fit(train_df['text'].tolist())
        
        train_dataset = TextDataset(
            train_df['text'].tolist(),
            train_df['ai'].tolist(),
            self.embedding_generator,
            self.style_extractor,
            max_sentences=max_sentences
        )
        
        val_dataset = TextDataset(
            val_df['text'].tolist(),
            val_df['ai'].tolist(),
            self.embedding_generator,
            self.style_extractor,
            max_sentences=max_sentences
        )
        
        test_dataset = TextDataset(
            test_df['text'].tolist(),
            test_df['ai'].tolist(),
            self.embedding_generator,
            self.style_extractor,
            max_sentences=max_sentences
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
            collate_fn=custom_collate_fn
        )
        
        return train_loader, val_loader, test_loader
    
    def train(self, train_loader, val_loader, epochs=10, lr=0.001):
        """Train the model"""
        print("Training model...")
        self.model = train_efficient(
            self.model,
            train_loader,
            val_loader,
            epochs=epochs,
            lr=lr
        )
    
    def evaluate(self, test_loader):
        """Evaluate the model on test data"""
        print("Evaluating model...")
        self.model.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Testing"):
                try:
                    output = self.model(
                        batch["sentence_embeds"].to(device),
                        batch["full_text_embeds"].to(device),
                        batch["style_features"].to(device),
                    )
                    # Apply sigmoid to get probabilities
                    sigmoid_preds = torch.sigmoid(output).squeeze().cpu().numpy()
                    all_preds.extend(sigmoid_preds)
                    all_labels.extend(batch["label"].cpu().numpy())
                except Exception as e:
                    print(f"Error in test batch: {e}")
                    continue
        
        # Calculate metrics
        if len(all_preds) > 0 and len(all_labels) > 0:
            all_preds_binary = [1 if p > 0.5 else 0 for p in all_preds]
            
            results = {
                "accuracy": accuracy_score(all_labels, all_preds_binary),
                "auc": roc_auc_score(all_labels, all_preds),
                "f1": f1_score(all_labels, all_preds_binary),
                "precision": precision_score(all_labels, all_preds_binary),
                "recall": recall_score(all_labels, all_preds_binary)
            }
            
            print("Test Results:")
            for metric, value in results.items():
                print(f"  {metric.capitalize()}: {value:.4f}")
                
            return results
        else:
            print("No valid predictions for testing")
            return None
    
    def predict_proba(self, text):
        """
        Predict the probability of text being AI-generated
        """
        # Ensure components are loaded
        if self.model is None or self.embedding_generator is None or self.style_extractor is None:
            raise ValueError("Model components not initialized. Load a trained model first.")
        
        self.model.eval()
        with torch.no_grad():
            # Get sentence embeddings
            sentence_embeds = self.embedding_generator.get_sentence_embeddings(text)
            sentence_embeds = torch.tensor(sentence_embeds, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get full text embedding
            full_text_embed = self.embedding_generator.get_embedding(text)
            full_text_embed = torch.tensor(full_text_embed, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get stylometric features
            style_features, _ = self.style_extractor.extract_features(text)
            style_features = torch.tensor(style_features, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get prediction
            logits = self.model(sentence_embeds, full_text_embed, style_features)
            # Apply sigmoid to get probability
            probability = torch.sigmoid(logits).item()
        
        return probability
    
    def save(self, output_dir=KAGGLE_OUTPUT_DIR):
        """Save all components necessary for inference"""
        if self.model is None or self.embedding_generator is None or self.style_extractor is None:
            raise ValueError("Model components not initialized.")
            
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Save model
        model_path = os.path.join(output_dir, "light_detector.pt")
        self.model.save(model_path)
        print(f"Model saved to {model_path}")
        
        # Save embedding generator
        embedding_generator_path = os.path.join(output_dir, "light_embedding_generator")
        self.embedding_generator.save(embedding_generator_path)
        print(f"Embedding generator saved to {embedding_generator_path}")
        
        # Save feature keys
        feature_keys_path = os.path.join(output_dir, "light_feature_keys.json")
        with open(feature_keys_path, 'w') as f:
            json.dump(self.feature_keys, f)
        print(f"Feature keys saved to {feature_keys_path}")
        
        return model_path, embedding_generator_path, feature_keys_path
    
    def load(self, model_path, embedding_generator_path, feature_keys_path):
        """Load all components necessary for inference"""
        # Load model
        print(f"Loading model from {model_path}...")
        self.model = LightDetector.load(model_path)
        self.model.to(device)
        
        # Load embedding generator
        print(f"Loading embedding generator from {embedding_generator_path}...")
        self.embedding_generator = LightEmbeddingGenerator.load(embedding_generator_path)
        
        # Load feature keys
        print(f"Loading feature keys from {feature_keys_path}...")
        with open(feature_keys_path, 'r') as f:
            self.feature_keys = json.load(f)
        
        # Initialize StyleFeatureExtractor
        print("Initializing LightStyleFeatureExtractor...")
        self.style_extractor = LightStyleFeatureExtractor()


def train_detector(csv_path, batch_size=32, epochs=8, lr=0.001, max_features=5000, embed_dim=32, hidden_dim=64, max_sentences=16):
    """Function to train the detector"""
    # Initialize detector
    detector = LightAITextDetector()
    detector.initialize_components(max_features=max_features, embed_dim=embed_dim, hidden_dim=hidden_dim)
    
    # Load and preprocess data
    train_df, val_df, test_df = detector.load_data(csv_path)
    
    # Create datasets and dataloaders
    train_loader, val_loader, test_loader = detector.create_datasets(
        train_df, val_df, test_df, 
        batch_size=batch_size,
        max_sentences=max_sentences,
        num_workers=0  # Use 0 workers for Kaggle
    )
    
    # Train model
    detector.train(train_loader, val_loader, epochs=epochs, lr=lr)
    
    # Evaluate model
    results = detector.evaluate(test_loader)
    
    # Save model for future inference
    model_path, embedding_generator_path, feature_keys_path = detector.save()
    
    # Save results to a file
    if results:
        results_path = os.path.join(KAGGLE_OUTPUT_DIR, "results.json")
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"Results saved to {results_path}")
    
    print(f"Model saved to {model_path}")
    
    return detector


def predict_single_text(text, model_dir=KAGGLE_OUTPUT_DIR):
    """
    Predict the probability of a single text being AI-generated
    
    Args:
        text (str): Text to classify
        model_dir (str): Directory containing saved model files
        
    Returns:
        dict: Prediction results including probability, prediction label, and confidence
    """
    # Construct paths to saved files
    model_path = os.path.join(model_dir, "light_detector.pt")
    embedding_generator_path = os.path.join(model_dir, "light_embedding_generator")
    feature_keys_path = os.path.join(model_dir, "light_feature_keys.json")
    
    # Check if files exist
    if not os.path.exists(model_path) or not os.path.exists(feature_keys_path):
        raise FileNotFoundError(f"Model files not found. Please train the model first.")
    
    # Load detector
    detector = LightAITextDetector(model_path, embedding_generator_path, feature_keys_path)
    
    # Get prediction
    probability = detector.predict_proba(text)
    
    result = {
        "probability_ai_generated": probability,
        "prediction": "AI-generated" if probability > 0.5 else "Human-written",
        "confidence": max(probability, 1 - probability)
    }
    
    return result

# ============================================================================
# Configuration
# ============================================================================

# Path to your dataset (CSV file with 'text' and 'ai' columns)
# 'ai' column should contain 0 for human-written, 1 for AI-generated
CSV_PATH = "/kaggle/input/combined-dataset-ai-human/combined_data.csv"

# Training parameters (optimized for faster training)
BATCH_SIZE = 128       # Larger batch size for faster epochs
EPOCHS = 25            # Reduced number of epochs
LEARNING_RATE = 0.001  # Slightly higher learning rate

# Model parameters (lighter model)
MAX_FEATURES = 5000    # Number of TF-IDF features
EMBED_DIM = 32         # Reduced embedding dimension
HIDDEN_DIM = 64        # Reduced hidden dimension
MAX_SENTENCES = 16     # Reduced from 32 to 16 maximum sentences

# Example text for testing
EXAMPLE_TEXT = """
This is a sample text to demonstrate the AI text detector.
It can classify whether this text was written by a human or generated by an AI.
The detector uses stylometric analysis and TF-IDF embeddings to make its determination.
"""

# Mode: "train", "predict", or "both"
MODE = "both"

# ============================================================================
# Main Execution
# ============================================================================

def main():
    global CSV_PATH
    # Print configuration
    print("=" * 80)
    print("Lightweight AI Text Detector - Configuration")
    print("=" * 80)
    print(f"CSV Path: {CSV_PATH}")
    print(f"Batch Size: {BATCH_SIZE}")
    print(f"Epochs: {EPOCHS}")
    print(f"Learning Rate: {LEARNING_RATE}")
    print(f"Max TF-IDF Features: {MAX_FEATURES}")
    print(f"Embedding Dimension: {EMBED_DIM}")
    print(f"Hidden Dimension: {HIDDEN_DIM}")
    print(f"Maximum Sentences: {MAX_SENTENCES}")
    print(f"Mode: {MODE}")
    print("=" * 80)
    
    # Install necessary packages if needed
    try:
        import spacy
    except ImportError:
        print("Installing required packages...")
        import subprocess
        subprocess.call(["pip", "install", "spacy"])
        subprocess.call(["python", "-m", "spacy", "download", "en_core_web_sm"])
        print("Packages installed.")
    
    # Check if model already exists
    model_path = os.path.join(KAGGLE_OUTPUT_DIR, "light_detector.pt")
    embedding_generator_path = os.path.join(KAGGLE_OUTPUT_DIR, "light_embedding_generator")
    feature_keys_path = os.path.join(KAGGLE_OUTPUT_DIR, "light_feature_keys.json")
    
    # Use auto-discovery to find CSV files if path not specified
    if CSV_PATH == "/kaggle/input/combined-dataset-ai-human/combined_data.csv":
        import glob
        csv_files = glob.glob("/kaggle/input/combined-dataset-ai-human/combined_data.csv")
        if csv_files:
            CSV_PATH = csv_files[0]
            print(f"Auto-discovered CSV file: {CSV_PATH}")
        else:
            print("No CSV files found in /kaggle/input/. Using a small synthetic dataset for demonstration.")
            # Create a small synthetic dataset
            synthetic_df = pd.DataFrame({
                'text': [
                    "This is a human written text. It has some style and personality.",
                    "The AI generated this text. It's somewhat different from human text.",
                    "Humans write with varying sentence lengths and unique expressions.",
                    "AI models produce text with patterns that can be detected by our system."
                ],
                'ai': [0, 1, 0, 1]
            })
            CSV_PATH = os.path.join(KAGGLE_OUTPUT_DIR, "synthetic_dataset.csv")
            synthetic_df.to_csv(CSV_PATH, index=False)
            print(f"Created synthetic dataset at {CSV_PATH}")
    
    detector = None
    
    # Training
    if MODE in ["train", "both"]:
        print("\nStarting training...")
        detector = train_detector(
            csv_path=CSV_PATH,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            lr=LEARNING_RATE,
            max_features=MAX_FEATURES,
            embed_dim=EMBED_DIM,
            hidden_dim=HIDDEN_DIM,
            max_sentences=MAX_SENTENCES
        )
    
    # Prediction
    if MODE in ["predict", "both"]:
        if all(os.path.exists(path) for path in [model_path, feature_keys_path]):
            print("\nPredicting on example text:")
            print("-" * 50)
            print(EXAMPLE_TEXT)
            print("-" * 50)
            
            result = predict_single_text(EXAMPLE_TEXT)
            
            print("\nPrediction Results:")
            print(f"Classification: {result['prediction']}")
            print(f"Probability of being AI-generated: {result['probability_ai_generated']:.4f}")
            print(f"Confidence: {result['confidence']:.4f}")
            
            # Example of how to use for custom text
            print("\nTo predict on your own text, use:")
            print("result = predict_single_text(your_text)")
        else:
            print("\nNo trained model found. Please run with MODE='train' first.")

# Run the main function
if __name__ == "__main__":
    main()

Using device: cuda
Lightweight AI Text Detector - Configuration
CSV Path: /kaggle/input/combined-dataset-ai-human/combined_data.csv
Batch Size: 128
Epochs: 25
Learning Rate: 0.001
Max TF-IDF Features: 5000
Embedding Dimension: 32
Hidden Dimension: 64
Maximum Sentences: 16
Mode: both
Auto-discovered CSV file: /kaggle/input/combined-dataset-ai-human/combined_data.csv

Starting training...
Initializing LightStyleFeatureExtractor...
Initializing LightEmbeddingGenerator...
Style feature dimension: 17
Initializing LightDetector model...
Loading data from /kaggle/input/combined-dataset-ai-human/combined_data.csv...
Data split: Train=32918, Validation=4702, Test=9406
Creating datasets...
Fitting TF-IDF vectorizer...
Training model...


Epoch 1/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 1/25:
  Learning Rate: 0.001000
  Train Loss: 0.5182
  Val AUC: 0.8795, Acc: 0.8086, F1: 0.8181


Epoch 2/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 2/25:
  Learning Rate: 0.001000
  Train Loss: 0.4325
  Val AUC: 0.9003, Acc: 0.8316, F1: 0.8295


Epoch 3/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 4/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 4/25:
  Learning Rate: 0.001000
  Train Loss: 0.4095
  Val AUC: 0.9031, Acc: 0.8290, F1: 0.8390


Epoch 5/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 6/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 6/25:
  Learning Rate: 0.001000
  Train Loss: 0.3907
  Val AUC: 0.9092, Acc: 0.8365, F1: 0.8367


Epoch 7/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 8/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 8/25:
  Learning Rate: 0.001000
  Train Loss: 0.3858
  Val AUC: 0.9119, Acc: 0.8409, F1: 0.8465


Epoch 9/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 10/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 10/25:
  Learning Rate: 0.001000
  Train Loss: 0.3789
  Val AUC: 0.9136, Acc: 0.8352, F1: 0.8293


Epoch 11/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 12/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 12/25:
  Learning Rate: 0.001000
  Train Loss: 0.3731
  Val AUC: 0.9162, Acc: 0.8435, F1: 0.8503


Epoch 13/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 14/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 14/25:
  Learning Rate: 0.001000
  Train Loss: 0.3697
  Val AUC: 0.9178, Acc: 0.8450, F1: 0.8451


Epoch 15/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 16/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 16/25:
  Learning Rate: 0.001000
  Train Loss: 0.3686
  Val AUC: 0.9171, Acc: 0.8467, F1: 0.8479


Epoch 17/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 18/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 18/25:
  Learning Rate: 0.001000
  Train Loss: 0.3639
  Val AUC: 0.9209, Acc: 0.8469, F1: 0.8481


Epoch 19/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 20/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 20/25:
  Learning Rate: 0.001000
  Train Loss: 0.3647
  Val AUC: 0.9196, Acc: 0.8473, F1: 0.8464


Epoch 21/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 22/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 22/25:
  Learning Rate: 0.001000
  Train Loss: 0.3555
  Val AUC: 0.9212, Acc: 0.8479, F1: 0.8477


Epoch 23/25:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch 24/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 24/25:
  Learning Rate: 0.001000
  Train Loss: 0.3508
  Val AUC: 0.9238, Acc: 0.8473, F1: 0.8436


Epoch 25/25:   0%|          | 0/258 [00:00<?, ?it/s]

Validation:   0%|          | 0/37 [00:00<?, ?it/s]

Epoch 25/25:
  Learning Rate: 0.000500
  Train Loss: 0.3485
  Val AUC: 0.9239, Acc: 0.8541, F1: 0.8556
Evaluating model...


Testing:   0%|          | 0/74 [00:00<?, ?it/s]

Test Results:
  Accuracy: 0.8448
  Auc: 0.9190
  F1: 0.8460
  Precision: 0.8359
  Recall: 0.8563
Model saved to /kaggle/working/light_detector.pt
Embedding generator saved to /kaggle/working/light_embedding_generator
Feature keys saved to /kaggle/working/light_feature_keys.json
Results saved to /kaggle/working/results.json
Model saved to /kaggle/working/light_detector.pt

Predicting on example text:
--------------------------------------------------

This is a sample text to demonstrate the AI text detector.
It can classify whether this text was written by a human or generated by an AI.
The detector uses stylometric analysis and TF-IDF embeddings to make its determination.

--------------------------------------------------
Loading model from /kaggle/working/light_detector.pt...
Loading embedding generator from /kaggle/working/light_embedding_generator...
Loading feature keys from /kaggle/working/light_feature_keys.json...
Initializing LightStyleFeatureExtractor...

Prediction Results: