In [None]:
# GAN-based AI Text Detection
# Complete implementation with PyTorch

import os
import re
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, lr_scheduler
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
from tqdm.notebook import tqdm
import spacy
import random
from collections import Counter
import nltk
from nltk.tokenize import word_tokenize
from sklearn.manifold import TSNE
from transformers import AutoTokenizer, AutoModel
import warnings
import joblib

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Download required NLTK resources
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

# Load spaCy model for NLP tasks
try:
    nlp = spacy.load('en_core_web_sm')
except OSError:
    print("Downloading spaCy model...")
    spacy.cli.download('en_core_web_sm')
    nlp = spacy.load('en_core_web_sm')

# Configuration class for hyperparameters
class Config:
    def __init__(self):
        # Data processing
        self.max_length = 512
        self.masking_ratio = 0.15
        self.min_text_length = 10
        
        # Model architecture
        self.base_model_name = "distilroberta-base"
        self.embedding_dim = 768
        self.generator_hidden_dims = [512, 256, 128]
        self.discriminator_hidden_dims = [128, 64, 32]
        self.style_feature_dim = 32
        self.dropout = 0.2
        
        # Training
        self.batch_size = 16
        self.num_epochs = 10
        self.gen_lr = 2e-4
        self.disc_lr = 1e-4
        self.disc_steps = 5  # Train discriminator 5 times for each generator step
        self.scheduler_patience = 2
        self.scheduler_factor = 0.5
        
        # Paths
        self.model_save_path = "models/"
        self.best_model_path = os.path.join(self.model_save_path, "best_gan_detector.pt")
        self.stylometric_features_path = os.path.join(self.model_save_path, "stylometric_processor.joblib")
        
        # Create directory if it doesn't exist
        os.makedirs(self.model_save_path, exist_ok=True)

# Initialize default config
config = Config()

# 1. Dataset Preprocessing

