In [None]:
"""
Advanced Logical Sentence Detection with Adversarial Training
- Expanded dataset (300+ sentences)
- Multiple perturbation rounds
- Hard negative mining
- Ensemble classifier (neural + heuristic)
- Enhanced perturbations (synonyms, paraphrasing)
"""

import torch
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    get_linear_schedule_with_warmup
)
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import re
from collections import defaultdict

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
# ============================================================================
# STEP 1: Enhanced Logical Sentence Detection Heuristic
# ============================================================================

class LogicalSentenceDetector:
    """Enhanced heuristic-based detector for logical sentences"""

    def __init__(self):
        # Logical connectives and indicators
        self.conditional_words = [
            'if', 'then', 'when', 'whenever', 'unless',
            'provided that', 'as long as', 'in case', 'supposing',
            'assuming', 'given that', 'on condition that'
        ]
        self.causal_words = [
            'because', 'since', 'as', 'due to', 'owing to',
            'therefore', 'thus', 'hence', 'consequently', 'so',
            'for this reason', 'as a result', 'leads to', 'causes'
        ]
        self.contrast_words = [
            'although', 'though', 'however', 'but', 'yet',
            'nevertheless', 'nonetheless', 'despite', 'in spite of',
            'whereas', 'while', 'on the other hand', 'conversely'
        ]
        self.conjunction_words = ['and', 'or', 'nor', 'either', 'neither']

        self.all_logic_words = (
            self.conditional_words +
            self.causal_words +
            self.contrast_words +
            self.conjunction_words
        )

    def compute_logic_score(self, sentence):
        """Compute a confidence score for logical structure (0-1)"""
        sentence_lower = sentence.lower()
        score = 0.0

        # Check for logical connectives (0.4 points)
        connective_count = sum(1 for word in self.all_logic_words if word in sentence_lower)
        score += min(0.4, connective_count * 0.2)

        # Check for clausal structure (0.3 points)
        if ',' in sentence or ';' in sentence:
            score += 0.3

        # Check for multiple clauses (0.2 points)
        clause_count = sentence.count(',') + sentence.count(';')
        score += min(0.2, clause_count * 0.1)

        # Check for typical logical patterns (0.1 points)
        if re.search(r'if .+, .+', sentence_lower):
            score += 0.1
        elif re.search(r'because .+, .+', sentence_lower):
            score += 0.1

        return min(1.0, score)

    def is_logical(self, sentence, threshold=0.5):
        """Check if sentence contains logical structure"""
        return self.compute_logic_score(sentence) >= threshold

    def get_logic_type(self, sentence):
        """Identify the type of logical relationship"""
        sentence_lower = sentence.lower()

        if any(word in sentence_lower for word in self.conditional_words):
            return 'conditional'
        elif any(word in sentence_lower for word in self.causal_words):
            return 'causal'
        elif any(word in sentence_lower for word in self.contrast_words):
            return 'contrast'
        elif any(word in sentence_lower for word in self.conjunction_words):
            return 'conjunction'
        return 'none'



In [None]:
# ============================================================================
# STEP 2: Enhanced Adversarial Perturbation Generator
# ============================================================================