class TextPreprocessor:
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
        # Define content words POS tags to mask
        self.content_pos = {'NOUN', 'ADJ', 'PROPN'}
        # Store stylometric feature statistics
        self.feature_stats = None
    
    def normalize_text(self, text):
        """Normalize text by converting to lowercase and standardizing characters"""
        if not isinstance(text, str) or len(text) == 0:
            return ""
            
        # Convert to lowercase and standardize whitespace
        text = text.lower().strip()
        # Standardize Unicode characters
        text = re.sub(r'\s+', ' ', text)
        # Remove special characters while preserving punctuation and structure
        text = re.sub(r'[^\w\s\.,;:!?\'"-]', '', text)
        return text
    
    def quality_filter(self, text):
        """Filter out extremely short texts or texts with abnormal character distributions"""
        if not isinstance(text, str):
            return False
            
        # Check minimum length
        tokens = word_tokenize(text)
        if len(tokens) < self.config.min_text_length:
            return False
            
        # Check for abnormal character distribution
        char_counts = Counter(text.lower())
        alpha_ratio = sum(char_counts[c] for c in 'abcdefghijklmnopqrstuvwxyz') / len(text) if len(text) > 0 else 0
        if alpha_ratio < 0.5:  # Less than 50% alphabetic characters
            return False
            
        return True
    
    def extract_sentences(self, text):
        """Use spaCy to divide text into sentences"""
        doc = nlp(text)
        return [sent.text for sent in doc.sents]
    
    def strategic_masking(self, text):
        """Randomly mask content words while preserving function words"""
        doc = nlp(text)
        masked_tokens = []
        
        for token in doc:
            # Skip punctuation
            if token.is_punct:
                masked_tokens.append(token.text)
                continue
                
            # Determine if token should be masked
            if token.pos_ in self.content_pos and random.random() < self.config.masking_ratio:
                masked_tokens.append("[MASK]")
            else:
                masked_tokens.append(token.text)
                
        return " ".join(masked_tokens)
    
    def extract_stylometric_features(self, text):
        """Extract stylometric features from text"""
        if not isinstance(text, str) or len(text) == 0:
            return np.zeros(5)
            
        doc = nlp(text)
        sentences = list(doc.sents)
        
        # Skip very short texts
        if len(sentences) == 0:
            return np.zeros(5)
        
        # Token and type counts
        tokens = [token.text.lower() for token in doc if not token.is_punct]
        types = set(tokens)
        
        # Type-token ratio
        ttr = len(types) / len(tokens) if tokens else 0
        
        # Sentence length distribution
        sent_lengths = [len(sent) for sent in sentences]
        avg_sent_length = np.mean(sent_lengths) if sent_lengths else 0
        std_sent_length = np.std(sent_lengths) if len(sent_lengths) > 1 else 0
        
        # POS tag distributions
        pos_counts = Counter([token.pos_ for token in doc])
        noun_ratio = pos_counts.get('NOUN', 0) / len(doc) if len(doc) > 0 else 0
        func_word_ratio = sum(pos_counts.get(pos, 0) for pos in ['DET', 'ADP', 'CONJ', 'CCONJ', 'SCONJ']) / len(doc) if len(doc) > 0 else 0
        
        # Create feature vector
        features = np.array([
            ttr,
            avg_sent_length,
            std_sent_length,
            noun_ratio,
            func_word_ratio
        ])
        
        return features
    
    def fit_transform_stylometric(self, texts):
        """Fit and transform stylometric features with normalization"""
        features = []
        for text in tqdm(texts, desc="Extracting stylometric features"):
            features.append(self.extract_stylometric_features(text))
        
        features_array = np.array(features)
        
        # Calculate mean and std for normalization
        self.feature_stats = {
            'mean': np.mean(features_array, axis=0),
            'std': np.std(features_array, axis=0)
        }
        
        # Avoid division by zero
        self.feature_stats['std'] = np.where(self.feature_stats['std'] == 0, 1, self.feature_stats['std'])
        
        # Normalize
        normalized_features = (features_array - self.feature_stats['mean']) / self.feature_stats['std']
        
        return normalized_features
    
    def transform_stylometric(self, texts):
        """Transform stylometric features using pre-fitted stats"""
        if self.feature_stats is None:
            raise ValueError("Stylometric feature processor has not been fitted yet")
        
        features = []
        for text in tqdm(texts, desc="Extracting stylometric features"):
            features.append(self.extract_stylometric_features(text))
        
        features_array = np.array(features)
        
        # Normalize using fitted stats
        normalized_features = (features_array - self.feature_stats['mean']) / self.feature_stats['std']
        
        return normalized_features
    
    def preprocess_dataset(self, df, fit_stylometric=True):
        """Preprocess entire dataset"""
        print("Preprocessing dataset...")
        
        # Ensure text column exists
        text_col = 'text' if 'text' in df.columns else df.columns[0]
        label_col = 'label' if 'label' in df.columns else df.columns[1] if len(df.columns) > 1 else None
        
        # Normalize and filter texts
        texts = []
        labels = []
        
        for i, row in tqdm(df.iterrows(), total=len(df), desc="Cleaning texts"):
            text = row[text_col]
            label = row[label_col] if label_col else 0
            
            normalized_text = self.normalize_text(text)
            if self.quality_filter(normalized_text):
                texts.append(normalized_text)
                labels.append(label)
        
        # Apply masking to texts
        masked_texts = [self.strategic_masking(text) for text in tqdm(texts, desc="Applying masking")]
        
        # Extract stylometric features
        if fit_stylometric:
            stylometric_features = self.fit_transform_stylometric(texts)
        else:
            stylometric_features = self.transform_stylometric(texts)
        
        # Save preprocessor for later use
        joblib.dump(self, os.path.join(self.config.model_save_path, "text_preprocessor.joblib"))
        
        return texts, masked_texts, np.array(labels), stylometric_features

# 2. Dataset Class

class TextDataset(Dataset):
    def __init__(self, texts, masked_texts, labels, stylometric_features, tokenizer, max_length):
        self.texts = texts
        self.masked_texts = masked_texts
        self.labels = labels
        self.stylometric_features = stylometric_features
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.masked_texts[idx]
        label = self.labels[idx]
        stylometric = self.stylometric_features[idx]
        
        # Tokenize text
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Get input_ids and attention_mask
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'stylometric': torch.tensor(stylometric, dtype=torch.float),
            'label': torch.tensor(label, dtype=torch.long)
        }

def prepare_dataloaders(texts, masked_texts, labels, stylometric_features, config):
    """Split data and create dataloaders"""
    # Split data into train, validation, and test sets
    train_texts, temp_texts, train_masked, temp_masked, train_labels, temp_labels, train_style, temp_style = train_test_split(
        texts, masked_texts, labels, stylometric_features, test_size=0.3, stratify=labels, random_state=SEED
    )
    
    val_texts, test_texts, val_masked, test_masked, val_labels, test_labels, val_style, test_style = train_test_split(
        temp_texts, temp_masked, temp_labels, temp_style, test_size=0.5, stratify=temp_labels, random_state=SEED
    )
    
    # Create datasets
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
    
    train_dataset = TextDataset(
        train_texts, train_masked, train_labels, train_style, tokenizer, config.max_length
    )
    
    val_dataset = TextDataset(
        val_texts, val_masked, val_labels, val_style, tokenizer, config.max_length
    )
    
    test_dataset = TextDataset(
        test_texts, test_masked, test_labels, test_style, tokenizer, config.max_length
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=config.batch_size
    )
    
    return train_loader, val_loader, test_loader