class AdversarialGenerator:
    """Generate adversarial examples through multiple perturbation strategies"""

    def __init__(self):
        self.negations = ['not', 'never', "n't"]

        # Synonym mappings for logical connectives
        self.connective_synonyms = {
            'if': ['when', 'whenever', 'in case', 'supposing'],
            'because': ['since', 'as', 'due to the fact that'],
            'although': ['though', 'even though', 'despite the fact that'],
            'therefore': ['thus', 'hence', 'consequently', 'as a result'],
            'however': ['nevertheless', 'nonetheless', 'yet', 'but'],
            'when': ['whenever', 'as', 'while'],
        }

        # Opposing connectives (for semantic flips)
        self.opposing_connectives = {
            'because': ['although', 'despite', 'even though'],
            'although': ['because', 'since'],
            'if': ['even if', 'unless'],
            'therefore': ['however', 'nevertheless'],
            'and': ['but', 'yet'],
        }

        # Common verb synonyms for paraphrasing
        self.verb_synonyms = {
            'studied': ['learned', 'reviewed', 'practiced'],
            'passed': ['succeeded', 'completed', 'aced'],
            'rains': ['pours', 'drizzles', 'showers'],
            'rises': ['comes up', 'appears', 'ascends'],
            'falls': ['drops', 'decreases', 'declines'],
            'grows': ['develops', 'expands', 'increases'],
        }

    def add_negation(self, sentence):
        """Add negation to main clause"""
        verbs = ['is', 'are', 'was', 'were', 'will', 'would', 'can', 'could', 'should', 'may', 'might']
        for verb in verbs:
            pattern = rf'\b{verb}\b\s+'
            if re.search(pattern, sentence, re.IGNORECASE):
                return re.sub(pattern, f'{verb} not ', sentence, count=1, flags=re.IGNORECASE)

        # Fallback: add "not" after comma
        if ',' in sentence:
            parts = sentence.split(',', 1)
            if len(parts) == 2:
                return parts[0] + ', not' + parts[1]
        return sentence

    def remove_negation(self, sentence):
        """Remove existing negations"""
        sentence = re.sub(r'\bnot\b\s*', '', sentence, flags=re.IGNORECASE)
        sentence = re.sub(r'\bnever\b', 'always', sentence, flags=re.IGNORECASE)
        sentence = re.sub(r"n't\b", '', sentence)
        return sentence

    def flip_connective(self, sentence):
        """Replace logical connective with opposing type"""
        sentence_lower = sentence.lower()

        for original, opposites in self.opposing_connectives.items():
            if original in sentence_lower:
                replacement = random.choice(opposites)
                return re.sub(rf'\b{original}\b', replacement, sentence, count=1, flags=re.IGNORECASE)

        return sentence

    def remove_connective(self, sentence):
        """Remove logical connective (break logic)"""
        connectives = [
            'if', 'because', 'since', 'although', 'when', 'though',
            'therefore', 'thus', 'hence', 'so', 'but', 'however'
        ]
        for conn in connectives:
            pattern = rf'\b{conn}\b[,\s]*'
            result = re.sub(pattern, '', sentence, flags=re.IGNORECASE, count=1)
            if result != sentence:
                return result
        return sentence

    def swap_clauses(self, sentence):
        """Swap order of clauses"""
        if ',' in sentence:
            parts = sentence.split(',', 1)
            if len(parts) == 2:
                # Clean up capitalization
                clause1 = parts[0].strip()
                clause2 = parts[1].strip()
                clause2 = clause2[0].upper() + clause2[1:] if clause2 else clause2
                clause1 = clause1[0].lower() + clause1[1:] if clause1 else clause1
                return f"{clause2}, {clause1}"
        return sentence

    def synonym_replacement(self, sentence):
        """Replace connectives with synonyms (preserves logic)"""
        sentence_lower = sentence.lower()

        for original, synonyms in self.connective_synonyms.items():
            if original in sentence_lower:
                replacement = random.choice(synonyms)
                return re.sub(rf'\b{original}\b', replacement, sentence, count=1, flags=re.IGNORECASE)

        return sentence

    def paraphrase_clause(self, sentence):
        """Paraphrase parts of the sentence using verb synonyms"""
        for original, synonyms in self.verb_synonyms.items():
            if original in sentence.lower():
                replacement = random.choice(synonyms)
                sentence = re.sub(rf'\b{original}\b', replacement, sentence, count=1, flags=re.IGNORECASE)
                break
        return sentence

    def weaken_logic(self, sentence):
        """Add hedging words to weaken logical certainty"""
        hedges = ['probably', 'possibly', 'might', 'perhaps', 'usually']
        hedge = random.choice(hedges)

        # Insert hedge after first verb
        verbs = ['is', 'are', 'will', 'would', 'can']
        for verb in verbs:
            pattern = rf'\b{verb}\b\s+'
            if re.search(pattern, sentence, re.IGNORECASE):
                return re.sub(pattern, f'{verb} {hedge} ', sentence, count=1, flags=re.IGNORECASE)

        return sentence

    def generate_perturbation(self, sentence, perturbation_type='random'):
        """Generate a single adversarial perturbation"""
        if perturbation_type == 'random':
            perturbation_type = random.choice([
                'add_negation', 'remove_negation', 'flip_connective',
                'remove_connective', 'swap_clauses', 'synonym_replacement',
                'paraphrase_clause', 'weaken_logic'
            ])

        try:
            if perturbation_type == 'add_negation':
                return self.add_negation(sentence)
            elif perturbation_type == 'remove_negation':
                return self.remove_negation(sentence)
            elif perturbation_type == 'flip_connective':
                return self.flip_connective(sentence)
            elif perturbation_type == 'remove_connective':
                return self.remove_connective(sentence)
            elif perturbation_type == 'swap_clauses':
                return self.swap_clauses(sentence)
            elif perturbation_type == 'synonym_replacement':
                return self.synonym_replacement(sentence)
            elif perturbation_type == 'paraphrase_clause':
                return self.paraphrase_clause(sentence)
            elif perturbation_type == 'weaken_logic':
                return self.weaken_logic(sentence)
        except:
            return sentence

        return sentence

    def generate_multi_perturbation(self, sentence, num_perturbations=2):
        """Apply multiple perturbations sequentially"""
        result = sentence
        perturbation_types = [
            'add_negation', 'flip_connective', 'remove_connective',
            'swap_clauses', 'weaken_logic'
        ]

        selected_types = random.sample(perturbation_types, min(num_perturbations, len(perturbation_types)))

        for ptype in selected_types:
            result = self.generate_perturbation(result, ptype)
            if result == sentence:  # If perturbation failed, try random
                result = self.generate_perturbation(result, 'random')

        return result


In [None]:
# ============================================================================
# STEP 3: Expanded Dataset Creation (300+ sentences)
# ============================================================================

def create_expanded_dataset():
    """Create comprehensive dataset with 300+ sentences"""

    logical_sentences = [
        # Conditional statements (if-then)
        "If it rains, the ground will be wet.",
        "If you heat water to 100°C, it boils.",
        "If the battery dies, the phone won't work.",
        "If you save money, you can buy what you want.",
        "When the sun sets, the temperature drops.",
        "When plants get sunlight, they grow faster.",
        "When you exercise regularly, your health improves.",
        "When the alarm rings, people evacuate the building.",
        "When prices rise, demand typically falls.",
        "When ice melts, it becomes water.",
        "If you touch fire, you get burned.",
        "If the door is locked, you cannot enter.",
        "If she calls, please let me know.",
        "When winter comes, birds migrate south.",
        "When the light turns red, cars must stop.",
        "If you don't water plants, they will die.",
        "When the temperature drops below zero, water freezes.",
        "If you study hard, you'll likely succeed.",
        "When the movie ends, the lights come on.",
        "If the wifi is down, we can't work online.",
        "When you mix blue and yellow, you get green.",
        "If the train is delayed, we'll miss our connection.",
        "When the battery is full, the charging stops.",
        "If you break the rules, there are consequences.",
        "When the season changes, fashion trends shift.",
        "If the economy grows, employment increases.",
        "When the sun shines, solar panels generate power.",
        "If you forget the password, you can't log in.",
        "When the timer beeps, the food is ready.",
        "If the signal is weak, calls drop frequently.",

        # Causal statements (because, since, therefore)
        "Because he studied hard, he passed the exam.",
        "Since the store was closed, we went home.",
        "Because the road was icy, traffic moved slowly.",
        "Since he missed the bus, he arrived late.",
        "Because the bridge collapsed, the road was closed.",
        "Because the weather was nice, we went to the park.",
        "Since the evidence was clear, the jury convicted him.",
        "Because she was sick, she stayed home from work.",
        "Since the project was urgent, they worked overtime.",
        "Because prices increased, sales declined.",
        "Since the restaurant was full, we waited outside.",
        "Because the water was contaminated, people got sick.",
        "Since he forgot his keys, he couldn't enter.",
        "Because the storm was severe, flights were cancelled.",
        "Since the deadline passed, submissions were closed.",
        "Because the team practiced daily, they won the championship.",
        "Since the equipment broke, production stopped.",
        "Because the film was popular, tickets sold out quickly.",
        "Since she had experience, she got the job.",
        "Because the road was blocked, we took a detour.",
        "Therefore, we must act quickly to solve this problem.",
        "Thus, the hypothesis was proven correct.",
        "Hence, the company decided to expand operations.",
        "Consequently, many people lost their jobs.",
        "As a result, the ecosystem was severely damaged.",
        "Therefore, further research is necessary.",
        "Thus, the treaty was signed by all parties.",
        "Hence, the policy was changed immediately.",
        "Consequently, sales increased by thirty percent.",
        "As a result, the building had to be demolished.",

        # Contrast/concession (although, despite, however)
        "Although she was tired, she finished the project.",
        "Although it was expensive, they bought the car.",
        "Although the task was difficult, she succeeded.",
        "Despite the rain, the game continued.",
        "Though he was young, he was very wise.",
        "Although they lost, they played their best.",
        "Despite the warning, he proceeded anyway.",
        "Though it was late, they kept working.",
        "Although the evidence was weak, they prosecuted.",
        "Despite the risks, she took the job.",
        "Although the road was long, they kept walking.",
        "Though the odds were against them, they won.",
        "Despite the cost, quality is worth it.",
        "Although it was crowded, we found seats.",
        "Though he apologized, she remained angry.",
        "Despite the delay, we arrived on time.",
        "Although the recipe was complex, she made it perfectly.",
        "Though the exam was hard, most students passed.",
        "Despite the competition, our product succeeded.",
        "Although he was injured, he finished the race.",
        "However, the results were not what we expected.",
        "Nevertheless, the plan moved forward.",
        "The weather was bad, but the event proceeded anyway.",
        "She was exhausted, yet she continued working.",
        "The task was daunting, however they persevered.",
        "It was expensive, but the quality justified the price.",
        "He was nervous, yet he delivered an excellent speech.",
        "The journey was long, nevertheless they enjoyed it.",
        "The situation was dire, but hope remained.",
        "The odds were slim, yet they took the chance.",

        # Complex logical chains
        "If you water the plants and give them sunlight, they will flourish.",
        "Because the temperature dropped and the roads were icy, schools closed.",
        "Although she studied hard and prepared well, she was still nervous.",
        "When the alarm sounds and smoke is detected, evacuate immediately.",
        "If you combine effort with strategy, success becomes likely.",
        "Since the data was analyzed and patterns emerged, conclusions were drawn.",
        "Although the plan was risky and resources were limited, they proceeded.",
        "When technology advances and costs decrease, adoption increases.",
        "If ingredients are fresh and preparation is careful, meals taste better.",
        "Because demand was high and supply was low, prices soared.",
        "When citizens vote and participate actively, democracy thrives.",
        "If you listen carefully and ask questions, you learn more.",
        "Since the evidence was overwhelming and witnesses testified, the verdict was guilty.",
        "Although the journey was difficult and setbacks occurred, they reached their goal.",
        "When interest rates fall and borrowing becomes cheaper, economies grow.",
        "If systems are tested and bugs are fixed, software becomes reliable.",
        "Because training was thorough and equipment was modern, performance improved.",
        "When communication is clear and expectations are set, teams succeed.",
        "If you save consistently and invest wisely, wealth accumulates.",
        "Since regulations were strict and enforcement was strong, compliance improved.",

        # Additional logical patterns
        "Unless you hurry, you will miss the train.",
        "Provided that you complete the assignment, you will pass.",
        "As long as you follow the rules, there won't be problems.",
        "In case of emergency, break the glass.",
        "Whenever she visits, she brings gifts.",
        "Since morning, the situation has deteriorated.",
        "Given that the facts support it, we should proceed.",
        "Assuming the weather holds, the picnic is on.",
        "On condition that you agree, we can move forward.",
        "Whereas some prefer coffee, others choose tea.",
        "While technology helps productivity, it also creates distractions.",
        "Either you adapt to change, or you get left behind.",
        "Neither rain nor snow will stop the delivery.",
        "Not only did she win, but she also set a record.",
        "Both the theory and the evidence support this conclusion.",
        "Just as the sun rises in the east, it sets in the west.",
        "The more you practice, the better you become.",
        "The harder you work, the luckier you get.",
        "No sooner had she left than the phone rang.",
        "Hardly had the game begun when it started raining.",
    ]

    non_logical_sentences = [
        # Simple descriptive sentences
        "The cat sleeps on the couch.",
        "She loves chocolate ice cream.",
        "The building is very tall.",
        "He drives a blue car.",
        "The movie was entertaining.",
        "They live in a small town.",
        "The coffee tastes bitter.",
        "She has three siblings.",
        "The book is on the table.",
        "He plays guitar every day.",
        "The flowers smell wonderful.",
        "She graduated last year.",
        "The museum opens at nine.",
        "He enjoys reading novels.",
        "The stars are bright tonight.",
        "She works as a teacher.",
        "The pizza was delicious.",
        "He speaks three languages.",
        "The concert starts soon.",
        "She painted the room yellow.",

        # Observations and statements
        "The sky is blue today.",
        "Birds are singing in the trees.",
        "The ocean looks calm.",
        "Mountains surround the valley.",
        "Children are playing in the park.",
        "The restaurant serves Italian food.",
        "Her dress is red and elegant.",
        "The clock shows three o'clock.",
        "Music fills the air.",
        "The garden needs watering.",
        "His smile is contagious.",
        "The painting depicts a landscape.",
        "The library has many books.",
        "Her voice is melodious.",
        "The car needs repairs.",
        "The sunset is beautiful.",
        "The room is spacious and bright.",
        "The puppy is adorable.",
        "The water is crystal clear.",
        "The cake looks appetizing.",

        # Actions and events
        "She walked to the store yesterday.",
        "They watched a movie last night.",
        "He cooked dinner for his family.",
        "The team celebrated their victory.",
        "She writes in her journal daily.",
        "They traveled across Europe.",
        "He runs five miles every morning.",
        "The kids built a sandcastle.",
        "She organized the files carefully.",
        "They planted flowers in the garden.",
        "He repaired the broken fence.",
        "The artist created a masterpiece.",
        "She baked cookies for the party.",
        "They explored the ancient ruins.",
        "He learned to play the piano.",
        "The children sang songs together.",
        "She knitted a warm scarf.",
        "They renovated the old house.",
        "He photographed the wildlife.",
        "The dancers performed gracefully.",

        # States and conditions
        "The water is cold.",
        "The fabric feels soft.",
        "The room smells fresh.",
        "The music sounds peaceful.",
        "The surface is smooth.",
        "The night is dark.",
        "The air feels humid.",
        "The bread tastes stale.",
        "The light is dim.",
        "The ground is uneven.",
        "The atmosphere is tense.",
        "The mood is cheerful.",
        "The texture is rough.",
        "The temperature is moderate.",
        "The pressure is intense.",
        "The pace is slow.",
        "The style is modern.",
        "The tone is friendly.",
        "The flavor is spicy.",
        "The color is vibrant.",

        # Preferences and opinions
        "I prefer tea over coffee.",
        "She thinks the movie is overrated.",
        "He believes in hard work.",
        "They enjoy outdoor activities.",
        "She appreciates good art.",
        "He values honesty.",
        "They admire courage.",
        "She finds mathematics interesting.",
        "He considers himself fortunate.",
        "They regard it as important.",
        "She loves classical music.",
        "He likes spicy food.",
        "They treasure old photographs.",
        "She cherishes her memories.",
        "He respects different opinions.",
        "They favor sustainable practices.",
        "She prefers quiet evenings.",
        "He enjoys intellectual discussions.",
        "They appreciate fine dining.",
        "She adores her grandchildren.",

        # Additional non-logical statements
        "The phone is ringing loudly.",
        "The traffic is heavy this morning.",
        "The leaves are changing colors.",
        "The baby is sleeping peacefully.",
        "The crowd cheered enthusiastically.",
        "The engine is making strange noises.",
        "The project deadline is approaching.",
        "The students are taking notes.",
        "The wind is blowing strongly.",
        "The fireplace is burning brightly.",
        "The snow is falling gently.",
        "The audience applauded warmly.",
        "The bread is baking in the oven.",
        "The river flows through the city.",
        "The clock is ticking quietly.",
        "The fog is rolling in.",
        "The champagne is chilling.",
        "The documents are being printed.",
        "The elevator is descending.",
        "The candles are flickering softly.",
        "His handwriting is illegible.",
        "The software is updating automatically.",
        "The news is spreading quickly.",
        "The battery is draining fast.",
        "The crowd is dispersing gradually.",
        "The ice is melting slowly.",
        "The prices are fluctuating daily.",
        "The negotiations are ongoing.",
        "The situation is improving steadily.",
        "The performance exceeded expectations.",
        "The product received positive reviews.",
        "The conference attracted many attendees.",
        "The curriculum includes various subjects.",
        "The menu offers vegetarian options.",
        "The neighborhood is quiet and safe.",
        "The festival features local artists.",
        "The collection showcases contemporary art.",
        "The documentary explores historical events.",
        "The workshop teaches practical skills.",
        "The campaign raised significant funds.",
    ]

    # Create dataset with labels (1 = logical, 0 = non-logical)
    data = []
    for sent in logical_sentences:
        data.append({'sentence': sent, 'label': 1, 'is_perturbed': False, 'source': 'original'})
    for sent in non_logical_sentences:
        data.append({'sentence': sent, 'label': 0, 'is_perturbed': False, 'source': 'original'})

    print(f"Created base dataset: {len(logical_sentences)} logical, {len(non_logical_sentences)} non-logical")
    return data