# 3. Model Architecture

class BaseEmbeddingModel(nn.Module):
    def __init__(self, model_name):
        super(BaseEmbeddingModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        
        # Freeze parameters to prevent fine-tuning
        for param in self.model.parameters():
            param.requires_grad = False
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Get CLS token embedding (first token) and sequence embedding
        cls_embedding = outputs.last_hidden_state[:, 0, :]
        
        return cls_embedding

class Generator(nn.Module):
    def __init__(self, config):
        super(Generator, self).__init__()
        self.config = config
        
        # Create layers with residual connections
        layers = []
        input_dim = config.embedding_dim
        
        for i, dim in enumerate(config.generator_hidden_dims):
            layers.append(nn.Linear(input_dim, dim))
            layers.append(nn.LayerNorm(dim))
            layers.append(nn.LeakyReLU(0.2))
            layers.append(nn.Dropout(config.dropout))
            input_dim = dim
        
        self.layers = nn.ModuleList(layers)
        self.output_layer = nn.Linear(input_dim, config.generator_hidden_dims[-1])
    
    def forward(self, x):
        residual = None
        
        for i in range(0, len(self.layers), 4):
            if residual is not None and residual.size() == x.size():
                x = x + residual
            
            # Store residual for next block
            residual = x
            
            # Apply layer block
            x = self.layers[i](x)      # Linear
            x = self.layers[i+1](x)    # LayerNorm
            x = self.layers[i+2](x)    # LeakyReLU
            x = self.layers[i+3](x)    # Dropout
        
        # Final output
        x = self.output_layer(x)
        
        return x

class StyleFeatureProcessor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(StyleFeatureProcessor, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.fc2(x)
        return x

class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.scale = torch.sqrt(torch.tensor(input_dim, dtype=torch.float32))
    
    def forward(self, x):
        # x shape: (batch_size, input_dim)
        x = x.unsqueeze(1)  # (batch_size, 1, input_dim)
        
        # Compute attention scores
        q = self.query(x)  # (batch_size, 1, input_dim)
        k = self.key(x)    # (batch_size, 1, input_dim)
        v = self.value(x)  # (batch_size, 1, input_dim)
        
        # Attention weights
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale  # (batch_size, 1, 1)
        attention = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        output = torch.matmul(attention, v)  # (batch_size, 1, input_dim)
        output = output.squeeze(1)  # (batch_size, input_dim)
        
        return output

class Discriminator(nn.Module):
    def __init__(self, config):
        super(Discriminator, self).__init__()
        self.config = config
        
        # Self-attention layer
        self.attention = SelfAttention(config.generator_hidden_dims[-1])
        
        # Process stylometric features
        self.style_processor = StyleFeatureProcessor(
            5,  # Number of stylometric features
            config.style_feature_dim,
            config.style_feature_dim
        )
        
        # Discriminator layers
        input_dim = config.generator_hidden_dims[-1] + config.style_feature_dim
        layers = []
        
        for hidden_dim in config.discriminator_hidden_dims:
            layers.append(nn.Linear(input_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.LeakyReLU(0.2))
            layers.append(nn.Dropout(config.dropout))
            input_dim = hidden_dim
        
        self.layers = nn.ModuleList(layers)
        
        # Output layer
        self.output_layer = nn.Linear(input_dim, 1)
    
    def forward(self, gen_features, stylometric_features):
        # Apply attention to generator features
        attended_features = self.attention(gen_features)
        
        # Process stylometric features
        style_features = self.style_processor(stylometric_features)
        
        # Concatenate features
        x = torch.cat([attended_features, style_features], dim=1)
        
        # Apply discriminator layers
        for i in range(0, len(self.layers), 4):
            x = self.layers[i](x)      # Linear
            x = self.layers[i+1](x)    # BatchNorm
            x = self.layers[i+2](x)    # LeakyReLU
            x = self.layers[i+3](x)    # Dropout
        
        # Output layer
        x = self.output_layer(x)
        
        return torch.sigmoid(x)

class GANTextDetector(nn.Module):
    def __init__(self, config):
        super(GANTextDetector, self).__init__()
        self.config = config
        
        # Initialize components
        self.embedding_model = BaseEmbeddingModel(config.base_model_name)
        self.generator = Generator(config)
        self.discriminator = Discriminator(config)
    
    def forward(self, input_ids, attention_mask, stylometric_features):
        # Get embeddings
        embeddings = self.embedding_model(input_ids, attention_mask)
        
        # Generate features
        gen_features = self.generator(embeddings)
        
        # Discriminate
        predictions = self.discriminator(gen_features, stylometric_features)
        
        return predictions, gen_features
    
    def get_embeddings(self, input_ids, attention_mask, stylometric_features):
        """Get embedding features for visualization or analysis"""
        with torch.no_grad():
            embeddings = self.embedding_model(input_ids, attention_mask)
            gen_features = self.generator(embeddings)
        
        return gen_features

# 4. Training Functions

def train_discriminator(model, batch, criterion, optimizer, device):
    """Train discriminator on one batch"""
    # Get batch data
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    stylometric = batch['stylometric'].to(device)
    labels = batch['label'].to(device).float().unsqueeze(1)
    
    # Zero gradients
    optimizer.zero_grad()
    
    # Forward pass
    predictions, _ = model(input_ids, attention_mask, stylometric)
    
    # Calculate loss
    loss = criterion(predictions, labels)
    
    # Backward pass
    loss.backward()
    
    # Update parameters
    optimizer.step()
    
    return loss.item()

def train_generator(model, batch, criterion, optimizer, device):
    """Train generator on one batch"""
    # Get batch data
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    stylometric = batch['stylometric'].to(device)
    labels = batch['label'].to(device).float().unsqueeze(1)
    
    # Invert labels for adversarial training
    inverted_labels = 1 - labels
    
    # Zero gradients
    optimizer.zero_grad()
    
    # Forward pass (only train generator)
    for param in model.discriminator.parameters():
        param.requires_grad = False
    
    predictions, _ = model(input_ids, attention_mask, stylometric)
    
    # Calculate adversarial loss (fool the discriminator)
    loss = criterion(predictions, inverted_labels)
    
    # Backward pass
    loss.backward()
    
    # Update generator parameters
    optimizer.step()
    
    # Re-enable discriminator gradients
    for param in model.discriminator.parameters():
        param.requires_grad = True
    
    return loss.item()

def validate(model, val_loader, criterion, device):
    """Validate model on validation set"""
    model.eval()
    val_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            # Get batch data
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            stylometric = batch['stylometric'].to(device)
            labels = batch['label'].to(device).float().unsqueeze(1)
            
            # Forward pass
            predictions, _ = model(input_ids, attention_mask, stylometric)
            
            # Calculate loss
            loss = criterion(predictions, labels)
            val_loss += loss.item()
            
            # Store predictions and labels for metrics
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    val_loss /= len(val_loader)
    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()
    
    accuracy = accuracy_score(all_labels, (all_preds > 0.5).astype(int))
    auc = roc_auc_score(all_labels, all_preds)
    f1 = f1_score(all_labels, (all_preds > 0.5).astype(int))
    
    return val_loss, accuracy, auc, f1

def train_gan_detector(model, train_loader, val_loader, config, device):
    """Train GAN text detector with alternating GAN training"""
    # Initialize criterion and optimizers
    criterion = nn.BCELoss()
    discriminator_optimizer = Adam(
        filter(lambda p: p.requires_grad, model.discriminator.parameters()),
        lr=config.disc_lr
    )
    generator_optimizer = Adam(
        model.generator.parameters(),
        lr=config.gen_lr
    )
    
    # Learning rate schedulers
    disc_scheduler = lr_scheduler.ReduceLROnPlateau(
        discriminator_optimizer, 
        mode='min', 
        factor=config.scheduler_factor,
        patience=config.scheduler_patience
    )
    
    gen_scheduler = lr_scheduler.ReduceLROnPlateau(
        generator_optimizer, 
        mode='min', 
        factor=config.scheduler_factor,
        patience=config.scheduler_patience
    )
    
    # Training loop
    best_val_auc = 0
    train_losses = {'disc': [], 'gen': []}
    val_metrics = {'loss': [], 'acc': [], 'auc': [], 'f1': []}
    
    for epoch in range(config.num_epochs):
        model.train()
        disc_loss_epoch = 0
        gen_loss_epoch = 0
        disc_steps = 0
        gen_steps = 0
        
        # Track progress with tqdm
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
        
        for i, batch in enumerate(progress_bar):
            # Train discriminator
            if i % (config.disc_steps + 1) != config.disc_steps:
                disc_loss = train_discriminator(
                    model, batch, criterion, discriminator_optimizer, device
                )
                disc_loss_epoch += disc_loss
                disc_steps += 1
                
                # Update progress bar
                progress_bar.set_postfix({
                    'D_loss': f"{disc_loss:.4f}",
                    'G_loss': f"{gen_loss_epoch/max(1, gen_steps):.4f}"
                })
            
            # Train generator
            else:
                gen_loss = train_generator(
                    model, batch, criterion, generator_optimizer, device
                )
                gen_loss_epoch += gen_loss
                gen_steps += 1
                
                # Update progress bar
                progress_bar.set_postfix({
                    'D_loss': f"{disc_loss_epoch/max(1, disc_steps):.4f}",
                    'G_loss': f"{gen_loss:.4f}"
                })
        
        # Calculate average losses
        avg_disc_loss = disc_loss_epoch / max(1, disc_steps)
        avg_gen_loss = gen_loss_epoch / max(1, gen_steps)
        
        # Store training losses
        train_losses['disc'].append(avg_disc_loss)
        train_losses['gen'].append(avg_gen_loss)
        
        # Validate model
        val_loss, val_acc, val_auc, val_f1 = validate(
            model, val_loader, criterion, device
        )
        
        # Store validation metrics
        val_metrics['loss'].append(val_loss)
        val_metrics['acc'].append(val_acc)
        val_metrics['auc'].append(val_auc)
        val_metrics['f1'].append(val_f1)
        
        # Update learning rate schedulers
        disc_scheduler.step(val_loss)
        gen_scheduler.step(val_loss)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{config.num_epochs}:")
        print(f"  Train - Disc Loss: {avg_disc_loss:.4f}, Gen Loss: {avg_gen_loss:.4f}")
        print(f"  Val - Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, AUC: {val_auc:.4f}, F1: {val_f1:.4f}")
        
        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            print(f"  New best model with AUC: {val_auc:.4f}")
            
            # Save model
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': config,
                'epoch': epoch,
                'val_auc': val_auc,
                'val_acc': val_acc,
                'val_f1': val_f1
            }, config.best_model_path)
    
    # Plot training progress
    plt.figure(figsize=(15, 5))
    
    # Plot training losses
    plt.subplot(1, 2, 1)
    plt.plot(train_losses['disc'], label='Discriminator')
    plt.plot(train_losses['gen'], label='Generator')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Losses')
    plt.legend()
    
    # Plot validation metrics
    plt.subplot(1, 2, 2)
    plt.plot(val_metrics['loss'], label='Loss')
    plt.plot(val_metrics['acc'], label='Accuracy')
    plt.plot(val_metrics['auc'], label='AUC')
    plt.plot(val_metrics['f1'], label='F1')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.title('Validation Metrics')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(config.model_save_path, 'training_progress.png'))
    plt.show()
    
    return model, train_losses, val_metrics