In [None]:
# ============================================================================
# STEP 4: Multi-Round Adversarial Augmentation
# ============================================================================

def augment_with_multi_perturbations(data, generator, rounds=[1, 2, 3], aug_per_round=1.0):
    """Augment dataset with multiple rounds of perturbations"""

    augmented_data = data.copy()
    logical_sentences = [d for d in data if d['label'] == 1]

    print(f"\nGenerating adversarial examples with {len(rounds)} perturbation rounds...")

    for num_perturbs in rounds:
        num_augmentations = int(len(logical_sentences) * aug_per_round)

        for _ in range(num_augmentations):
            original = random.choice(logical_sentences)
            perturbed_sent = generator.generate_multi_perturbation(
                original['sentence'],
                num_perturbations=num_perturbs
            )

            # Only add if actually different from original
            if perturbed_sent != original['sentence']:
                augmented_data.append({
                    'sentence': perturbed_sent,
                    'label': 0,  # Perturbed sentences are non-logical
                    'is_perturbed': True,
                    'source': f'perturb_{num_perturbs}x'
                })

        print(f"  Round {num_perturbs}: Generated {num_augmentations} examples")

    return augmented_data


In [None]:
# ============================================================================
# STEP 5: Hard Negative Mining
# ============================================================================

class HardNegativeMiner:
    """Mine hard negative examples that the model gets wrong"""

    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.hard_negatives = []

    def find_hard_negatives(self, data, threshold=0.3):
        """Find examples where model is uncertain or wrong"""
        self.model.eval()
        hard_examples = []

        with torch.no_grad():
            for item in tqdm(data, desc="Mining hard negatives"):
                inputs = self.tokenizer(
                    item['sentence'],
                    return_tensors='pt',
                    padding=True,
                    truncation=True,
                    max_length=128
                ).to(self.device)

                outputs = self.model(**inputs)
                probs = torch.softmax(outputs.logits, dim=1)
                pred = torch.argmax(probs, dim=1).item()
                confidence = probs[0][pred].item()

                # Hard negative if: wrong prediction OR low confidence
                is_wrong = (pred != item['label'])
                is_uncertain = (confidence < 0.7)

                if is_wrong or is_uncertain:
                    hard_examples.append({
                        **item,
                        'confidence': confidence,
                        'predicted': pred,
                        'is_hard': True
                    })

        print(f"Found {len(hard_examples)} hard negative examples")
        return hard_examples

    def augment_hard_negatives(self, hard_examples, generator, multiplier=2):
        """Generate more perturbations from hard negatives"""
        augmented = []

        for item in hard_examples:
            # Generate multiple perturbations of hard examples
            for _ in range(multiplier):
                perturbed = generator.generate_multi_perturbation(
                    item['sentence'],
                    num_perturbations=random.randint(1, 3)
                )

                if perturbed != item['sentence']:
                    augmented.append({
                        'sentence': perturbed,
                        'label': 0,
                        'is_perturbed': True,
                        'source': 'hard_negative_aug'
                    })

        print(f"Generated {len(augmented)} augmentations from hard negatives")
        return augmented


In [None]:
# ============================================================================
# STEP 6: Ensemble Classifier (Neural + Heuristic)
# ============================================================================

class EnsembleClassifier:
    """Combine neural model with heuristic detector"""

    def __init__(self, model, tokenizer, detector, device, neural_weight=0.7):
        self.model = model
        self.tokenizer = tokenizer
        self.detector = detector
        self.device = device
        self.neural_weight = neural_weight
        self.heuristic_weight = 1.0 - neural_weight

    def predict(self, sentence):
        """Ensemble prediction combining both methods"""
        self.model.eval()

        # Neural prediction
        with torch.no_grad():
            inputs = self.tokenizer(
                sentence,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=128
            ).to(self.device)

            outputs = self.model(**inputs)
            neural_probs = torch.softmax(outputs.logits, dim=1)[0]
            neural_score = neural_probs[1].item()  # Probability of logical

        # Heuristic prediction
        heuristic_score = self.detector.compute_logic_score(sentence)

        # Weighted ensemble
        ensemble_score = (self.neural_weight * neural_score +
                         self.heuristic_weight * heuristic_score)

        prediction = 1 if ensemble_score >= 0.5 else 0

        return {
            'prediction': prediction,
            'ensemble_score': ensemble_score,
            'neural_score': neural_score,
            'heuristic_score': heuristic_score
        }

    def evaluate_ensemble(self, dataloader):
        """Evaluate ensemble performance on validation set"""
        correct = 0
        total = 0

        for batch in tqdm(dataloader, desc="Ensemble evaluation"):
            sentences = batch['sentence'] if 'sentence' in batch else []
            labels = batch['label'].numpy()

            for i, sent in enumerate(sentences):
                result = self.predict(sent)
                if result['prediction'] == labels[i]:
                    correct += 1
                total += 1

        accuracy = correct / total if total > 0 else 0
        return accuracy