# 5. Evaluation Functions

def evaluate_model(model, test_loader, device):
    """Evaluate model on test set"""
    model.eval()
    all_preds = []
    all_labels = []
    all_embeddings = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            # Get batch data
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            stylometric = batch['stylometric'].to(device)
            labels = batch['label'].to(device).float().unsqueeze(1)
            
            # Forward pass
            predictions, embeddings = model(input_ids, attention_mask, stylometric)
            
            # Store predictions, labels, and embeddings
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_embeddings.extend(embeddings.cpu().numpy())
    
    # Convert to arrays
    all_preds = np.array(all_preds).flatten()
    all_labels = np.array(all_labels).flatten()
    all_embeddings = np.array(all_embeddings)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, (all_preds > 0.5).astype(int))
    auc = roc_auc_score(all_labels, all_preds)
    f1 = f1_score(all_labels, (all_preds > 0.5).astype(int))
    
    print(f"Test Results:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  AUC-ROC: {auc:.4f}")
    print(f"  F1 Score: {f1:.4f}")
    
    # Visualize embeddings
    visualize_embeddings(all_embeddings, all_labels)
    
    return accuracy, auc, f1, all_embeddings, all_preds, all_labels

def visualize_embeddings(embeddings, labels):
    """Visualize embeddings using t-SNE"""
    # Reduce dimensionality with t-SNE
    tsne = TSNE(n_components=2, random_state=SEED)
    reduced_embeddings = tsne.fit_transform(embeddings)
    
    # Plot embeddings
    plt.figure(figsize=(10, 8))
    for label, color, marker, name in zip(
        [0, 1], 
        ['blue', 'red'], 
        ['o', 'x'],
        ['Human', 'AI']
    ):
        mask = labels == label
        plt.scatter(
            reduced_embeddings[mask, 0],
            reduced_embeddings[mask, 1],
            c=color,
            marker=marker,
            label=name,
            alpha=0.7
        )
    
    plt.legend()
    plt.title('t-SNE Visualization of Text Embeddings')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.savefig('embedding_visualization.png')
    plt.show()