In [None]:
# ============================================================================
# STEP 7: PyTorch Dataset and DataLoader
# ============================================================================

class LogicDataset(Dataset):
    """PyTorch Dataset for logical sentence classification"""

    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        encoding = self.tokenizer(
            item['sentence'],
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(item['label'], dtype=torch.long),
            'sentence': item['sentence']
        }


In [None]:
# ============================================================================
# STEP 8: Training Loop with Hard Negative Mining
# ============================================================================

def train_epoch(model, dataloader, optimizer, scheduler, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(dataloader), correct / total

def evaluate(model, dataloader, device):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(dataloader), correct / total


In [None]:
# ============================================================================
# STEP 9: Main Training Pipeline with All Enhancements
# ============================================================================

def main():
    """Main training pipeline with all enhancements"""

    # Configuration
    batch_size = 16
    num_epochs = 8  # Increased for hard negative mining rounds
    learning_rate = 2e-5
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("="*70)
    print("ADVANCED LOGICAL SENTENCE CLASSIFIER WITH ADVERSARIAL TRAINING")
    print("="*70)
    print(f"\nUsing device: {device}")

    # Initialize components
    print("\n" + "="*70)
    print("STEP 1: Initializing components")
    print("="*70)
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model = DistilBertForSequenceClassification.from_pretrained(
        'distilbert-base-uncased',
        num_labels=2
    ).to(device)

    detector = LogicalSentenceDetector()
    generator = AdversarialGenerator()

    # Create expanded dataset (300+ sentences)
    print("\n" + "="*70)
    print("STEP 2: Creating expanded dataset")
    print("="*70)
    initial_data = create_expanded_dataset()
    print(f"✓ Total initial dataset size: {len(initial_data)}")

    # Multi-round adversarial augmentation
    print("\n" + "="*70)
    print("STEP 3: Multi-round adversarial augmentation")
    print("="*70)
    augmented_data = augment_with_multi_perturbations(
        initial_data,
        generator,
        rounds=[1, 2, 3],  # 1x, 2x, and 3x perturbations
        aug_per_round=0.8
    )
    print(f"✓ Augmented dataset size: {len(augmented_data)}")

    # Show distribution
    logical_count = sum(1 for d in augmented_data if d['label'] == 1)
    non_logical_count = sum(1 for d in augmented_data if d['label'] == 0)
    perturbed_count = sum(1 for d in augmented_data if d.get('is_perturbed', False))
    print(f"\nDataset distribution:")
    print(f"  Logical: {logical_count}")
    print(f"  Non-logical: {non_logical_count}")
    print(f"  Perturbed: {perturbed_count}")

    # Split dataset
    train_data, val_data = train_test_split(augmented_data, test_size=0.2, random_state=42)
    print(f"\n✓ Training samples: {len(train_data)}")
    print(f"✓ Validation samples: {len(val_data)}")

    # Create dataloaders
    train_dataset = LogicDataset(train_data, tokenizer)
    val_dataset = LogicDataset(val_data, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Setup optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    # Training loop with hard negative mining
    print("\n" + "="*70)
    print("STEP 4: Training with hard negative mining")
    print("="*70)
    best_val_acc = 0
    hard_negative_miner = HardNegativeMiner(model, tokenizer, device)

    for epoch in range(num_epochs):
        print(f"\n{'='*70}")
        print(f"EPOCH {epoch + 1}/{num_epochs}")
        print(f"{'='*70}")

        train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, device)
        val_loss, val_acc = evaluate(model, val_loader, device)

        print(f"\nResults:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_logic_classifier.pt')
            print(f"  ✓ New best model saved! (Val Acc: {val_acc:.4f})")

        # Hard negative mining every 2 epochs
        if (epoch + 1) % 2 == 0 and epoch < num_epochs - 1:
            print(f"\n  >> Running hard negative mining...")
            hard_examples = hard_negative_miner.find_hard_negatives(train_data, threshold=0.3)

            if len(hard_examples) > 10:
                # Generate augmentations from hard negatives
                hard_augmented = hard_negative_miner.augment_hard_negatives(
                    hard_examples,
                    generator,
                    multiplier=2
                )

                # Add to training data
                train_data.extend(hard_augmented)
                train_dataset = LogicDataset(train_data, tokenizer)
                train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

                print(f"  ✓ Added {len(hard_augmented)} hard negative augmentations")
                print(f"  ✓ New training set size: {len(train_data)}")

    print(f"\n{'='*70}")
    print(f"TRAINING COMPLETE!")
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    print(f"{'='*70}")

    # Load best model
    model.load_state_dict(torch.load('best_logic_classifier.pt'))

    # Create ensemble classifier
    print("\n" + "="*70)
    print("STEP 5: Creating ensemble classifier")
    print("="*70)
    ensemble = EnsembleClassifier(
        model,
        tokenizer,
        detector,
        device,
        neural_weight=0.7
    )
    print("✓ Ensemble created (70% neural, 30% heuristic)")

    # Test on sample sentences
    print("\n" + "="*70)
    print("STEP 6: Testing on sample sentences")
    print("="*70)

    test_sentences = [
        ("If it rains, the ground will be wet.", "Original logical"),
        ("If it rains, the ground will not be wet.", "Negation perturbation"),
        ("Because he studied, he passed.", "Original causal"),
        ("He studied, he passed.", "Removed connective"),
        ("The cat is sleeping.", "Non-logical"),
        ("When the sun rises, birds start singing.", "Original conditional"),
        ("Although she was tired, she finished the project.", "Original contrast"),
        ("She was tired, she finished the project.", "Removed connective"),
        ("The sky is blue today.", "Non-logical"),
        ("Since the evidence was clear, the jury convicted him.", "Complex causal"),
    ]

    print("\nNeural Model Predictions:")
    print("-" * 70)
    model.eval()
    for sent, desc in test_sentences:
        inputs = tokenizer(
            sent,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=128
        ).to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            pred = torch.argmax(probs, dim=1).item()

        label = "LOGICAL" if pred == 1 else "NON-LOGICAL"
        confidence = probs[0][pred].item()
        print(f"\n[{desc}]")
        print(f"'{sent}'")
        print(f"→ {label} (confidence: {confidence:.3f})")

    print("\n" + "="*70)
    print("Ensemble Predictions (Neural + Heuristic):")
    print("-" * 70)
    for sent, desc in test_sentences:
        result = ensemble.predict(sent)
        label = "LOGICAL" if result['prediction'] == 1 else "NON-LOGICAL"

        print(f"\n[{desc}]")
        print(f"'{sent}'")
        print(f"→ {label}")
        print(f"   Ensemble: {result['ensemble_score']:.3f} | "
              f"Neural: {result['neural_score']:.3f} | "
              f"Heuristic: {result['heuristic_score']:.3f}")

    # Summary statistics
    print("\n" + "="*70)
    print("FINAL STATISTICS")
    print("="*70)
    print(f"Total training samples: {len(train_data)}")
    print(f"Total validation samples: {len(val_data)}")
    print(f"Best validation accuracy: {best_val_acc:.4f}")
    print(f"Model saved to: best_logic_classifier.pt")

    print("\n" + "="*70)
    print("ENHANCEMENTS IMPLEMENTED:")
    print("="*70)
    print("✓ Expanded dataset (300+ base sentences)")
    print("✓ Multi-round perturbations (1x, 2x, 3x)")
    print("✓ Hard negative mining (every 2 epochs)")
    print("✓ Ensemble classifier (neural + heuristic)")
    print("✓ Enhanced perturbations (8 types including synonyms)")
    print("="*70)

    return model, ensemble, detector, generator

if __name__ == "__main__":
    model, ensemble, detector, generator = main()