# 6. Single Text Prediction Function

def predict_single_text(text, model_path=None, device=None):
    """
    Predict whether a single text is AI-generated or human-written
    
    Args:
        text (str): The text to classify
        model_path (str, optional): Path to the model. Defaults to None (uses default path).
        device (torch.device, optional): Device to run inference on. Defaults to None.
    
    Returns:
        float: Probability that the text is AI-generated
        str: Classification (AI or Human)
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load config and model
    if model_path is None:
        config = Config()
        model_path = config.best_model_path
    else:
        config = Config()
    
    # Check if model exists
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found at {model_path}. Please train the model first.")
    
    # Load preprocessor
    preprocessor_path = os.path.join(config.model_save_path, "text_preprocessor.joblib")
    if not os.path.exists(preprocessor_path):
        raise FileNotFoundError("Text preprocessor not found. Please train the model first.")
    
    preprocessor = joblib.load(preprocessor_path)
    
    # Load model
    checkpoint = torch.load(model_path, map_location=device)
    if 'config' in checkpoint:
        config = checkpoint['config']
    
    # Initialize model
    model = GANTextDetector(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Preprocess text
    normalized_text = preprocessor.normalize_text(text)
    masked_text = preprocessor.strategic_masking(normalized_text)
    stylometric_features = preprocessor.transform_stylometric([normalized_text])[0]
    
    # Tokenize
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name)
    encoding = tokenizer(
        masked_text,
        truncation=True,
        max_length=config.max_length,
        padding='max_length',
        return_tensors='pt'
    )
    
    # Move to device
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    stylometric = torch.tensor(stylometric_features, dtype=torch.float).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        prediction, _ = model(input_ids, attention_mask, stylometric)
        probability = prediction.item()
    
    # Classify
    classification = "AI-generated" if probability > 0.5 else "Human-written"
    
    return probability, classification

# 7. Main Execution

def main(data_path, config):
    """Main execution function"""
    # Load data
    if data_path.endswith('.csv'):
        df = pd.read_csv(data_path)
    elif data_path.endswith('.tsv'):
        df = pd.read_csv(data_path, sep='\t')
    else:
        raise ValueError("Unsupported file format. Please provide a CSV or TSV file.")
    
    # Initialize preprocessor
    preprocessor = TextPreprocessor(config)
    
    # Preprocess data
    texts, masked_texts, labels, stylometric_features = preprocessor.preprocess_dataset(df)
    
    # Prepare dataloaders
    train_loader, val_loader, test_loader = prepare_dataloaders(
        texts, masked_texts, labels, stylometric_features, config
    )
    
    # Initialize model
    model = GANTextDetector(config)
    model.to(device)
    
    # Train model
    model, train_losses, val_metrics = train_gan_detector(
        model, train_loader, val_loader, config, device
    )
    
    # Evaluate model
    accuracy, auc, f1, embeddings, preds, labels = evaluate_model(
        model, test_loader, device
    )
    
    # Test single text prediction
    sample_text = "This is a sample text to demonstrate the prediction function."
    probability, classification = predict_single_text(sample_text)
    print(f"\nSample text prediction:")
    print(f"  Text: \"{sample_text}\"")
    print(f"  Probability of being AI-generated: {probability:.4f}")
    print(f"  Classification: {classification}")
    
    return model




In [None]:
config = Config()
    
config.batch_size = 128
config.num_epochs = 50
config.gen_lr = 2e-4
config.disc_lr = 1e-4
    
main('/kaggle/input/combined-dataset-ai-human/combined_data.csv', config)
    
text = "This is an example text that I want to classify as AI or human."
probability, classification = predict_single_text(text)
print(f"The text is classified as {classification} with {probability:.2%} probability.")