In [None]:
"""
TEXT DATASET PREPROCESSING - WITH SMART TEXT CLEANING (NO EMOJI PRINTS)
"""

import os
import pandas as pd
import re
import html
import unicodedata
from sklearn.model_selection import train_test_splita

def process_text_datasets():
    """Process text datasets with correct column mappings and smart text cleaning"""

    base_dir = r"datasets"
    processed_dir = os.path.join(base_dir, "processed_data")
    os.makedirs(processed_dir, exist_ok=True)

    all_datasets_processed = {}

    def smart_text_cleaning(text):
        """
        Robust cleaning that:
        - normalizes unicode and apostrophes
        - expands contractions (uses contractions lib if installed, otherwise fallback rules)
        - removes URLs, emails, hashtags
        - normalizes punctuation and whitespace
        """
        if not isinstance(text, str):
            return ""

        # 1) HTML unescape and unicode normalize
        text = html.unescape(text)
        text = unicodedata.normalize("NFKC", text)

        # 2) Normalize apostrophes to straight single quote
        text = text.replace("'", "'").replace("'", "'").replace("`", "'")

        # 3) Expand contractions (try contractions lib, otherwise apply regex fallback)
        try:
            import contractions
            # contractions.fix preserves capitalization; we'll lowercase later
            text = contractions.fix(text)
        except Exception:
            # fallback generic rules (covers common contractions)
            # special cases first
            text = re.sub(r"\bwon't\b", "will not", text, flags=re.IGNORECASE)
            text = re.sub(r"\bcan't\b", "can not", text, flags=re.IGNORECASE)
            # general patterns
            rules = [
                (r"n['']t\b", " not"),
                (r"['']re\b", " are"),
                (r"['']ve\b", " have"),
                (r"['']ll\b", " will"),
                (r"['']d\b", " would"),
                (r"['']m\b", " am"),
                (r"['']s\b", " is"),
            ]
            for pat, rep in rules:
                text = re.sub(pat, rep, text, flags=re.IGNORECASE)

        # 4) Lowercase and strip
        text = text.lower().strip()

        # 5) Remove URLs and emails
        text = re.sub(r"http\S+|www\S+|https\S+", "", text)
        text = re.sub(r"\S+@\S+", "", text)

        # 6) Remove hashtags (#love -> love)
        text = re.sub(r"#(\w+)", r"\1", text)

        # 7) Normalize repeated punctuation
        text = re.sub(r"!{2,}", "!", text)
        text = re.sub(r"\?{2,}", "?", text)
        text = re.sub(r"\.{2,}", ".", text)

        # 8) Remove stray characters but keep basic punctuation . , ! ? ' -
        text = re.sub(r"[^\w\s\.\,\!\?\-\'\"]+", " ", text)

        # 9) Collapse whitespace
        text = " ".join(text.split())

        return text

    def safe_read_csv(csv_file, nrows=None):
        """Safe CSV reading with multiple encodings"""
        try:
            return pd.read_csv(csv_file, encoding='utf-8', on_bad_lines='skip', low_memory=False, nrows=nrows)
        except Exception:
            try:
                return pd.read_csv(csv_file, encoding='utf-8-sig', on_bad_lines='skip', low_memory=False, nrows=nrows)
            except Exception:
                try:
                    return pd.read_csv(csv_file, encoding='latin-1', on_bad_lines='skip', low_memory=False, nrows=nrows)
                except Exception:
                    print(f"Could not read: {csv_file}")
                    return pd.DataFrame()

    def safe_int_convert(value):
        try:
            return int(float(value))
        except Exception:
            return 0

    # ---------- dataset processors (kept structure from your file) ----------

    def process_hate_speech_curated():
        print("\nPROCESSING: hate_speech_curated")
        dataset_path = os.path.join(base_dir, "hate_speech_curated")
        all_data = []

        file1 = os.path.join(dataset_path, "HateSpeechDataset.csv")
        if os.path.exists(file1):
            print("  Reading: HateSpeechDataset.csv")
            df = safe_read_csv(file1)
            if not df.empty:
                print(f"    Columns: {list(df.columns)}")
                valid_count = 0
                for idx, row in df.iterrows():
                    try:
                        text = str(row.get('Content', "")) if pd.notna(row.get('Content', "")) else ""
                        label_val = row.get('Label', 0)

                        if text.strip() == "" or text.lower() in ['content', 'label', 'content_int']:
                            continue

                        text = smart_text_cleaning(text)
                        if not text.strip():
                            continue

                        label = safe_int_convert(label_val)

                        all_data.append({
                            'text': text,
                            'label': label,
                            'source_dataset': 'hate_speech_curated',
                            'source_file': 'HateSpeechDataset.csv'
                        })
                        valid_count += 1
                    except Exception:
                        continue
                print(f"    Processed {valid_count} valid samples from HateSpeechDataset.csv")

        file2 = os.path.join(dataset_path, "HateSpeechDatasetBalanced.csv")
        if os.path.exists(file2):
            print("  Reading: HateSpeechDatasetBalanced.csv")
            df = safe_read_csv(file2)
            if not df.empty:
                print(f"    Columns: {list(df.columns)}")
                valid_count = 0
                for idx, row in df.iterrows():
                    try:
                        text = str(row.get('Content', "")) if pd.notna(row.get('Content', "")) else ""
                        label_val = row.get('Label', 0)

                        if text.strip() == "" or text.lower() in ['content', 'label']:
                            continue

                        text = smart_text_cleaning(text)
                        if not text.strip():
                            continue

                        label = safe_int_convert(label_val)

                        all_data.append({
                            'text': text,
                            'label': label,
                            'source_dataset': 'hate_speech_curated',
                            'source_file': 'HateSpeechDatasetBalanced.csv'
                        })
                        valid_count += 1
                    except Exception:
                        continue
                print(f"    Processed {valid_count} valid samples from HateSpeechDatasetBalanced.csv")

        if all_data:
            result_df = pd.DataFrame(all_data).drop_duplicates(subset=['text'])
            output_path = os.path.join(processed_dir, "hate_speech_curated_cleaned.csv")
            result_df.to_csv(output_path, index=False)
            all_datasets_processed['hate_speech_curated'] = {
                'samples': len(result_df),
                'hate_samples': result_df['label'].sum()
            }
        return all_data

    def process_hate_speech_offensive():
        print("\nPROCESSING: hate_speech_offensive")
        dataset_path = os.path.join(base_dir, "hate_speech_and_offensive_language")
        all_data = []

        file_path = os.path.join(dataset_path, "labeled_data.csv")
        if os.path.exists(file_path):
            print("  Reading: labeled_data.csv")
            df = safe_read_csv(file_path)
            if not df.empty:
                print(f"    Columns: {list(df.columns)}")
                valid_count = 0
                for idx, row in df.iterrows():
                    try:
                        text = str(row.get('tweet', "")) if pd.notna(row.get('tweet', "")) else ""
                        class_val = row.get('class', 2)

                        if text.strip() == "" or text.lower() in ['tweet', 'class']:
                            continue

                        text = smart_text_cleaning(text)
                        if not text.strip():
                            continue

                        label = 1 if safe_int_convert(class_val) in [0, 1] else 0

                        all_data.append({
                            'text': text,
                            'label': label,
                            'source_dataset': 'hate_speech_offensive',
                            'source_file': 'labeled_data.csv',
                            'original_class': safe_int_convert(class_val)
                        })
                        valid_count += 1
                    except Exception:
                        continue
                print(f"    Processed {valid_count} samples from labeled_data.csv")

        if all_data:
            result_df = pd.DataFrame(all_data).drop_duplicates(subset=['text'])
            output_path = os.path.join(processed_dir, "hate_speech_offensive_cleaned.csv")
            result_df.to_csv(output_path, index=False)
            all_datasets_processed['hate_speech_offensive'] = {
                'samples': len(result_df),
                'hate_samples': result_df['label'].sum()
            }
        return all_data

    def process_suspicious_comm():
        print("\nPROCESSING: suspicious_comm")
        dataset_path = os.path.join(base_dir, "suspicious_communication_on_social_platforms")
        all_data = []

        file_path = os.path.join(dataset_path, "Suspicious Communication on Social Platforms.csv")
        if os.path.exists(file_path):
            print("  Reading: Suspicious Communication on Social Platforms.csv")
            df = safe_read_csv(file_path)
            if not df.empty:
                print(f"    Columns: {list(df.columns)}")
                valid_count = 0
                for idx, row in df.iterrows():
                    try:
                        text = str(row.get('comments', "")) if pd.notna(row.get('comments', "")) else ""
                        tagging_val = row.get('tagging', 0)

                        if text.strip() == "" or text.lower() in ['comments', 'tagging']:
                            continue

                        text = smart_text_cleaning(text)
                        if not text.strip():
                            continue

                        label = safe_int_convert(tagging_val)

                        all_data.append({
                            'text': text,
                            'label': label,
                            'source_dataset': 'suspicious_comm',
                            'source_file': 'Suspicious Communication on Social Platforms.csv'
                        })
                        valid_count += 1
                    except Exception:
                        continue
                print(f"    Processed {valid_count} samples from Suspicious Communication on Social Platforms.csv")

        if all_data:
            result_df = pd.DataFrame(all_data).drop_duplicates(subset=['text'])
            output_path = os.path.join(processed_dir, "suspicious_comm_cleaned.csv")
            result_df.to_csv(output_path, index=False)
            all_datasets_processed['suspicious_comm'] = {
                'samples': len(result_df),
                'hate_samples': result_df['label'].sum()
            }
        return all_data

    def process_jigsaw_toxic():
        print("\nPROCESSING: jigsaw_toxic")
        dataset_path = os.path.join(base_dir, "jigsaw-toxic-comment-classification-challenge")
        all_data = []

        # Process train.csv
        train_file = os.path.join(dataset_path, "train.csv")
        if os.path.exists(train_file):
            print("  Reading: train.csv")
            df = safe_read_csv(train_file)
            if not df.empty:
                print(f"    Columns: {list(df.columns)}")
                valid_count = 0
                toxic_count = 0
                for idx, row in df.iterrows():
                    try:
                        text = str(row.get('comment_text', "")) if pd.notna(row.get('comment_text', "")) else ""

                        if text.strip() == "" or text.lower() in ['comment_text', 'id']:
                            continue

                        text = smart_text_cleaning(text)
                        if not text.strip():
                            continue

                        toxic_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
                        is_toxic = any(safe_int_convert(row.get(label, 0)) == 1 for label in toxic_labels)

                        if is_toxic:
                            toxic_count += 1

                        all_data.append({
                            'text': text,
                            'label': 1 if is_toxic else 0,
                            'source_dataset': 'jigsaw_toxic',
                            'source_file': 'train.csv'
                        })
                        valid_count += 1
                    except Exception:
                        continue
                print(f"    Processed {valid_count} samples from train.csv ({toxic_count} toxic)")

        # Process test.csv with test_labels.csv
        test_file = os.path.join(dataset_path, "test.csv")
        test_labels_file = os.path.join(dataset_path, "test_labels.csv")

        if os.path.exists(test_file) and os.path.exists(test_labels_file):
            print("  Reading: test.csv with test_labels.csv")
            test_df = safe_read_csv(test_file)
            test_labels_df = safe_read_csv(test_labels_file)

            if not test_df.empty and not test_labels_df.empty:
                merged_df = pd.merge(test_df, test_labels_df, on='id', how='inner')
                valid_count = 0
                toxic_count = 0
                for idx, row in merged_df.iterrows():
                    try:
                        text = str(row.get('comment_text', "")) if pd.notna(row.get('comment_text', "")) else ""

                        if text.strip() == "" or text.lower() in ['comment_text', 'id']:
                            continue

                        if row.get('toxic', 0) == -1:
                            continue

                        text = smart_text_cleaning(text)
                        if not text.strip():
                            continue

                        toxic_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
                        is_toxic = any(safe_int_convert(row.get(label, 0)) == 1 for label in toxic_labels if row.get(label, 0) != -1)

                        if is_toxic:
                            toxic_count += 1

                        all_data.append({
                            'text': text,
                            'label': 1 if is_toxic else 0,
                            'source_dataset': 'jigsaw_toxic',
                            'source_file': 'test.csv'
                        })
                        valid_count += 1
                    except Exception:
                        continue
                print(f"    Processed {valid_count} samples from test data ({toxic_count} toxic)")

        if all_data:
            result_df = pd.DataFrame(all_data).drop_duplicates(subset=['text'])
            output_path = os.path.join(processed_dir, "jigsaw_toxic_cleaned.csv")
            result_df.to_csv(output_path, index=False)
            all_datasets_processed['jigsaw_toxic'] = {
                'samples': len(result_df),
                'hate_samples': result_df['label'].sum()
            }
        return all_data

    # Process all datasets
    print("PROCESSING TEXT DATASETS WITH SMART CLEANING...")

    process_hate_speech_curated()
    process_hate_speech_offensive()
    process_suspicious_comm()
    process_jigsaw_toxic()

    # Create combined dataset
    if all_datasets_processed:
        print("\nCREATING COMBINED DATASET...")

        all_dfs = []
        for dataset_name in all_datasets_processed.keys():
            try:
                csv_path = os.path.join(processed_dir, f"{dataset_name}_cleaned.csv")
                if os.path.exists(csv_path):
                    df = pd.read_csv(csv_path)
                    df['dataset_source'] = dataset_name
                    all_dfs.append(df)
                    print(f"  Added {dataset_name}: {len(df)} samples")
            except Exception as e:
                print(f"Error reading {dataset_name}: {e}")
                continue

        if all_dfs:
            combined_df = pd.concat(all_dfs, ignore_index=True)
            combined_df = combined_df.drop_duplicates(subset=['text'])

            combined_path = os.path.join(processed_dir, "ALL_TEXT_DATASETS_COMBINED.csv")
            combined_df.to_csv(combined_path, index=False)

            # Create train/val/test splits
            train_df, temp_df = train_test_split(
                combined_df, test_size=0.3, stratify=combined_df['label'], random_state=42
            )
            val_df, test_df = train_test_split(
                temp_df, test_size=0.5, stratify=temp_df['label'], random_state=42
            )

            # Create split directories
            splits_dir = os.path.join(processed_dir, "splits")
            os.makedirs(os.path.join(splits_dir, "train"), exist_ok=True)
            os.makedirs(os.path.join(splits_dir, "val"), exist_ok=True)
            os.makedirs(os.path.join(splits_dir, "test"), exist_ok=True)

            # Save splits
            train_df.to_csv(os.path.join(splits_dir, "train", "text_train.csv"), index=False)
            val_df.to_csv(os.path.join(splits_dir, "val", "text_val.csv"), index=False)
            test_df.to_csv(os.path.join(splits_dir, "test", "text_test.csv"), index=False)

            print("TEXT PROCESSING COMPLETED.")
    else:
        print("No datasets were successfully processed.")

if __name__ == "__main__":
    process_text_datasets()





"""
INTERPRETATION OF DATASET SEPARABILITY

The processed text datasets contain multiple sources of hate, toxic, offensive,
and suspicious communication. After cleaning and normalization, certain
datasets exhibit clearer separability between classes (hate vs. non-hate)
than others.

In particular:
- Some datasets contain highly explicit toxic or hateful language, making
  class boundaries sharper and easier for models to learn.
- Other datasets include more subtle or context-dependent expressions,
  producing greater overlap between the "hate" and "non-hate" categories.

Overall, the combined dataset reflects a mix of clearly separable samples
and samples with strong class ambiguity. This variety closely resembles
real-world social-media text, where explicit and subtle forms of harmful
communication coexist.
"""


In [None]:
"""
MEMOTION DATASET PREPROCESSING - WITH FIXED REPORT AND DOCUMENTATION
"""

import os
import pandas as pd
import re
from PIL import Image
import glob

def process_memotion_dataset():
    """Process Memotion 7K dataset with corrected label mappings and proper reporting"""

    base_dir = r"datasets"
    memotion_dir = os.path.join(base_dir, "memotion_dataset_7k")
    processed_dir = os.path.join(base_dir, "processed_data")
    os.makedirs(processed_dir, exist_ok=True)

    print("PROCESSING MEMOTION 7K DATASET...")

    if not os.path.exists(memotion_dir):
        print(f"ERROR: Memotion dataset not found: {memotion_dir}")
        return

    def smart_text_cleaning(text):
        """Smart cleaning that preserves important hate speech signals"""
        if not isinstance(text, str):
            return ""

        # Basic cleaning
        text = text.lower().strip()
        text = ' '.join(text.split())

        # Remove URLs and emails (usually noise)
        text = re.sub(r'http\S+|www\S+|https\S+', '', text)
        text = re.sub(r'\S+@\S+', '', text)

        # Normalize excessive punctuation (keep at least one)
        text = re.sub(r'!+', '!', text)  # "!!!!" → "!"
        text = re.sub(r'\?+', '?', text) # "????" → "?"
        text = re.sub(r'\.+', '.', text) # "......" → "."

        # Remove extra whitespace again
        text = ' '.join(text.split())

        return text

    def load_labels():
        """Load label data from CSV or Excel"""
        labels_file = os.path.join(memotion_dir, "labels.csv")
        if os.path.exists(labels_file):
            print(f"Loading labels from: labels.csv")
            df = pd.read_csv(labels_file)
        else:
            labels_file = os.path.join(memotion_dir, "labels.xlsx")
            if os.path.exists(labels_file):
                print(f"Loading labels from: labels.xlsx")
                df = pd.read_excel(labels_file)
            else:
                print("ERROR: No labels file found")
                return None

        print(f"   Loaded {len(df)} rows")
        print(f"   Columns: {list(df.columns)}")
        return df

    def find_image_files():
        """Find all image files"""
        images_dir = os.path.join(memotion_dir, "images")
        if not os.path.exists(images_dir):
            print(f"ERROR: Images directory not found: {images_dir}")
            return {}

        image_files = {}
        supported_formats = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']

        for format in supported_formats:
            pattern = os.path.join(images_dir, '**', format)
            for img_path in glob.glob(pattern, recursive=True):
                filename = os.path.basename(img_path)
                image_files[filename] = img_path

        print(f"Found {len(image_files)} image files")
        return image_files

    def get_offensive_label(offensive_str):
        """
        Convert offensive string to binary label

        Memotion offensive categories:
        - 'not_offensive' → 0 (Not Hate) - No offensive content
        - 'slight' → 0 (Not Hate) - Mild/borderline offensive, not severe enough for hate speech
        - 'offensive' → 1 (Hate) - Clearly offensive content
        - 'very_offensive' → 1 (Hate) - Highly offensive content
        - 'hateful_offensive' → 1 (Hate) - Hate speech specifically targeting groups
        """
        if not isinstance(offensive_str, str):
            return 0

        offensive_str = offensive_str.lower().strip()

        # Hate speech categories
        if offensive_str in ['offensive', 'very_offensive', 'hateful_offensive']:
            return 1
        # Not hate categories
        elif offensive_str in ['not_offensive', 'slight']:
            return 0
        # Default to not hate for unknown values
        else:
            return 0

    # Main processing
    labels_df = load_labels()
    if labels_df is None:
        return

    image_files = find_image_files()

    all_data = []
    processed_count = 0
    image_found_count = 0
    hate_count = 0

    print(f"PROCESSING {len(labels_df)} SAMPLES...")

    for idx, row in labels_df.iterrows():
        try:
            # Use text_corrected if available, otherwise text_ocr
            text = ""
            if pd.notna(row['text_corrected']):
                text = str(row['text_corrected'])
            elif pd.notna(row['text_ocr']):
                text = str(row['text_ocr'])
            else:
                continue

            # Skip header rows or invalid data
            if text.lower() in ['text_corrected', 'text_ocr', 'image_name'] or text == '':
                continue

            # Clean text using smart cleaning
            text = smart_text_cleaning(text)
            if not text or len(text) < 5:
                continue

            # Get offensive label (main hate speech indicator)
            offensive_str = row['offensive'] if pd.notna(row['offensive']) else "not_offensive"
            label = get_offensive_label(offensive_str)

            if label == 1:
                hate_count += 1

            # Find corresponding image
            image_name = str(row['image_name']) if pd.notna(row['image_name']) else ""
            image_path = None

            if image_name and image_name in image_files:
                image_path = image_files[image_name]
                # Validate image
                try:
                    with Image.open(image_path) as img:
                        img.verify()
                    image_found_count += 1
                except:
                    image_path = None
                    print(f"    WARNING: Corrupted image: {image_name}")

            all_data.append({
                'text': text,
                'label': label,
                'image_path': image_path if image_path else "",
                'image_name': image_name,
                'source_dataset': 'memotion_7k',
                'offensive_category': offensive_str,
                'humour': row['humour'] if pd.notna(row['humour']) else "",
                'sarcasm': row['sarcasm'] if pd.notna(row['sarcasm']) else "",
                'motivational': row['motivational'] if pd.notna(row['motivational']) else "",
                'overall_sentiment': row['overall_sentiment'] if pd.notna(row['overall_sentiment']) else "",
                'has_image': 1 if image_path else 0
            })

            processed_count += 1

            if processed_count % 1000 == 0:
                print(f"  Processed {processed_count}/{len(labels_df)} samples... ({hate_count} hate so far)")

        except Exception as e:
            continue

    # Create final dataset
    if all_data:
        memotion_df = pd.DataFrame(all_data)

        # Remove duplicates based on text and image
        initial_count = len(memotion_df)
        memotion_df = memotion_df.drop_duplicates(subset=['text', 'image_name'])
        final_count = len(memotion_df)

        # Save the dataset
        output_path = os.path.join(processed_dir, "memotion_7k_multimodal.csv")
        memotion_df.to_csv(output_path, index=False)

        # Calculate statistics
        total_samples = len(memotion_df)
        total_hate = memotion_df['label'].sum()
        images_available = memotion_df['has_image'].sum()

        # Analyze offensive category distribution
        offensive_dist = memotion_df['offensive_category'].value_counts()

        print(f"""
MEMOTION PROCESSING COMPLETED!
=================================

DATASET SUMMARY:
- Total samples: {total_samples:,}
- Hate speech samples: {total_hate:,}
- Normal samples: {total_samples - total_hate:,}
- Hate ratio: {(total_hate/total_samples)*100:.1f}%
- Images available: {images_available:,} ({images_available/total_samples*100:.1f}%)

OFFENSIVE CATEGORY DISTRIBUTION:""")

        for category, count in offensive_dist.items():
            percentage = (count / total_samples) * 100
            hate_label = "HATE" if get_offensive_label(category) == 1 else "NOT HATE"
            print(f"  - {category}: {count:,} ({percentage:.1f}%) [{hate_label}]")

        print(f"""
OUTPUT:
- Dataset: {output_path}

FILES PROCESSED:
- Labels: {len(labels_df)} rows
- Images found: {image_found_count} files
- Valid multimodal pairs: {total_samples}
- Duplicates removed: {initial_count - final_count:,}
""")

        # Save detailed analysis report
        report_path = os.path.join(processed_dir, "memotion_analysis_report.txt")
        with open(report_path, 'w', encoding='utf-8') as f:
            f.write(f"""Memotion 7K Dataset Analysis Report
=====================================

Dataset Location: {memotion_dir}
Processing Date: {pd.Timestamp.now()}

FINAL STATISTICS:
- Total samples processed: {total_samples}
- Hate speech samples: {total_hate}
- Normal samples: {total_samples - total_hate}
- Hate speech ratio: {(total_hate/total_samples)*100:.1f}%
- Images available: {images_available} ({images_available/total_samples*100:.1f}%)

OFFENSIVE CATEGORY BREAKDOWN:
""")
            for category, count in offensive_dist.items():
                percentage = (count / total_samples) * 100
                hate_label = "HATE" if get_offensive_label(category) == 1 else "NOT HATE"
                f.write(f"- {category}: {count} samples ({percentage:.1f}%) [{hate_label}]\n")

            f.write(f"""
PROCESSING DETAILS:
- Original label rows: {len(labels_df)}
- Image files found: {len(image_files)}
- Successful pairs: {processed_count}
- Images successfully matched: {image_found_count}
- Duplicates removed: {initial_count - final_count}

LABELING LOGIC EXPLANATION:
- HATE (1): 'offensive', 'very_offensive', 'hateful_offensive'
  * Clear hate speech, offensive content targeting groups

- NOT HATE (0): 'not_offensive', 'slight'
  * 'not_offensive': No offensive content
  * 'slight': Mild/borderline offensive (insults, subtle prejudice) but not severe hate speech

CATEGORY DEFINITIONS:
- not_offensive: No offensive content whatsoever
- slight: Mild offensive content, insults, borderline comments
- offensive: Clearly offensive content, profanity, strong insults
- very_offensive: Highly offensive content, severe language
- hateful_offensive: Hate speech specifically targeting racial, religious, or other groups

ADDITIONAL FEATURES:
- humour: Categorical humor classification
- sarcasm: Categorical sarcasm classification
- motivational: Motivational content flag
- overall_sentiment: Overall sentiment classification

OUTPUT FILES:
- Main dataset: {output_path}
- This report: {report_path}
""")

        print(f"Detailed analysis saved to: {report_path}")
        return memotion_df

    else:
        print("ERROR: No valid samples were processed!")
        return None

if __name__ == "__main__":
    process_memotion_dataset()




"""
MEMOTION DATASET PREPROCESSING - WITH FIXED REPORT AND DOCUMENTATION

DATA PROCESSING PIPELINE:
- Loads and validates Memotion 7K dataset with corrected label mappings
- Implements smart text cleaning that preserves hate speech signals
- Converts offensive categories to unified binary labels
- Matches text with corresponding images and validates image integrity
- Removes duplicates and generates comprehensive analysis reports

SMART TEXT CLEANING:
- Normalizes text while preserving important linguistic cues
- Removes URLs, emails, and excessive punctuation
- Maintains emotional intensity indicators relevant to hate speech detection

LABEL CONVERSION LOGIC:
- HATE (1): 'offensive', 'very_offensive', 'hateful_offensive'
- NOT HATE (0): 'not_offensive', 'slight' (mild/borderline content)

OVERALL INTERPRETATION:
This preprocessing system transforms the multimodal Memotion dataset into a clean, 
standardized format optimized for hate speech detection. It maintains the crucial 
relationship between text and images while ensuring data quality through validation, 
deduplication, and comprehensive reporting. The resulting corpus enables reliable 
multimodal model training with balanced class representation and traceable metadata.
"""

In [None]:
# train_novel_multimodal_t4_lora.py
# CLIP vision + consistency + hard-mining + LoRA (text default, vision optional)
# Direct 3-class labeling (Abusive, Offensive, Non-abusive) at dataset load time
# Adds: text/image projection to shared space, multimodal gate, image-only aux loss, text-modality dropout

# Optional installs:
# !pip install -U torch torchvision transformers scikit-learn pandas numpy pillow tqdm seaborn matplotlib peft

import os, gc, json, math, random, warnings, sys, subprocess
from dataclasses import dataclass
from typing import Optional, List, Tuple
from pathlib import Path

warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ============== Config ==============
@dataclass
class CFG:
    # Paths
    base_dir: str = "datasets"
    processed_dir_name: str = "processed_data"
    memotion_csv: str = "memotion_7k_multimodal.csv"
    splits_dirname: str = "splits"

    # Text CSVs
    text_train_csv: str = "train/text_train.csv"
    text_val_csv: str = "val/text_val.csv"
    text_test_csv: str = "test/text_test.csv"

    # Random subsampling for val/test (tracking during training)
    seed: int = 42
    sample_val_text_n: Optional[int] = 7_000
    sample_test_text_n: Optional[int] = 5_000
    sample_val_memotion_n: Optional[int] = None
    sample_test_memotion_n: Optional[int] = None

    # Model (text + vision)
    text_model_name: str = "nreimers/MiniLMv2-L6-H384-distilled-from-RoBERTa-Large"
    use_clip_vision: bool = True
    clip_vision_model: str = "openai/clip-vit-base-patch32"
    image_backbone_fallback: str = "mobilenet_v2"
    freeze_text_encoder: bool = True
    freeze_vision_encoder: bool = True
    gradient_checkpointing: bool = True
    use_torch_compile: bool = False

    # Tokenization / image
    max_len: int = 96
    image_size: int = 224
    pre_tokenize_text: bool = False

    # Training
    epochs: int = 30
    batch_size: int = 32
    grad_accum_steps: int = 1
    lr: float = 1e-4          # bumped for LoRA
    weight_decay: float = 1e-2
    warmup_ratio: float = 0.06
    label_smoothing: float = 0.05
    grad_clip: float = 1.0
    early_stop_patience: int = 4
    mixed_precision: bool = True

    # Architecture
    hidden: int = 256
    dropout: float = 0.25
    ab2_mode: str = "ce"
    add_meta_dim: int = 2
    use_multi_task: bool = True
    proj_dim: int = 256  # NEW: shared projection dimension for text/image

    # Loss tweaks
    lambda_consistency: float = 0.2
    image_only_loss_w: float = 0.5   # NEW: weight for image-only aux loss
    text_dropout_prob: float = 0.25  # NEW: chance to drop text per step

    # Hard-mining curriculum (TEXT)
    hard_mining: bool = True
    pool_size: int = 5_000
    train_text_per_epoch: int = 5_000
    hard_frac: float = 0.6

    # Memotion images per epoch
    images_per_epoch: int = 200

    # Balance
    use_class_weights: bool = True

    # LoRA config
    use_lora_text: bool = True
    use_lora_vision: bool = False      # enable if you want CLIP LoRA too
    lora_r_text: int = 8
    lora_alpha_text: int = 16
    lora_dropout_text: float = 0.1
    lora_target_text: Optional[List[str]] = None  # defaults to ["query","value"]

    lora_r_vision: int = 4
    lora_alpha_vision: int = 8
    lora_dropout_vision: float = 0.05
    lora_target_vision: Optional[List[str]] = None  # defaults to ["q_proj","v_proj"]

    # Reporting
    reports_dirname: str = "reports_novel"
    checkpoint_name: str = "best_novel.pt"
    drive_save_path: Optional[str] = None
    final_full_eval: bool = False

cfg = CFG()

# ============== Device / AMP / Repro ==============
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.cuda.empty_cache(); gc.collect()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    print("VRAM:", torch.cuda.get_device_properties(0).total_memory/1024**3, "GB")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

torch.set_num_threads(2)

from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler(enabled=cfg.mixed_precision)

def set_seed(seed: int):
    import numpy as np, random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

# ============== Imports ==============
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModel, AutoConfig, CLIPVisionModel, CLIPImageProcessor
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

# PEFT (LoRA)
try:
    from peft import LoraConfig, get_peft_model, TaskType
except ImportError:
    print("Installing 'peft' for LoRA...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "peft"])
    from peft import LoraConfig, get_peft_model, TaskType

# ============== JSON-safe ==============
def to_serializable(o):
    import numpy as np, torch
    if isinstance(o, dict):
        return {k: to_serializable(v) for k, v in o.items()}
    if isinstance(o, (list, tuple, set)):
        return [to_serializable(v) for v in o]
    if isinstance(o, np.ndarray):
        return o.tolist()
    if isinstance(o, (np.integer,)):
        return int(o)
    if isinstance(o, (np.floating,)):
        return float(o)
    if isinstance(o, torch.Tensor):
        return o.detach().cpu().tolist()
    if isinstance(o, Path):
        return str(o)
    return o

# ============== Label helpers ==============
def map_text_labels(row: pd.Series):
    hs3 = -100
    if "original_class" in row and str(row["original_class"]).strip() != "":
        try:
            oc = int(float(row["original_class"]))
            hs3 = oc if oc in (0,1,2) else -100
        except: hs3 = -100
    ab2 = -100
    if "label" in row and str(row["label"]).strip() != "":
        try:
            v = int(float(row["label"])); ab2 = 1 if v == 1 else 0
        except: ab2 = -100
    return hs3, ab2

def map_memotion_labels(off_cat: str):
    s = str(off_cat).strip().lower()
    if s == "hateful_offensive": hs3 = 0
    elif s in ("offensive","very_offensive"): hs3 = 1
    elif s in ("slight","not_offensive"): hs3 = 2
    else: hs3 = -100
    ab2 = 1 if s in ("offensive","very_offensive","hateful_offensive") else 0
    return hs3, ab2

def parse_sarcasm(val) -> int:
    s = str(val).strip().lower()
    return 1 if s in {"sarcasm","sarcastic","yes","true","1"} or "sarcas" in s else 0

def parse_humour(val) -> int:
    s = str(val).strip().lower()
    if s in {"", "none", "not_funny", "not funny", "no_humour", "no_humor"}: return 0
    return 1 if any(k in s for k in ["funny","hilar","humor","humour","very_funny"]) else 0

def one_hot_smooth(y: int, num_classes: int = 3, eps: float = 0.0) -> torch.Tensor:
    vec = torch.full((num_classes,), eps / (num_classes - 1), dtype=torch.float32)
    if y >= 0: vec[y] = 1.0 - eps
    else: vec[:] = 0.0
    return vec

class SoftCrossEntropyLoss(nn.Module):
    def __init__(self, ignore_index=-100): super().__init__(); self.ignore_index = ignore_index
    def forward(self, input, target):
        if target.dim() == 2:
            target = target.to(input.dtype)
            valid = target.sum(dim=-1) > 0
            if not valid.any():
                return input.new_tensor(0.0)
            log_probs = F.log_softmax(input[valid], dim=-1)
            return -(target[valid] * log_probs).sum(dim=-1).mean()
        else:
            return F.cross_entropy(input, target, ignore_index=self.ignore_index)

def compute_weights_from_counts(counts: np.ndarray):
    if counts.sum() == 0: return None
    freqs = counts / counts.sum()
    w = 1.0 / np.clip(freqs, 1e-8, None)
    w = w / (w.mean() + 1e-8)
    return torch.tensor(w, dtype=torch.float32)

def build_losses(ab2_mode="ce", hs3_class_weights=None, ab2_class_weights=None, device="cpu", label_smoothing=0.0):
    def _ce(weight):
        try: return nn.CrossEntropyLoss(ignore_index=-100, weight=weight, label_smoothing=label_smoothing)
        except TypeError: return nn.CrossEntropyLoss(ignore_index=-100, weight=weight)
    hs3_w = hs3_class_weights.to(device) if isinstance(hs3_class_weights, torch.Tensor) else None
    ab2_w = ab2_class_weights.to(device) if isinstance(ab2_class_weights, torch.Tensor) else None
    crit = {"hs3": SoftCrossEntropyLoss(ignore_index=-100)}
    crit["ab2"] = _ce(ab2_w) if ab2_mode == "ce" else nn.BCEWithLogitsLoss(pos_weight=ab2_w[1:2] if ab2_w is not None else None)
    return crit

def build_scheduler(optimizer, train_loader_len, cfg):
    updates_per_epoch = max(1, math.ceil(train_loader_len / cfg.grad_accum_steps))
    total_updates = updates_per_epoch * cfg.epochs
    warmup_updates = max(1, int(cfg.warmup_ratio * total_updates))
    def lr_lambda(step):
        if step < warmup_updates:
            return float(step) / float(max(1, warmup_updates))
        progress = (step - warmup_updates) / float(max(1, total_updates - warmup_updates))
        progress = min(1.0, max(0.0, progress))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda), total_updates

# ============== Tokenizers / processors ==============
text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name, use_fast=True)
clip_processor = CLIPImageProcessor.from_pretrained(cfg.clip_vision_model) if cfg.use_clip_vision else None
img_tfm_fallback = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# ============== Dataset (Memotion path FIX + Fused labels at load time) ==============
class OptimizedProcessedCSVSet(Dataset):
    def __init__(self, csv_path: Path, is_memotion=False, split: Optional[str]=None,
                 sample_n: Optional[int]=None, pre_tokenize: Optional[bool]=None):
        self.is_memotion = is_memotion
        self.split = split
        self.pre_tokenize = cfg.pre_tokenize_text if pre_tokenize is None else pre_tokenize

        hdr = pd.read_csv(csv_path, nrows=0).columns
        if is_memotion:
            want = ["text", "image_path", "offensive_category", "sarcasm", "humour"]
        else:
            want = ["text", "original_class", "label"]
        usecols = [c for c in want if c in hdr]
        df = pd.read_csv(csv_path, usecols=usecols or None).fillna("")

        if is_memotion and split in {"train","val","test"}:
            rng = np.random.RandomState(123)
            idx = np.arange(len(df)); rng.shuffle(idx)
            n = len(idx); n_train=int(0.8*n); n_val=int(0.1*n)
            part = {"train": idx[:n_train], "val": idx[n_train:n_train+n_val], "test": idx[n_train+n_val:]}[split]
            df = df.iloc[part].reset_index(drop=True)

        if "text" in df.columns and 3 > 0:
            df = df[df["text"].astype(str).str.len() >= 3]

        if sample_n is not None and len(df) > sample_n:
            df = df.sample(n=int(sample_n), random_state=cfg.seed, replace=False).reset_index(drop=True)

        self.df = df.reset_index(drop=True)

        # Fused labels (0=Abusive, 1=Offensive, 2=Non-abusive), fallback to ab2 if hs3 missing
        self.labels = []
        for _, row in self.df.iterrows():
            if is_memotion:
                hs3, ab2 = map_memotion_labels(row.get("offensive_category",""))
            else:
                hs3, ab2 = map_text_labels(row)
            fused_hs3 = hs3 if hs3 != -100 else (2 if ab2 == 0 else (1 if ab2 == 1 else -100))
            self.labels.append((fused_hs3, ab2))

        self.use_half = torch.cuda.is_available() and cfg.mixed_precision

        self.enc_input_ids, self.enc_attention_mask = None, None
        if self.pre_tokenize and len(self.df) > 0:
            texts = self.df["text"].astype(str).tolist()
            enc = text_tokenizer(texts, truncation=True, max_length=cfg.max_len, padding=False)
            self.enc_input_ids, self.enc_attention_mask = enc["input_ids"], enc["attention_mask"]

    def __len__(self): return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        text = str(row.get("text","")).strip()

        if self.enc_input_ids is not None:
            input_ids = self.enc_input_ids[i]; attention_mask = self.enc_attention_mask[i]
        else:
            enc = text_tokenizer(text, truncation=True, max_length=cfg.max_len, padding=False)
            input_ids, attention_mask = enc["input_ids"], enc["attention_mask"]

        img = None; has_image = False
        if self.is_memotion:
            fname = Path(str(row.get("image_path","")).strip()).name
            if fname:
                img_root = Path(cfg.base_dir) / cfg.processed_dir_name / "memotion_dataset_7k" / "images"
                p = img_root / fname
                if p.exists():
                    try:
                        pil = Image.open(p).convert("RGB")
                        if cfg.use_clip_vision:
                            img = clip_processor(images=pil, return_tensors="pt")["pixel_values"][0]
                        else:
                            img = img_tfm_fallback(pil)
                        has_image = True
                    except Exception:
                        has_image = False

        if img is None:
            img = torch.zeros(3, cfg.image_size, cfg.image_size, dtype=torch.float32)
        if self.use_half: img = img.half()

        sarcasm = parse_sarcasm(row.get("sarcasm","")) if self.is_memotion and "sarcasm" in row else 0
        humour  = parse_humour(row.get("humour",""))   if self.is_memotion and "humour" in row else 0
        meta = torch.tensor([sarcasm, humour], dtype=torch.float16 if self.use_half else torch.float32)

        hs3, ab2 = self.labels[i]
        hs3_target = one_hot_smooth(int(hs3), num_classes=3, eps=cfg.label_smoothing)
        if self.use_half: hs3_target = hs3_target.half()

        return {
            "text": text,
            "input_ids": input_ids, "attention_mask": attention_mask,
            "image": img, "has_image": torch.tensor(has_image, dtype=torch.bool),
            "meta": meta,
            "hs3_label": torch.tensor(hs3, dtype=torch.long),
            "hs3_target": hs3_target,
            "ab2_label": torch.tensor(ab2, dtype=torch.long),
        }

# ============== Model (LoRA + mean pooling + projections + multimodal gate) ==============
class OptimizedMultiModalGated(nn.Module):
    def __init__(self, text_model_name, freeze_text=True, use_clip=True, clip_model_name=None,
                 fallback_backbone="mobilenet_v2", freeze_vision=True,
                 use_multi_task=True, ab2_mode="ce", hidden=256, dropout=0.25, add_meta_dim=2,
                 gradient_checkpointing=True):
        super().__init__()
        self.use_multi_task = use_multi_task
        self.ab2_mode = ab2_mode
        self.add_meta_dim = add_meta_dim
        self.use_clip = use_clip
        self.proj_dim = cfg.proj_dim

        # Text encoder (disable pooler + mean pooling)
        txt_cfg = AutoConfig.from_pretrained(text_model_name)
        if hasattr(txt_cfg, "add_pooling_layer"):
            txt_cfg.add_pooling_layer = False
        self.text_model = AutoModel.from_pretrained(text_model_name, config=txt_cfg)
        if gradient_checkpointing and hasattr(self.text_model, "gradient_checkpointing_enable"):
            try: self.text_model.gradient_checkpointing_enable()
            except: pass
        tdim = self.text_model.config.hidden_size
        if freeze_text:
            for p in self.text_model.parameters(): p.requires_grad = False

        # LoRA on text
        if getattr(cfg, "use_lora_text", False):
            targets = cfg.lora_target_text or ["query","value"]
            lora_text_cfg = LoraConfig(
                r=cfg.lora_r_text, lora_alpha=cfg.lora_alpha_text, lora_dropout=cfg.lora_dropout_text,
                target_modules=targets, bias="none", task_type=TaskType.FEATURE_EXTRACTION
            )
            self.text_model = get_peft_model(self.text_model, lora_text_cfg)

        # Vision encoder
        if use_clip:
            self.vision = CLIPVisionModel.from_pretrained(clip_model_name or "openai/clip-vit-base-patch32")
            idim = self.vision.config.hidden_size
            if freeze_vision:
                for p in self.vision.parameters(): p.requires_grad = False
            if getattr(cfg, "use_lora_vision", False):
                vtargets = cfg.lora_target_vision or ["q_proj","v_proj"]
                lora_vis_cfg = LoraConfig(
                    r=cfg.lora_r_vision, lora_alpha=cfg.lora_alpha_vision, lora_dropout=cfg.lora_dropout_vision,
                    target_modules=vtargets, bias="none", task_type=TaskType.FEATURE_EXTRACTION
                )
                self.vision = get_peft_model(self.vision, lora_vis_cfg)
            self.use_clip = True
        else:
            if fallback_backbone.lower() == "mobilenet_v2":
                try:
                    from torchvision.models import MobileNet_V2_Weights
                    im = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
                except Exception:
                    im = models.mobilenet_v2(pretrained=True)
                idim = 1280; im.classifier = nn.Identity()
            else:
                try:
                    from torchvision.models import ResNet18_Weights
                    backbone = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
                except Exception:
                    backbone = models.resnet18(pretrained=True)
                idim = 512; im = nn.Sequential(*list(backbone.children())[:-1], nn.Flatten(1))
            if freeze_vision:
                for p in im.parameters(): p.requires_grad = False
            self.vision = im
            self.use_clip = False

        # Projections to shared space
        self.t_proj = nn.Sequential(
            nn.Linear(tdim, self.proj_dim), nn.LayerNorm(self.proj_dim),
            nn.GELU(), nn.Dropout(dropout)
        )
        self.i_proj = nn.Sequential(
            nn.Linear(idim, self.proj_dim), nn.LayerNorm(self.proj_dim),
            nn.GELU(), nn.Dropout(dropout)
        )

        # Backbones use projected features
        fusion_in = self.proj_dim + self.proj_dim + (add_meta_dim if add_meta_dim>0 else 0)
        text_in   = self.proj_dim + (add_meta_dim if add_meta_dim>0 else 0)

        self.fusion_backbone = nn.Sequential(
            nn.Linear(fusion_in, hidden), nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.GELU(), nn.Dropout(dropout),
        )
        self.text_backbone = nn.Sequential(
            nn.Linear(text_in, hidden), nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.GELU(), nn.Dropout(dropout),
        )

        # Gate sees both modalities (+meta)
        gate_in = self.proj_dim + self.proj_dim + add_meta_dim
        self.gate_proj = nn.Sequential(
            nn.Linear(gate_in, 64), nn.ReLU(),
            nn.Linear(64, 1), nn.Sigmoid()
        )

        # Heads
        self.hs3_fusion_head = nn.Linear(hidden//2, 3)
        self.hs3_text_head   = nn.Linear(hidden//2, 3)
        self.ab2_fusion_head = nn.Linear(hidden//2, 2)
        self.ab2_text_head   = nn.Linear(hidden//2, 2)

    def _mean_pool(self, last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(self, input_ids, attention_mask, images, has_image_mask, meta=None):
        # Text
        out_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        t_last = out_text.last_hidden_state
        t_feat = self._mean_pool(t_last, attention_mask)

        if (meta is not None) and (meta.dtype != t_feat.dtype):
            meta = meta.to(t_feat.dtype)

        # Vision
        B = images.size(0)
        img_mask_flat = has_image_mask.view(-1)
        if img_mask_flat.any():
            idx = torch.nonzero(img_mask_flat, as_tuple=False).squeeze(1)
            images_sub = images[idx].to(memory_format=torch.channels_last)
            if self.use_clip:
                v_out = self.vision(pixel_values=images_sub)
                i_sub = v_out.pooler_output
            else:
                i_sub = self.vision(images_sub)
            i_sub = i_sub.to(images_sub.dtype)
            i_feat = images.new_zeros((B, self.t_proj[0].out_features), dtype=images_sub.dtype)  # placeholder; we proj later
            # We'll store raw i_sub first then project for all B rows below
            raw_i_feat = images.new_zeros((B, self.i_proj[0].in_features), dtype=images_sub.dtype)
            raw_i_feat[idx] = i_sub
        else:
            raw_i_feat = images.new_zeros((B, self.i_proj[0].in_features), dtype=images.dtype)

        # Project both
        t_vec = self.t_proj(t_feat)
        i_vec = self.i_proj(raw_i_feat)
        img_mask = img_mask_flat.float().unsqueeze(1)

        # Inputs
        if self.add_meta_dim and meta is not None:
            t_in = torch.cat([t_vec, meta], dim=1)
            f_in = torch.cat([t_vec, i_vec, meta], dim=1)
            g_in = torch.cat([t_vec, i_vec, meta], dim=1)
        else:
            t_in = t_vec
            f_in = torch.cat([t_vec, i_vec], dim=1)
            g_in = torch.cat([t_vec, i_vec], dim=1)

        t_repr = self.text_backbone(t_in)
        f_repr = self.fusion_backbone(f_in)

        gate_weight = self.gate_proj(g_in) * img_mask

        hs3_f = self.hs3_fusion_head(f_repr)
        hs3_t = self.hs3_text_head(t_repr)
        hs3_logits = gate_weight * hs3_f + (1 - gate_weight) * hs3_t

        ab2_f = self.ab2_fusion_head(f_repr)
        ab2_t = self.ab2_text_head(t_repr)
        ab2_logits = gate_weight * ab2_f + (1 - gate_weight) * ab2_t

        return {
            "hs3_logits": hs3_logits,
            "ab2_logits": ab2_logits,
            "gate_weight": gate_weight.detach(),
            "hs3_f_only": hs3_f, "ab2_f_only": ab2_f,
            "hs3_t_only": hs3_t, "ab2_t_only": ab2_t
        }

# ============== Collate / utils ==============
def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if not batch: return None
    pad = text_tokenizer.pad([{"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]} for b in batch],
                             padding=True, pad_to_multiple_of=8, return_tensors="pt")
    return {
        "text": [b["text"] for b in batch],
        "input_ids": pad["input_ids"],
        "attention_mask": pad["attention_mask"],
        "image": torch.stack([b["image"] for b in batch]),
        "has_image": torch.stack([b["has_image"] for b in batch]),
        "meta": torch.stack([b["meta"] for b in batch]),
        "hs3_label": torch.stack([b["hs3_label"] for b in batch]),
        "hs3_target": torch.stack([b["hs3_target"] for b in batch]),
        "ab2_label": torch.stack([b["ab2_label"] for b in batch]),
    }

def clear_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache(); torch.cuda.synchronize()
    gc.collect()

def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU Mem - Alloc {torch.cuda.memory_allocated()/1024**3:.2f} GB | Reserved {torch.cuda.memory_reserved()/1024**3:.2f} GB")

# ============== Consistency loss ==============
def hierarchical_consistency_loss(hs3_logits, ab2_logits):
    p_hs3 = F.softmax(hs3_logits, dim=-1)
    q_ab2 = torch.stack([p_hs3[:,2], (p_hs3[:,0] + p_hs3[:,1])], dim=1).clamp_min(1e-8)
    p_ab2 = F.softmax(ab2_logits, dim=-1).clamp_min(1e-8)
    return F.kl_div(p_ab2.log(), q_ab2, reduction="batchmean") + F.kl_div(q_ab2.log(), p_ab2, reduction="batchmean")

# ============== Evaluation ==============
@torch.no_grad()
def evaluate(loader, model, crit, split="val"):
    model.eval()
    agg = {"loss":0.0, "n":0}
    hs3_y, hs3_p = [], []
    ab2_y, ab2_p = [], []

    for batch in tqdm(loader, desc=f"Evaluating {split}", leave=False):
        if batch is None: continue
        ids = batch["input_ids"].to(device, non_blocking=True)
        attn = batch["attention_mask"].to(device, non_blocking=True)
        imgs = batch["image"].to(device, non_blocking=True)
        has_img = batch["has_image"].to(device, non_blocking=True)
        meta = batch["meta"].to(device, non_blocking=True) if cfg.add_meta_dim > 0 else None
        y_hs3 = batch["hs3_label"].to(device, non_blocking=True)
        y_ab2 = batch["ab2_label"].to(device, non_blocking=True)
        hs3_target = batch["hs3_target"].to(device, non_blocking=True)

        with autocast(enabled=cfg.mixed_precision, dtype=torch.float16):
            out = model(ids, attn, imgs, has_img, meta)
            hs3_loss = crit["hs3"](out["hs3_logits"], hs3_target)
            ab2_loss = crit["ab2"](out["ab2_logits"], y_ab2)
            cons_loss = hierarchical_consistency_loss(out["hs3_logits"], out["ab2_logits"]) * cfg.lambda_consistency
            # image-only aux loss
            img_only_loss = imgs.new_tensor(0.0)
            img_mask = has_img.view(-1)
            if cfg.image_only_loss_w > 0 and img_mask.any():
                img_only_hs3 = crit["hs3"](out["hs3_f_only"][img_mask], hs3_target[img_mask])
                img_only_ab2 = crit["ab2"](out["ab2_f_only"][img_mask], y_ab2[img_mask])
                img_only_loss = cfg.image_only_loss_w * (img_only_hs3 + img_only_ab2)

            loss = hs3_loss + ab2_loss + cons_loss + img_only_loss

        agg["loss"] += float(loss.item()); agg["n"] += 1

        m_hs3 = y_hs3 != -100
        if m_hs3.any():
            hs3_y.append(y_hs3[m_hs3].cpu().numpy())
            hs3_p.append(out["hs3_logits"][m_hs3].argmax(-1).cpu().numpy())

        m_ab2 = y_ab2 != -100
        if m_ab2.any():
            ab2_y.append(y_ab2[m_ab2].cpu().numpy())
            ab2_p.append(out["ab2_logits"][m_ab2].argmax(-1).cpu().numpy())

    def agg_metrics(Y, P):
        if len(Y)==0: return None
        y = np.concatenate(Y); p = np.concatenate(P)
        return dict(
            acc=accuracy_score(y,p),
            f1_macro=f1_score(y,p,average="macro"),
            precision_macro=precision_score(y,p,average="macro",zero_division=0),
            recall_macro=recall_score(y,p,average="macro",zero_division=0),
            y=y, p=p
        )

    hs3_m = agg_metrics(hs3_y, hs3_p)
    ab2_m = agg_metrics(ab2_y, ab2_p)
    print(f"[{split}] loss={agg['loss']/max(1,agg['n']):.4f}")
    if hs3_m: print(f"[{split}] 3-Class: acc={hs3_m['acc']:.4f} f1={hs3_m['f1_macro']:.4f}")
    if ab2_m: print(f"[{split}] AB2: acc={ab2_m['acc']:.4f} f1={ab2_m['f1_macro']:.4f}")
    return {"loss": agg["loss"]/max(1,agg["n"]), "hs3": hs3_m, "ab2": ab2_m}

# ============== Hard-mining (text) + Memotion sampling helpers ==============
@torch.no_grad()
def score_difficulty_text(model, dataset: Dataset, indices: List[int], batch_size=256) -> Tuple[List[int], np.ndarray]:
    model.eval()
    pool_ds = Subset(dataset, indices)
    loader = DataLoader(pool_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate_fn)
    diffs = np.zeros(len(indices), dtype=np.float32)
    off = 0
    for batch in tqdm(loader, desc="Scoring hardness", leave=False):
        ids = batch["input_ids"].to(device, non_blocking=True)
        attn = batch["attention_mask"].to(device, non_blocking=True)
        imgs = batch["image"].to(device, non_blocking=True)
        has_img = batch["has_image"].to(device, non_blocking=True)
        meta = batch["meta"].to(device, non_blocking=True) if cfg.add_meta_dim > 0 else None
        y_hs3 = batch["hs3_label"].to(device, non_blocking=True)
        y_ab2 = batch["ab2_label"].to(device, non_blocking=True)
        with autocast(enabled=cfg.mixed_precision, dtype=torch.float16):
            out = model(ids, attn, imgs, has_img, meta)
            hs3_log = out["hs3_logits"]; ab2_log = out["ab2_logits"]
            hs3_ce = torch.zeros(ids.size(0), device=device)
            m_h = (y_hs3 != -100)
            if m_h.any():
                hs3_ce[m_h] = F.cross_entropy(hs3_log[m_h], y_hs3[m_h], reduction="none")
            ab2_ce = torch.zeros(ids.size(0), device=device)
            m_a = (y_ab2 != -100)
            if m_a.any():
                ab2_ce[m_a] = F.cross_entropy(ab2_log[m_a], y_ab2[m_a], reduction="none")
            diff = (hs3_ce + ab2_ce).detach().float().cpu().numpy()
        diffs[off:off+len(diff)] = diff
        off += len(diff)
    return indices, diffs

def select_hard_examples(model, full_text_ds: Dataset, pool_size: int, select_n: int, hard_frac: float, epoch: int) -> List[int]:
    N = len(full_text_ds)
    pool_size = min(pool_size, N)
    rng_pool = np.random.RandomState(cfg.seed + 1009*epoch)
    pool_idx = rng_pool.choice(N, size=pool_size, replace=False).tolist()
    _, diffs = score_difficulty_text(model, full_text_ds, pool_idx, batch_size=256)
    k_hard = int(select_n * hard_frac)
    hard_idx_rel = np.argsort(-diffs)[:k_hard]
    hard_idx = [pool_idx[i] for i in hard_idx_rel]
    remaining = select_n - k_hard
    remaining_candidates = [i for i in range(N) if i not in set(hard_idx)]
    rng_rest = np.random.RandomState(cfg.seed + 2027*epoch)
    rand_rel = rng_rest.choice(len(remaining_candidates), size=remaining, replace=False)
    rand_idx = [remaining_candidates[i] for i in rand_rel]
    return hard_idx + rand_idx

def get_memotion_valid_indices(ds_memo: OptimizedProcessedCSVSet) -> List[int]:
    valid = []
    img_root = Path(cfg.base_dir) / cfg.processed_dir_name / "memotion_dataset_7k" / "images"
    for i, row in ds_memo.df.iterrows():
        fname = Path(str(row.get("image_path","")).strip()).name
        if fname and (img_root / fname).exists():
            valid.append(i)
    return valid

def sample_memotion_indices(valid_idx: List[int], n: int, epoch: int) -> List[int]:
    n = min(n, len(valid_idx))
    rng = np.random.RandomState(cfg.seed + 707*epoch)
    sel = rng.choice(valid_idx, size=n, replace=False).tolist()
    return sel

# ============== Train ==============
def train_novel():
    clear_memory()
    processed_dir = Path(cfg.base_dir) / cfg.processed_dir_name
    splits_dir = processed_dir / cfg.splits_dirname
    reports_dir = processed_dir / cfg.reports_dirname
    reports_dir.mkdir(parents=True, exist_ok=True)

    memotion_path = processed_dir / cfg.memotion_csv
    assert memotion_path.exists(), f"Missing memotion CSV: {memotion_path}"
    train_text_csv = splits_dir / cfg.text_train_csv
    val_text_csv   = splits_dir / cfg.text_val_csv
    test_text_csv  = splits_dir / cfg.text_test_csv
    for p in [train_text_csv, val_text_csv, test_text_csv]:
        assert p.exists(), f"Missing: {p}"

    print("Paths verified")

    # Datasets
    ds_train_text_full = OptimizedProcessedCSVSet(train_text_csv, is_memotion=False, sample_n=None, pre_tokenize=False)
    ds_val_text   = OptimizedProcessedCSVSet(val_text_csv,   is_memotion=False, sample_n=cfg.sample_val_text_n, pre_tokenize=True)
    ds_test_text  = OptimizedProcessedCSVSet(test_text_csv,  is_memotion=False, sample_n=cfg.sample_test_text_n, pre_tokenize=True)

    ds_val_memo   = OptimizedProcessedCSVSet(memotion_path, is_memotion=True, split="val",  sample_n=cfg.sample_val_memotion_n, pre_tokenize=True)
    ds_test_memo  = OptimizedProcessedCSVSet(memotion_path, is_memotion=True, split="test", sample_n=cfg.sample_test_memotion_n, pre_tokenize=True)
    ds_train_memo_full = OptimizedProcessedCSVSet(memotion_path, is_memotion=True, split="train", sample_n=None, pre_tokenize=True)

    memotion_valid_idx = get_memotion_valid_indices(ds_train_memo_full)
    if len(memotion_valid_idx) == 0:
        print("No valid Memotion images found under memotion_dataset_7k/images. Training will use zero-image placeholders.")

    # Initial subsets: 5k text + 200 images (you can increase images_per_epoch to 1k+ for stronger vision)
    text_init_size = min(cfg.train_text_per_epoch, len(ds_train_text_full))
    init_text_idx = np.random.RandomState(cfg.seed).choice(len(ds_train_text_full), size=text_init_size, replace=False)
    ds_train_text_epoch = Subset(ds_train_text_full, init_text_idx.tolist())

    memo_init_idx = sample_memotion_indices(memotion_valid_idx, cfg.images_per_epoch, epoch=0)
    ds_train_memo_epoch = Subset(ds_train_memo_full, memo_init_idx)

    train_ds = ConcatDataset([ds_train_text_epoch, ds_train_memo_epoch])
    val_ds   = ConcatDataset([ds_val_text, ds_val_memo])
    test_ds  = ConcatDataset([ds_test_text, ds_test_memo])

    # Dataloaders
    cpu_ct = os.cpu_count() or 2
    num_workers = min(8, max(2, cpu_ct // 2))
    dl_kwargs = dict(num_workers=num_workers, collate_fn=collate_fn, pin_memory=(device=="cuda"), persistent_workers=True, prefetch_factor=4)
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, **dl_kwargs)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, **dl_kwargs)
    test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, **dl_kwargs)

    print(f"Datasets — Train:{len(train_ds)} (text {len(ds_train_text_epoch)} + memo {len(ds_train_memo_epoch)}) | Val:{len(val_ds)} | Test:{len(test_ds)}")

    # Model
    model = OptimizedMultiModalGated(
        text_model_name=cfg.text_model_name,
        freeze_text=cfg.freeze_text_encoder,
        use_clip=cfg.use_clip_vision,
        clip_model_name=cfg.clip_vision_model,
        fallback_backbone=cfg.image_backbone_fallback,
        freeze_vision=cfg.freeze_vision_encoder,
        use_multi_task=cfg.use_multi_task,
        ab2_mode=cfg.ab2_mode,
        hidden=cfg.hidden,
        dropout=cfg.dropout,
        add_meta_dim=cfg.add_meta_dim,
        gradient_checkpointing=cfg.gradient_checkpointing
    ).to(device)

    if device == "cuda":
        model = model.to(memory_format=torch.channels_last)

    if cfg.use_torch_compile and hasattr(torch, "compile"):
        try:
            model = torch.compile(model)
            print("torch.compile enabled")
        except Exception as e:
            print(f"torch.compile skipped: {e}")

    # Optimizer
    params_to_train = [p for p in model.parameters() if p.requires_grad]
    try:
        optimizer = torch.optim.AdamW(params_to_train, lr=cfg.lr, weight_decay=cfg.weight_decay, fused=True)
    except TypeError:
        optimizer = torch.optim.AdamW(params_to_train, lr=cfg.lr, weight_decay=cfg.weight_decay)

    # Class weights (from current train subsets)
    def counts_from_concat(text_subset: Dataset, memo_subset: Dataset, label_pos, n_classes):
        counts = np.zeros(n_classes, dtype=np.int64)
        for i in range(len(text_subset)):
            idx = text_subset.indices[i] if isinstance(text_subset, Subset) else i
            y = ds_train_text_full.labels[idx][label_pos]
            if y != -100: counts[y] += 1
        for i in range(len(memo_subset)):
            idx = memo_subset.indices[i] if isinstance(memo_subset, Subset) else i
            y_h, y_a = ds_train_memo_full.labels[idx]
            if label_pos == 0 and y_h != -100: counts[y_h] += 1
            if label_pos == 1 and y_a != -100: counts[y_a] += 1
        return counts

    hs3_w = ab2_w = None
    if cfg.use_class_weights:
        hs3_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=0, n_classes=3)
        ab2_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=1, n_classes=2)
        hs3_w = compute_weights_from_counts(hs3_counts); ab2_w = compute_weights_from_counts(ab2_counts)

    crit = build_losses(cfg.ab2_mode, hs3_w, ab2_w, device=device, label_smoothing=cfg.label_smoothing)
    scheduler, total_updates = build_scheduler(optimizer, len(train_loader), cfg)

    print("Training started")
    print(f"Steps/epoch: {len(train_loader)} | Total steps: {total_updates}")
    print_gpu_memory()

    best_score = -1.0
    patience = 0
    history = []

    def unwrap_for_save(m):
        if hasattr(m, "_orig_mod"): m = m._orig_mod
        if hasattr(m, "module"): m = m.module
        return m

    for epoch in range(1, cfg.epochs+1):
        # Rebuild subsets each epoch
        if cfg.hard_mining:
            print(f"\nHard-mining epoch {epoch}: pool {cfg.pool_size} -> select {cfg.train_text_per_epoch} (hard_frac={cfg.hard_frac})")
            with torch.no_grad():
                sel_text_idx = select_hard_examples(model, ds_train_text_full, cfg.pool_size, cfg.train_text_per_epoch, cfg.hard_frac, epoch)
            ds_train_text_epoch = Subset(ds_train_text_full, sel_text_idx)

        sel_memo_idx = sample_memotion_indices(memotion_valid_idx, cfg.images_per_epoch, epoch)
        ds_train_memo_epoch = Subset(ds_train_memo_full, sel_memo_idx)

        train_ds = ConcatDataset([ds_train_text_epoch, ds_train_memo_epoch])
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, **dl_kwargs)

        if cfg.use_class_weights:
            hs3_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=0, n_classes=3)
            ab2_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=1, n_classes=2)
            hs3_w = compute_weights_from_counts(hs3_counts); ab2_w = compute_weights_from_counts(ab2_counts)
            crit = build_losses(cfg.ab2_mode, hs3_w, ab2_w, device=device, label_smoothing=cfg.label_smoothing)

        print(f"Epoch {epoch} train sizes — text: {len(ds_train_text_epoch)}, memotion: {len(ds_train_memo_epoch)}")

        model.train()
        running_loss = 0.0
        total_batches = 0

        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}/{cfg.epochs}")
        optimizer.zero_grad(set_to_none=True)
        for it, batch in pbar:
            if batch is None: continue
            ids = batch["input_ids"].to(device, non_blocking=True)
            attn = batch["attention_mask"].to(device, non_blocking=True)
            imgs = batch["image"].to(device, non_blocking=True)
            has_img = batch["has_image"].to(device, non_blocking=True)
            meta = batch["meta"].to(device, non_blocking=True) if cfg.add_meta_dim > 0 else None
            y_ab2 = batch["ab2_label"].to(device, non_blocking=True)
            hs3_target = batch["hs3_target"].to(device, non_blocking=True)

            # Modality dropout for text
            if cfg.text_dropout_prob > 0 and random.random() < cfg.text_dropout_prob:
                ids = ids.clone(); attn = torch.zeros_like(attn)
                ids[:] = text_tokenizer.pad_token_id

            with autocast(enabled=cfg.mixed_precision, dtype=torch.float16):
                out = model(ids, attn, imgs, has_img, meta)
                hs3_loss = crit["hs3"](out["hs3_logits"], hs3_target)
                ab2_loss = crit["ab2"](out["ab2_logits"], y_ab2)
                cons_loss = hierarchical_consistency_loss(out["hs3_logits"], out["ab2_logits"]) * cfg.lambda_consistency

                # image-only aux loss (cheap; reuse fusion-only heads)
                img_only_loss = imgs.new_tensor(0.0)
                img_mask = has_img.view(-1)
                if cfg.image_only_loss_w > 0 and img_mask.any():
                    img_only_hs3 = crit["hs3"](out["hs3_f_only"][img_mask], hs3_target[img_mask])
                    img_only_ab2 = crit["ab2"](out["ab2_f_only"][img_mask], y_ab2[img_mask])
                    img_only_loss = cfg.image_only_loss_w * (img_only_hs3 + img_only_ab2)

                loss = (hs3_loss + ab2_loss + cons_loss + img_only_loss) / cfg.grad_accum_steps

            if scaler.is_enabled():
                scaler.scale(loss).backward()
            else:
                loss.backward()

            step_now = ((it + 1) % cfg.grad_accum_steps == 0) or (it + 1 == len(train_loader))
            if step_now:
                if cfg.grad_clip:
                    if scaler.is_enabled(): scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(params_to_train, cfg.grad_clip)
                if scaler.is_enabled():
                    scaler.step(optimizer); scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()

            running_loss += float(loss.item()) * cfg.grad_accum_steps
            total_batches += 1
            pbar.set_postfix({'loss': f'{running_loss/total_batches:.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})

            del out, loss, hs3_loss, ab2_loss, cons_loss, img_only_loss

        avg_train_loss = running_loss / max(1, total_batches)
        val_metrics = evaluate(val_loader, model, crit, split="val")

        f1s = []
        if val_metrics["hs3"] is not None: f1s.append(val_metrics["hs3"]["f1_macro"])
        if val_metrics["ab2"] is not None: f1s.append(val_metrics["ab2"]["f1_macro"])
        current_score = float(np.mean(f1s)) if f1s else -val_metrics["loss"]

        history.append({'epoch': epoch, 'train_loss': avg_train_loss, 'val_score': current_score,
                        'val_hs3_f1': val_metrics["hs3"]["f1_macro"] if val_metrics["hs3"] else None,
                        'val_ab2_f1': val_metrics["ab2"]["f1_macro"] if val_metrics["ab2"] else None})

        print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Val Score={current_score:.4f}, Best={best_score:.4f}")
        print_gpu_memory()

        base_model = unwrap_for_save(model)
        ckpt = {
            "model": base_model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "scaler": scaler.state_dict() if scaler.is_enabled() else None,
            "epoch": epoch,
            "best_score": max(best_score, current_score),
            "cfg": cfg.__dict__,
            "history": history
        }

        if current_score > best_score:
            best_score = current_score
            patience = 0
            torch.save(ckpt, processed_dir / cfg.checkpoint_name)
            if cfg.drive_save_path:
                Path(cfg.drive_save_path).mkdir(parents=True, exist_ok=True)
                torch.save(ckpt, Path(cfg.drive_save_path) / cfg.checkpoint_name)
            print("New best model saved!")
        else:
            patience += 1
            print(f"⏳ No improvement. Patience {patience}/{cfg.early_stop_patience}")
            if patience >= cfg.early_stop_patience:
                print("Early stopping.")
                break

        clear_memory()

    # Final evaluation
    print("\n" + "="*50)
    print("FINAL EVALUATION (tracking splits)")
    print("="*50)

    def _clean_state_dict_keys(sd):
        out = {}
        for k, v in sd.items():
            if k.startswith("_orig_mod."): k = k[len("_orig_mod."):]
            if k.startswith("module."):    k = k[len("module."):]
            out[k] = v
        return out

    processed_dir = Path(cfg.base_dir) / cfg.processed_dir_name
    reports_dir = processed_dir / cfg.reports_dirname
    ckpt_path = processed_dir / cfg.checkpoint_name
    if ckpt_path.exists():
        checkpoint = torch.load(ckpt_path, map_location=device)
        sd = _clean_state_dict_keys(checkpoint["model"])
        model.load_state_dict(sd, strict=True)
        print(f"Loaded best model from epoch {checkpoint['epoch']} | score {checkpoint['best_score']:.4f}")

    test_metrics = evaluate(test_loader, model, crit, split="test")
    results = {"best_val_score": best_score, "final_test_metrics": test_metrics, "training_history": history, "config": cfg.__dict__}

    with open(reports_dir / "training_results_novel.json", "w") as f:
        json.dump(to_serializable(results), f, indent=2)

    if cfg.final_full_eval:
        print("\n===== FULL SPLIT EVALUATION (val/test full) =====")
        ds_val_text_full  = OptimizedProcessedCSVSet(splits_dir / cfg.text_val_csv,  is_memotion=False, sample_n=None, pre_tokenize=True)
        ds_test_text_full = OptimizedProcessedCSVSet(splits_dir / cfg.text_test_csv, is_memotion=False, sample_n=None, pre_tokenize=True)
        val_ds_full  = ConcatDataset([ds_val_text_full,  ds_val_memo])
        test_ds_full = ConcatDataset([ds_test_text_full, ds_test_memo])

        num_workers_full = min(8, max(2, (os.cpu_count() or 2)//2))
        val_loader_full = DataLoader(val_ds_full, batch_size=cfg.batch_size, shuffle=False,
                                     num_workers=num_workers_full, collate_fn=collate_fn,
                                     pin_memory=(device=="cuda"), persistent_workers=True, prefetch_factor=4)
        test_loader_full = DataLoader(test_ds_full, batch_size=cfg.batch_size, shuffle=False,
                                      num_workers=num_workers_full, collate_fn=collate_fn,
                                      pin_memory=(device=="cuda"), persistent_workers=True, prefetch_factor=4)
        val_full = evaluate(val_loader_full, model, crit, split="val_full200k")
        test_full = evaluate(test_loader_full, model, crit, split="test_full200k")

    def save_confusion_matrix(y, p, labels, name):
        cm = confusion_matrix(y, p, labels=list(range(len(labels))))
        plt.figure(figsize=(6,4))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels)
        plt.title(f"{name} Confusion Matrix"); plt.ylabel("True"); plt.xlabel("Pred")
        plt.tight_layout(); plt.savefig(reports_dir / f"{name}_cm.png", dpi=150); plt.close()
        pd.DataFrame(cm, index=labels, columns=labels).to_csv(reports_dir / f"{name}_cm.csv")

    if test_metrics["hs3"] is not None:
        y, p = test_metrics["hs3"]["y"], test_metrics["hs3"]["p"]
        save_confusion_matrix(y, p, ["Abusive","Offensive","Non-abusive"], "HS3_test")
    if test_metrics["ab2"] is not None:
        y, p = test_metrics["ab2"]["y"], test_metrics["ab2"]["p"]
        save_confusion_matrix(y, p, ["Non-abusive","Abusive"], "AB2_test")

    if cfg.drive_save_path:
        import shutil
        drive_results = Path(cfg.drive_save_path) / cfg.reports_dirname
        shutil.copytree(reports_dir, drive_results, dirs_exist_ok=True)
        print(f"Copied reports to {drive_results}")

    print("\nTraining complete!")
    print(f"Best Val Score: {best_score:.4f} | Reports: {reports_dir}")
    return results

# ============== Run ==============
if __name__ == "__main__":
    print("Starting novel multimodal training (CLIP + consistency + hard-mining + LoRA, direct 3-class) …")
    results = train_novel()



"""
INTERPRETATION OF ARCHITECTURE DESIGN — MULTIMODAL HATE DETECTION WITH LoRA

The novel multimodal architecture combines text and vision encoders with
adaptive gating and LoRA fine-tuning to handle the complex separability
patterns in hateful meme detection.

Key architectural interpretations:

- **Dual Encoder Strategy**: Uses MiniLM for text and CLIP-ViT for images,
  recognizing that hate in memes often requires understanding both modalities
  simultaneously. Text captures explicit slurs while vision contextualizes
  cultural references and visual stereotypes.

- **Shared Projection Space**: Projects both text and image features into
  a common 256D space, enabling the model to learn cross-modal relationships
  where offensive meaning emerges from text-image combinations rather than
  either modality alone.

- **Adaptive Gating Mechanism**: The gate weight dynamically balances
  modality contributions per sample. For text-heavy memes, it relies more
  on textual analysis; for image-dominant offensive content, it shifts
  focus to visual features. This handles cases where offensive intent is
  only apparent when both modalities are considered together.

- **LoRA Fine-tuning**: Applies Low-Rank Adaptation to the text encoder
  (and optionally vision) to efficiently adapt pre-trained models without
  full fine-tuning. This preserves linguistic capabilities while specializing
  for hate detection patterns, particularly useful for detecting coded
  language and subtle offensive constructs.

- **Auxiliary Image-Only Loss**: Forces the vision backbone to develop
  standalone offensive content recognition, crucial for memes where
  the image alone conveys hateful meaning while text appears neutral.

- **Text Modality Dropout**: Randomly masks text during training to
  prevent over-reliance on textual cues and force robust multimodal
  reasoning, especially important for culturally contextual memes
  where visual elements carry offensive weight.

- **Hierarchical Consistency**: Aligns 3-class (Abusive/Offensive/Non-abusive)
  and binary predictions to maintain logical coherence, addressing the
  dataset's natural hierarchy where abusive content is a subset of offensive.

Overall interpretation:
This architecture acknowledges that hateful meme detection operates in
a continuum from clearly separable explicit content to highly ambiguous
borderline cases. The multimodal gating, projection strategy, and auxiliary
objectives create a robust system that can handle both obvious hate speech
and subtle, context-dependent offensive content that requires joint
text-image understanding.
"""

In [None]:
# ds_multimodal_novel_lora_full_pooled.py
# End-to-end: Train (LoRA + hard-mining + image-only aux loss + text dropout) + Direct 3-Class Evaluation
# Keys:
# - Fused 3-class labels at dataset load time (0=Abusive, 1=Offensive, 2=Non-abusive)
# - Mean pooling text encoder (no pooler warning)
# - Projection layers for text/image; gate uses both modalities (not text-only)
# - Cheap image-only auxiliary loss (no extra forward) + optional text-modality dropout
# - Direct 3-class eval; PR/ECE; gate analysis; CSV export

# Optional installs:
# !pip install -U torch torchvision transformers scikit-learn pandas numpy pillow tqdm seaborn matplotlib peft

import os, gc, json, math, random, warnings, sys, subprocess
from dataclasses import dataclass
from typing import Optional, List, Tuple
from pathlib import Path

warnings.filterwarnings("ignore", category=UserWarning)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"  # quiet HF warnings

# ===========================
# USER TOGGLES
# ===========================
RUN_TRAIN = False          # True to train; False to only evaluate using saved checkpoint
RUN_FUSED_EVAL = True      # Run final direct 3-class eval + insights
TEACHER_N = 10000           # None = FULL test; set int (e.g., 5000) for a quick run
EXPORT_SAMPLE = None       # None = full export; or int for sample CSV
BATCH_EVAL = 128           # Batch size for eval/export

# Paths
BASE_DIR = "datasets/processed_data"
SPLITS_DIR = f"{BASE_DIR}/splits"
CKPT_PATH = f"{BASE_DIR}/best_novel.pt"   # Training saves here by default
OUT_DIR   = f"{BASE_DIR}/novel_eval_insights"
os.makedirs(OUT_DIR, exist_ok=True)

# ===========================
# Config
# ===========================
@dataclass
class CFG:
    # Paths
    base_dir: str = "datasets"
    processed_dir_name: str = "processed_data"
    memotion_csv: str = "memotion_7k_multimodal.csv"
    splits_dirname: str = "splits"

    # Text CSVs
    text_train_csv: str = "train/text_train.csv"
    text_val_csv: str = "val/text_val.csv"
    text_test_csv: str = "test/text_test.csv"

    # Random subsampling for val/test (tracking during training)
    seed: int = 42
    sample_val_text_n: Optional[int] = 7_000
    sample_test_text_n: Optional[int] = 5_000
    sample_val_memotion_n: Optional[int] = None
    sample_test_memotion_n: Optional[int] = None

    # Model (text + vision)
    text_model_name: str = "nreimers/MiniLMv2-L6-H384-distilled-from-RoBERTa-Large"
    use_clip_vision: bool = True
    clip_vision_model: str = "openai/clip-vit-base-patch32"
    image_backbone_fallback: str = "mobilenet_v2"
    freeze_text_encoder: bool = True
    freeze_vision_encoder: bool = True
    gradient_checkpointing: bool = True
    use_torch_compile: bool = False

    # Tokenization / image
    max_len: int = 96
    image_size: int = 224
    pre_tokenize_text: bool = False

    # Training
    epochs: int = 30
    batch_size: int = 32
    grad_accum_steps: int = 1
    lr: float = 1e-4
    weight_decay: float = 1e-2
    warmup_ratio: float = 0.06
    label_smoothing: float = 0.05
    grad_clip: float = 1.0
    early_stop_patience: int = 4
    mixed_precision: bool = True

    # Architecture
    hidden: int = 256
    dropout: float = 0.25
    ab2_mode: str = "ce"
    add_meta_dim: int = 2
    use_multi_task: bool = True
    proj_dim: int = 256   # projection dimension for text/image before fusion

    # Loss tweaks
    lambda_consistency: float = 0.2
    image_only_loss_w: float = 0.5   # weight for image-only aux loss (0.0 to disable)
    text_dropout_prob: float = 0.25  # probability to drop text (forces image reliance)

    # Hard-mining curriculum (TEXT)
    hard_mining: bool = True
    pool_size: int = 5_000
    train_text_per_epoch: int = 5_000
    hard_frac: float = 0.6

    # Memotion images per epoch
    images_per_epoch: int = 200

    # Balance
    use_class_weights: bool = True

    # LoRA config
    use_lora_text: bool = True
    use_lora_vision: bool = False      # enable if you want CLIP LoRA too
    lora_r_text: int = 8
    lora_alpha_text: int = 16
    lora_dropout_text: float = 0.1
    lora_target_text: Optional[List[str]] = None  # ["query","value"] default

    lora_r_vision: int = 4
    lora_alpha_vision: int = 8
    lora_dropout_vision: float = 0.05
    lora_target_vision: Optional[List[str]] = None  # ["q_proj","v_proj"] default

    # Text pooling
    text_pooling: str = "mean"   # "mean" or "cls"

    # Reporting
    reports_dirname: str = "reports_novel"
    checkpoint_name: str = "best_novel.pt"
    drive_save_path: Optional[str] = None
    final_full_eval: bool = False

cfg = CFG()

# ===========================
# Device / AMP / Repro
# ===========================
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.cuda.empty_cache(); gc.collect()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    print("VRAM:", torch.cuda.get_device_properties(0).total_memory/1024**3, "GB")
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

torch.set_num_threads(2)
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler(enabled=cfg.mixed_precision)

def set_seed(seed: int):
    import numpy as np, random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
set_seed(cfg.seed)

# ===========================
# Imports
# ===========================
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModel, AutoConfig, CLIPVisionModel, CLIPImageProcessor
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
from sklearn.metrics import precision_recall_curve, average_precision_score
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

# PEFT (LoRA)
try:
    from peft import LoraConfig, get_peft_model, TaskType
except ImportError:
    print("Installing 'peft' for LoRA...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "peft"])
    from peft import LoraConfig, get_peft_model, TaskType

# ===========================
# JSON-safe
# ===========================
def to_serializable(o):
    import numpy as np, torch
    if isinstance(o, dict):
        return {k: to_serializable(v) for k, v in o.items()}
    if isinstance(o, (list, tuple, set)):
        return [to_serializable(v) for v in o]
    if isinstance(o, np.ndarray):
        return o.tolist()
    if isinstance(o, (np.integer,)):
        return int(o)
    if isinstance(o, (np.floating,)):
        return float(o)
    if isinstance(o, torch.Tensor):
        return o.detach().cpu().tolist()
    if isinstance(o, Path):
        return str(o)
    return o

# ===========================
# Label helpers
# ===========================
def map_text_labels(row: pd.Series):
    hs3 = -100
    if "original_class" in row and str(row["original_class"]).strip() != "":
        try:
            oc = int(float(row["original_class"]))
            hs3 = oc if oc in (0,1,2) else -100
        except: hs3 = -100
    ab2 = -100
    if "label" in row and str(row["label"]).strip() != "":
        try:
            v = int(float(row["label"])); ab2 = 1 if v == 1 else 0
        except: ab2 = -100
    return hs3, ab2

def map_memotion_labels(off_cat: str):
    s = str(off_cat).strip().lower()
    if s == "hateful_offensive": hs3 = 0
    elif s in ("offensive","very_offensive"): hs3 = 1
    elif s in ("slight","not_offensive"): hs3 = 2
    else: hs3 = -100
    ab2 = 1 if s in ("offensive","very_offensive","hateful_offensive") else 0
    return hs3, ab2

def parse_sarcasm(val) -> int:
    s = str(val).strip().lower()
    return 1 if s in {"sarcasm","sarcastic","yes","true","1"} or "sarcas" in s else 0

def parse_humour(val) -> int:
    s = str(val).strip().lower()
    if s in {"", "none", "not_funny", "not funny", "no_humour", "no_humor"}: return 0
    return 1 if any(k in s for k in ["funny","hilar","humor","humour","very_funny"]) else 0

def one_hot_smooth(y: int, num_classes: int = 3, eps: float = 0.0) -> torch.Tensor:
    vec = torch.full((num_classes,), eps / (num_classes - 1), dtype=torch.float32)
    if y >= 0: vec[y] = 1.0 - eps
    else: vec[:] = 0.0
    return vec

class SoftCrossEntropyLoss(nn.Module):
    def __init__(self, ignore_index=-100): super().__init__(); self.ignore_index = ignore_index
    def forward(self, input, target):
        if target.dim() == 2:
            target = target.to(input.dtype)
            valid = target.sum(dim=-1) > 0
            if not valid.any():
                return input.new_tensor(0.0)
            log_probs = F.log_softmax(input[valid], dim=-1)
            return -(target[valid] * log_probs).sum(dim=-1).mean()
        else:
            return F.cross_entropy(input, target, ignore_index=self.ignore_index)

def compute_weights_from_counts(counts: np.ndarray):
    if counts.sum() == 0: return None
    freqs = counts / counts.sum()
    w = 1.0 / np.clip(freqs, 1e-8, None)
    w = w / (w.mean() + 1e-8)
    return torch.tensor(w, dtype=torch.float32)

def build_losses(ab2_mode="ce", hs3_class_weights=None, ab2_class_weights=None, device="cpu", label_smoothing=0.0):
    def _ce(weight):
        try: return nn.CrossEntropyLoss(ignore_index=-100, weight=weight, label_smoothing=label_smoothing)
        except TypeError: return nn.CrossEntropyLoss(ignore_index=-100, weight=weight)
    hs3_w = hs3_class_weights.to(device) if isinstance(hs3_class_weights, torch.Tensor) else None
    ab2_w = ab2_class_weights.to(device) if isinstance(ab2_class_weights, torch.Tensor) else None
    crit = {"hs3": SoftCrossEntropyLoss(ignore_index=-100)}
    crit["ab2"] = _ce(ab2_w) if ab2_mode == "ce" else nn.BCEWithLogitsLoss(pos_weight=ab2_w[1:2] if ab2_w is not None else None)
    return crit

def build_scheduler(optimizer, train_loader_len, cfg):
    updates_per_epoch = max(1, math.ceil(train_loader_len / cfg.grad_accum_steps))
    total_updates = updates_per_epoch * cfg.epochs
    warmup_updates = max(1, int(cfg.warmup_ratio * total_updates))
    def lr_lambda(step):
        if step < warmup_updates:
            return float(step) / float(max(1, warmup_updates))
        progress = (step - warmup_updates) / float(max(1, total_updates - warmup_updates))
        progress = min(1.0, max(0.0, progress))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda), total_updates

# ===========================
# Tokenizers / processors
# ===========================
text_tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name, use_fast=True)
clip_processor = CLIPImageProcessor.from_pretrained(cfg.clip_vision_model) if cfg.use_clip_vision else None
img_tfm_fallback = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# ===========================
# Dataset (fused 3-class labels)
# ===========================
class OptimizedProcessedCSVSet(Dataset):
    def __init__(self, csv_path: Path, is_memotion=False, split: Optional[str]=None,
                 sample_n: Optional[int]=None, pre_tokenize: Optional[bool]=None):
        self.is_memotion = is_memotion
        self.split = split
        self.pre_tokenize = cfg.pre_tokenize_text if pre_tokenize is None else pre_tokenize

        hdr = pd.read_csv(csv_path, nrows=0).columns
        if is_memotion:
            want = ["text", "image_path", "offensive_category", "sarcasm", "humour"]
        else:
            want = ["text", "original_class", "label"]
        usecols = [c for c in want if c in hdr]
        df = pd.read_csv(csv_path, usecols=usecols or None).fillna("")

        if is_memotion and split in {"train","val","test"}:
            rng = np.random.RandomState(123)
            idx = np.arange(len(df)); rng.shuffle(idx)
            n = len(idx); n_train=int(0.8*n); n_val=int(0.1*n)
            part = {"train": idx[:n_train], "val": idx[n_train:n_train+n_val], "test": idx[n_train+n_val:]}[split]
            df = df.iloc[part].reset_index(drop=True)

        if "text" in df.columns and 3 > 0:
            df = df[df["text"].astype(str).str.len() >= 3]

        if sample_n is not None and len(df) > sample_n:
            df = df.sample(n=int(sample_n), random_state=cfg.seed, replace=False).reset_index(drop=True)

        self.df = df.reset_index(drop=True)

        # Fused labels (0=Abusive, 1=Offensive, 2=Non-abusive)
        self.labels = []
        for _, row in self.df.iterrows():
            if is_memotion:
                hs3, ab2 = map_memotion_labels(row.get("offensive_category",""))
            else:
                hs3, ab2 = map_text_labels(row)
            fused_hs3 = hs3
            if fused_hs3 == -100 and ab2 != -100:
                fused_hs3 = 2 if ab2 == 0 else 1
            self.labels.append((fused_hs3, ab2))

        self.use_half = torch.cuda.is_available() and cfg.mixed_precision

        self.enc_input_ids, self.enc_attention_mask = None, None
        if self.pre_tokenize and len(self.df) > 0:
            texts = self.df["text"].astype(str).tolist()
            enc = text_tokenizer(texts, truncation=True, max_length=cfg.max_len, padding=False)
            self.enc_input_ids, self.enc_attention_mask = enc["input_ids"], enc["attention_mask"]

    def __len__(self): return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        text = str(row.get("text","")).strip()

        if self.enc_input_ids is not None:
            input_ids = self.enc_input_ids[i]; attention_mask = self.enc_attention_mask[i]
        else:
            enc = text_tokenizer(text, truncation=True, max_length=cfg.max_len, padding=False)
            input_ids, attention_mask = enc["input_ids"], enc["attention_mask"]

        img = None; has_image = False
        if self.is_memotion:
            fname = Path(str(row.get("image_path","")).strip()).name
            if fname:
                img_root = Path(cfg.base_dir) / cfg.processed_dir_name / "memotion_dataset_7k" / "images"
                p = img_root / fname
                if p.exists():
                    try:
                        pil = Image.open(p).convert("RGB")
                        if cfg.use_clip_vision:
                            img = clip_processor(images=pil, return_tensors="pt")["pixel_values"][0]
                        else:
                            img = img_tfm_fallback(pil)
                        has_image = True
                    except Exception:
                        has_image = False

        if img is None:
            img = torch.zeros(3, cfg.image_size, cfg.image_size, dtype=torch.float32)
        if self.use_half: img = img.half()

        sarcasm = parse_sarcasm(row.get("sarcasm","")) if self.is_memotion and "sarcasm" in row else 0
        humour  = parse_humour(row.get("humour",""))   if self.is_memotion and "humour" in row else 0
        meta = torch.tensor([sarcasm, humour], dtype=torch.float16 if self.use_half else torch.float32)

        hs3, ab2 = self.labels[i]
        hs3_target = one_hot_smooth(int(hs3), num_classes=3, eps=cfg.label_smoothing)
        if self.use_half: hs3_target = hs3_target.half()

        return {
            "text": text,
            "input_ids": input_ids, "attention_mask": attention_mask,
            "image": img, "has_image": torch.tensor(has_image, dtype=torch.bool),
            "meta": meta,
            "hs3_label": torch.tensor(hs3, dtype=torch.long),
            "hs3_target": hs3_target,
            "ab2_label": torch.tensor(ab2, dtype=torch.long),
        }

# ===========================
# Model with LoRA, mean pooling, projections, multimodal gate
# ===========================
class OptimizedMultiModalGated(nn.Module):
    def __init__(self, text_model_name, freeze_text=True, use_clip=True, clip_model_name=None,
                 fallback_backbone="mobilenet_v2", freeze_vision=True,
                 use_multi_task=True, ab2_mode="ce", hidden=256, dropout=0.25, add_meta_dim=2,
                 gradient_checkpointing=True, cfg_loaded=None):
        super().__init__()
        self.use_multi_task = use_multi_task
        self.ab2_mode = ab2_mode
        self.add_meta_dim = add_meta_dim
        self.use_clip = use_clip
        self.cfg_loaded = cfg_loaded if cfg_loaded is not None else cfg

        # Text encoder (remove pooler; enable mean pooling)
        txt_cfg = AutoConfig.from_pretrained(text_model_name)
        if hasattr(txt_cfg, "add_pooling_layer"):
            txt_cfg.add_pooling_layer = False
        self.text_model = AutoModel.from_pretrained(text_model_name, config=txt_cfg)
        if gradient_checkpointing and hasattr(self.text_model, "gradient_checkpointing_enable"):
            try: self.text_model.gradient_checkpointing_enable()
            except: pass
        tdim = self.text_model.config.hidden_size
        if freeze_text:
            for p in self.text_model.parameters(): p.requires_grad = False

        # LoRA on text
        if getattr(self.cfg_loaded, "use_lora_text", False):
            targets = getattr(self.cfg_loaded, "lora_target_text", None) or ["query","value"]
            lora_text_cfg = LoraConfig(
                r=getattr(self.cfg_loaded, "lora_r_text", 8),
                lora_alpha=getattr(self.cfg_loaded, "lora_alpha_text", 16),
                lora_dropout=getattr(self.cfg_loaded, "lora_dropout_text", 0.1),
                target_modules=targets,
                bias="none",
                task_type=TaskType.FEATURE_EXTRACTION
            )
            self.text_model = get_peft_model(self.text_model, lora_text_cfg)

        # Vision encoder
        if use_clip:
            self.vision = CLIPVisionModel.from_pretrained(clip_model_name or "openai/clip-vit-base-patch32")
            idim = self.vision.config.hidden_size
            if freeze_vision:
                for p in self.vision.parameters(): p.requires_grad = False
            if getattr(self.cfg_loaded, "use_lora_vision", False):
                vtargets = getattr(self.cfg_loaded, "lora_target_vision", None) or ["q_proj","v_proj"]
                lora_vis_cfg = LoraConfig(
                    r=getattr(self.cfg_loaded, "lora_r_vision", 4),
                    lora_alpha=getattr(self.cfg_loaded, "lora_alpha_vision", 8),
                    lora_dropout=getattr(self.cfg_loaded, "lora_dropout_vision", 0.05),
                    target_modules=vtargets,
                    bias="none",
                    task_type=TaskType.FEATURE_EXTRACTION
                )
                self.vision = get_peft_model(self.vision, lora_vis_cfg)
            self.use_clip = True
        else:
            if fallback_backbone.lower() == "mobilenet_v2":
                try:
                    from torchvision.models import MobileNet_V2_Weights
                    im = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
                except Exception:
                    im = models.mobilenet_v2(pretrained=True)
                idim = 1280; im.classifier = nn.Identity()
            else:
                try:
                    from torchvision.models import ResNet18_Weights
                    backbone = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
                except Exception:
                    backbone = models.resnet18(pretrained=True)
                idim = 512; im = nn.Sequential(*list(backbone.children())[:-1], nn.Flatten(1))
            if freeze_vision:
                for p in im.parameters(): p.requires_grad = False
            self.vision = im
            self.use_clip = False

        self.image_feat_dim = idim
        proj_dim = getattr(self.cfg_loaded, "proj_dim", 256)

        # Projection layers to align scales
        self.t_proj = nn.Sequential(
            nn.Linear(tdim, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        self.i_proj = nn.Sequential(
            nn.Linear(idim, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Dropout(dropout),
        )

        # Backbones
        fusion_in = proj_dim + proj_dim + (add_meta_dim if add_meta_dim>0 else 0)
        text_in   = proj_dim + (add_meta_dim if add_meta_dim>0 else 0)

        self.fusion_backbone = nn.Sequential(
            nn.Linear(fusion_in, hidden), nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.GELU(), nn.Dropout(dropout),
        )
        self.text_backbone = nn.Sequential(
            nn.Linear(text_in, hidden), nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.GELU(), nn.Dropout(dropout),
        )

        # Gate over both modalities (+meta)
        self.gate_proj = nn.Sequential(
            nn.Linear(proj_dim + proj_dim + add_meta_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

        # Heads
        self.hs3_fusion_head = nn.Linear(hidden//2, 3)
        self.hs3_text_head   = nn.Linear(hidden//2, 3)
        self.ab2_fusion_head = nn.Linear(hidden//2, 2)
        self.ab2_text_head   = nn.Linear(hidden//2, 2)

    def _mean_pool(self, last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(self, input_ids, attention_mask, images, has_image_mask, meta=None):
        # Text
        out_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        t_last = out_text.last_hidden_state
        if getattr(self.cfg_loaded, "text_pooling", "mean") == "mean":
            t_feat = self._mean_pool(t_last, attention_mask)
        else:
            t_feat = t_last[:,0,:]
        t_proj = self.t_proj(t_feat)

        # Meta dtype align
        if (meta is not None) and (meta.dtype != t_proj.dtype):
            meta = meta.to(t_proj.dtype)

        # Vision
        B = images.size(0)
        img_mask_flat = has_image_mask.view(-1)
        if img_mask_flat.any():
            idx = torch.nonzero(img_mask_flat, as_tuple=False).squeeze(1)
            images_sub = images[idx].to(memory_format=torch.channels_last)
            if self.use_clip:
                v_out = self.vision(pixel_values=images_sub)
                i_sub = v_out.pooler_output
            else:
                i_sub = self.vision(images_sub)
            i_sub = i_sub.to(images_sub.dtype)
            i_feat = images.new_zeros((B, self.image_feat_dim), dtype=images_sub.dtype)
            i_feat[idx] = i_sub
        else:
            i_feat = images.new_zeros((B, self.image_feat_dim), dtype=images.dtype)
        i_proj = self.i_proj(i_feat)
        img_mask = img_mask_flat.float().unsqueeze(1)

        # Inputs
        if self.add_meta_dim and meta is not None:
            t_in = torch.cat([t_proj, meta], dim=1)
            f_in = torch.cat([t_proj, i_proj, meta], dim=1)
            gate_in = f_in
        else:
            t_in = t_proj
            f_in = torch.cat([t_proj, i_proj], dim=1)
            gate_in = f_in

        # Reprs
        t_repr = self.text_backbone(t_in)
        f_repr = self.fusion_backbone(f_in)

        # Gate
        gate_weight = self.gate_proj(gate_in) * img_mask  # [B,1]

        # Heads
        hs3_f = self.hs3_fusion_head(f_repr)
        hs3_t = self.hs3_text_head(t_repr)
        hs3_logits = gate_weight * hs3_f + (1 - gate_weight) * hs3_t

        ab2_f = self.ab2_fusion_head(f_repr)
        ab2_t = self.ab2_text_head(t_repr)
        ab2_logits = gate_weight * ab2_f + (1 - gate_weight) * ab2_t

        return {
            "hs3_logits": hs3_logits,
            "ab2_logits": ab2_logits,
            "gate_weight": gate_weight.detach(),
            # expose branch-only logits (for cheap image-only loss, ablations)
            "hs3_f_only": hs3_f,
            "ab2_f_only": ab2_f,
            "hs3_t_only": hs3_t,
            "ab2_t_only": ab2_t,
        }

# ===========================
# Collate / utils
# ===========================
def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    if not batch: return None
    pad = text_tokenizer.pad([{"input_ids": b["input_ids"], "attention_mask": b["attention_mask"]} for b in batch],
                             padding=True, pad_to_multiple_of=8, return_tensors="pt")
    return {
        "text": [b["text"] for b in batch],
        "input_ids": pad["input_ids"],
        "attention_mask": pad["attention_mask"],
        "image": torch.stack([b["image"] for b in batch]),
        "has_image": torch.stack([b["has_image"] for b in batch]),
        "meta": torch.stack([b["meta"] for b in batch]),
        "hs3_label": torch.stack([b["hs3_label"] for b in batch]),
        "hs3_target": torch.stack([b["hs3_target"] for b in batch]),
        "ab2_label": torch.stack([b["ab2_label"] for b in batch]),
    }

def clear_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache(); torch.cuda.synchronize()
    gc.collect()

def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU Mem - Alloc {torch.cuda.memory_allocated()/1024**3:.2f} GB | Reserved {torch.cuda.memory_reserved()/1024**3:.2f} GB")

# ===========================
# Consistency loss
# ===========================
def hierarchical_consistency_loss(hs3_logits, ab2_logits):
    p_hs3 = F.softmax(hs3_logits, dim=-1)
    q_ab2 = torch.stack([p_hs3[:,2], (p_hs3[:,0] + p_hs3[:,1])], dim=1).clamp_min(1e-8)
    p_ab2 = F.softmax(ab2_logits, dim=-1).clamp_min(1e-8)
    return F.kl_div(p_ab2.log(), q_ab2, reduction="batchmean") + F.kl_div(q_ab2.log(), p_ab2, reduction="batchmean")

# ===========================
# Hard-mining and Memotion sampling
# ===========================
@torch.no_grad()
def score_difficulty_text(model, dataset: Dataset, indices: List[int], batch_size=256):
    model.eval()
    pool_ds = Subset(dataset, indices)
    loader = DataLoader(pool_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=(device=="cuda"), collate_fn=collate_fn)
    diffs = np.zeros(len(indices), dtype=np.float32); off = 0
    for batch in tqdm(loader, desc="Scoring hardness", leave=False):
        ids = batch["input_ids"].to(device, non_blocking=True)
        attn = batch["attention_mask"].to(device, non_blocking=True)
        imgs = batch["image"].to(device, non_blocking=True)
        has_img = batch["has_image"].to(device, non_blocking=True)
        meta = batch["meta"].to(device, non_blocking=True) if cfg.add_meta_dim > 0 else None
        y_hs3 = batch["hs3_label"].to(device, non_blocking=True)
        y_ab2 = batch["ab2_label"].to(device, non_blocking=True)
        with autocast(enabled=cfg.mixed_precision, dtype=torch.float16):
            out = model(ids, attn, imgs, has_img, meta)
            hs3_log = out["hs3_logits"]; ab2_log = out["ab2_logits"]
            hs3_ce = torch.zeros(ids.size(0), device=device)
            m_h = (y_hs3 != -100)
            if m_h.any():
                hs3_ce[m_h] = F.cross_entropy(hs3_log[m_h], y_hs3[m_h], reduction="none")
            ab2_ce = torch.zeros(ids.size(0), device=device)
            m_a = (y_ab2 != -100)
            if m_a.any():
                ab2_ce[m_a] = F.cross_entropy(ab2_log[m_a], y_ab2[m_a], reduction="none")
            diff = (hs3_ce + ab2_ce).detach().float().cpu().numpy()
        diffs[off:off+len(diff)] = diff; off += len(diff)
    return indices, diffs

def select_hard_examples(model, full_text_ds: Dataset, pool_size: int, select_n: int, hard_frac: float, epoch: int):
    N = len(full_text_ds); pool_size = min(pool_size, N)
    rng_pool = np.random.RandomState(cfg.seed + 1009*epoch)
    pool_idx = rng_pool.choice(N, size=pool_size, replace=False).tolist()
    _, diffs = score_difficulty_text(model, full_text_ds, pool_idx, batch_size=256)
    k_hard = int(select_n * hard_frac)
    hard_idx_rel = np.argsort(-diffs)[:k_hard]; hard_idx = [pool_idx[i] for i in hard_idx_rel]
    remaining = select_n - k_hard
    remaining_candidates = [i for i in range(N) if i not in set(hard_idx)]
    rng_rest = np.random.RandomState(cfg.seed + 2027*epoch)
    rand_rel = rng_rest.choice(len(remaining_candidates), size=remaining, replace=False)
    rand_idx = [remaining_candidates[i] for i in rand_rel]
    return hard_idx + rand_idx

def get_memotion_valid_indices(ds_memo: OptimizedProcessedCSVSet):
    valid = []
    img_root = Path(cfg.base_dir) / cfg.processed_dir_name / "memotion_dataset_7k" / "images"
    for i, row in ds_memo.df.iterrows():
        fname = Path(str(row.get("image_path","")).strip()).name
        if fname and (img_root / fname).exists():
            valid.append(i)
    return valid

def sample_memotion_indices(valid_idx: List[int], n: int, epoch: int):
    n = min(n, len(valid_idx))
    rng = np.random.RandomState(cfg.seed + 707*epoch)
    sel = rng.choice(valid_idx, size=n, replace=False).tolist()
    return sel

# ===========================
# Training
# ===========================
def train_novel():
    clear_memory()
    processed_dir = Path(cfg.base_dir) / cfg.processed_dir_name
    splits_dir = processed_dir / cfg.splits_dirname
    reports_dir = processed_dir / cfg.reports_dirname
    reports_dir.mkdir(parents=True, exist_ok=True)

    memotion_path = processed_dir / cfg.memotion_csv
    assert memotion_path.exists(), f"Missing memotion CSV: {memotion_path}"
    train_text_csv = splits_dir / cfg.text_train_csv
    val_text_csv   = splits_dir / cfg.text_val_csv
    test_text_csv  = splits_dir / cfg.text_test_csv
    for p in [train_text_csv, val_text_csv, test_text_csv]:
        assert p.exists(), f"Missing: {p}"

    # Datasets
    ds_train_text_full = OptimizedProcessedCSVSet(train_text_csv, is_memotion=False, sample_n=None, pre_tokenize=False)
    ds_val_text   = OptimizedProcessedCSVSet(val_text_csv,   is_memotion=False, sample_n=cfg.sample_val_text_n, pre_tokenize=True)
    ds_test_text  = OptimizedProcessedCSVSet(test_text_csv,  is_memotion=False, sample_n=cfg.sample_test_text_n, pre_tokenize=True)

    ds_val_memo   = OptimizedProcessedCSVSet(memotion_path, is_memotion=True, split="val",  sample_n=cfg.sample_val_memotion_n, pre_tokenize=True)
    ds_test_memo  = OptimizedProcessedCSVSet(memotion_path, is_memotion=True, split="test", sample_n=cfg.sample_test_memotion_n, pre_tokenize=True)
    ds_train_memo_full = OptimizedProcessedCSVSet(memotion_path, is_memotion=True, split="train", sample_n=None, pre_tokenize=True)

    memotion_valid_idx = get_memotion_valid_indices(ds_train_memo_full)

    # Initial subsets
    text_init_size = min(cfg.train_text_per_epoch, len(ds_train_text_full))
    init_text_idx = np.random.RandomState(cfg.seed).choice(len(ds_train_text_full), size=text_init_size, replace=False)
    ds_train_text_epoch = Subset(ds_train_text_full, init_text_idx.tolist())
    memo_init_idx = sample_memotion_indices(memotion_valid_idx, cfg.images_per_epoch, epoch=0)
    ds_train_memo_epoch = Subset(ds_train_memo_full, memo_init_idx)

    train_ds = ConcatDataset([ds_train_text_epoch, ds_train_memo_epoch])
    val_ds   = ConcatDataset([ds_val_text, ds_val_memo])
    test_ds  = ConcatDataset([ds_test_text, ds_test_memo])

    # Dataloaders
    cpu_ct = os.cpu_count() or 2
    num_workers = min(8, max(2, cpu_ct // 2))
    dl_kwargs = dict(num_workers=num_workers, collate_fn=collate_fn, pin_memory=(device=="cuda"), persistent_workers=True, prefetch_factor=4)
    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, **dl_kwargs)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, **dl_kwargs)
    test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, **dl_kwargs)

    # Model
    model = OptimizedMultiModalGated(
        text_model_name=cfg.text_model_name,
        freeze_text=cfg.freeze_text_encoder,
        use_clip=cfg.use_clip_vision,
        clip_model_name=cfg.clip_vision_model,
        fallback_backbone=cfg.image_backbone_fallback,
        freeze_vision=cfg.freeze_vision_encoder,
        use_multi_task=cfg.use_multi_task,
        ab2_mode=cfg.ab2_mode,
        hidden=cfg.hidden,
        dropout=cfg.dropout,
        add_meta_dim=cfg.add_meta_dim,
        gradient_checkpointing=cfg.gradient_checkpointing
    ).to(device)
    if device == "cuda":
        model = model.to(memory_format=torch.channels_last)

    if cfg.use_torch_compile and hasattr(torch, "compile"):
        try:
            model = torch.compile(model); print("torch.compile enabled")
        except Exception as e:
            print(f"torch.compile skipped: {e}")

    # Optimizer
    params_to_train = [p for p in model.parameters() if p.requires_grad]
    try:
        optimizer = torch.optim.AdamW(params_to_train, lr=cfg.lr, weight_decay=cfg.weight_decay, fused=True)
    except TypeError:
        optimizer = torch.optim.AdamW(params_to_train, lr=cfg.lr, weight_decay=cfg.weight_decay)

    # Class weights
    def counts_from_concat(text_subset: Dataset, memo_subset: Dataset, label_pos, n_classes):
        counts = np.zeros(n_classes, dtype=np.int64)
        for i in range(len(text_subset)):
            idx = text_subset.indices[i] if isinstance(text_subset, Subset) else i
            y = ds_train_text_full.labels[idx][label_pos]
            if y != -100: counts[y] += 1
        for i in range(len(memo_subset)):
            idx = memo_subset.indices[i] if isinstance(memo_subset, Subset) else i
            y_h, y_a = ds_train_memo_full.labels[idx]
            if label_pos == 0 and y_h != -100: counts[y_h] += 1
            if label_pos == 1 and y_a != -100: counts[y_a] += 1
        return counts

    hs3_w = ab2_w = None
    if cfg.use_class_weights:
        hs3_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=0, n_classes=3)
        ab2_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=1, n_classes=2)
        hs3_w = compute_weights_from_counts(hs3_counts); ab2_w = compute_weights_from_counts(ab2_counts)
    crit = build_losses(cfg.ab2_mode, hs3_w, ab2_w, device=device, label_smoothing=cfg.label_smoothing)
    scheduler, total_updates = build_scheduler(optimizer, len(train_loader), cfg)

    print("Training started")
    print(f"Steps/epoch: {len(train_loader)}")
    print_gpu_memory()

    best_score = -1.0; patience = 0; history = []
    processed_dir = Path(cfg.base_dir) / cfg.processed_dir_name

    def unwrap_for_save(m):
        if hasattr(m, "_orig_mod"): m = m._orig_mod
        if hasattr(m, "module"): m = m.module
        return m

    for epoch in range(1, cfg.epochs+1):
        if cfg.hard_mining:
            print(f"\nHard-mining epoch {epoch}: pool {cfg.pool_size} -> select {cfg.train_text_per_epoch} (hard_frac={cfg.hard_frac})")
            with torch.no_grad():
                sel_text_idx = select_hard_examples(model, ds_train_text_full, cfg.pool_size, cfg.train_text_per_epoch, cfg.hard_frac, epoch)
            ds_train_text_epoch = Subset(ds_train_text_full, sel_text_idx)

        sel_memo_idx = sample_memotion_indices(get_memotion_valid_indices(ds_train_memo_full), cfg.images_per_epoch, epoch)
        ds_train_memo_epoch = Subset(ds_train_memo_full, sel_memo_idx)

        train_ds = ConcatDataset([ds_train_text_epoch, ds_train_memo_epoch])
        train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, **dl_kwargs)

        if cfg.use_class_weights:
            hs3_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=0, n_classes=3)
            ab2_counts = counts_from_concat(ds_train_text_epoch, ds_train_memo_epoch, label_pos=1, n_classes=2)
            hs3_w = compute_weights_from_counts(hs3_counts); ab2_w = compute_weights_from_counts(ab2_counts)
            crit = build_losses(cfg.ab2_mode, hs3_w, ab2_w, device=device, label_smoothing=cfg.label_smoothing)

        print(f"Epoch {epoch} train sizes — text: {len(ds_train_text_epoch)}, memotion: {len(ds_train_memo_epoch)}")

        model.train(); running_loss = 0.0; total_batches = 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}/{cfg.epochs}")
        optimizer.zero_grad(set_to_none=True)
        for it, batch in pbar:
            if batch is None: continue
            ids = batch["input_ids"].to(device, non_blocking=True)
            attn = batch["attention_mask"].to(device, non_blocking=True)
            imgs = batch["image"].to(device, non_blocking=True)
            has_img = batch["has_image"].to(device, non_blocking=True)
            meta = batch["meta"].to(device, non_blocking=True) if cfg.add_meta_dim > 0 else None
            y_ab2 = batch["ab2_label"].to(device, non_blocking=True)
            hs3_target = batch["hs3_target"].to(device, non_blocking=True)

            # Modality dropout for text (forces image reliance sometimes)
            if cfg.text_dropout_prob > 0 and random.random() < cfg.text_dropout_prob:
                ids = ids.clone()
                ids[:] = text_tokenizer.pad_token_id
                attn = torch.zeros_like(attn)

            with autocast(enabled=cfg.mixed_precision, dtype=torch.float16):
                out = model(ids, attn, imgs, has_img, meta)
                hs3_loss = crit["hs3"](out["hs3_logits"], hs3_target)
                ab2_loss = crit["ab2"](out["ab2_logits"], y_ab2)
                cons_loss = hierarchical_consistency_loss(out["hs3_logits"], out["ab2_logits"]) * cfg.lambda_consistency

                # Cheap image-only aux loss (use fusion-only logits on samples with images)
                img_mask = has_img.view(-1)
                img_only_loss = imgs.new_tensor(0.0)
                if cfg.image_only_loss_w > 0 and img_mask.any():
                    img_hs3 = crit["hs3"](out["hs3_f_only"][img_mask], hs3_target[img_mask])
                    img_ab2 = crit["ab2"](out["ab2_f_only"][img_mask], y_ab2[img_mask])
                    img_only_loss = cfg.image_only_loss_w * (img_hs3 + img_ab2)

                loss = (hs3_loss + ab2_loss + cons_loss + img_only_loss) / cfg.grad_accum_steps

            if scaler.is_enabled():
                scaler.scale(loss).backward()
            else:
                loss.backward()

            step_now = ((it + 1) % cfg.grad_accum_steps == 0) or (it + 1 == len(train_loader))
            if step_now:
                if cfg.grad_clip:
                    if scaler.is_enabled(): scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(params_to_train, cfg.grad_clip)
                if scaler.is_enabled():
                    scaler.step(optimizer); scaler.update()
                else:
                    optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()

            running_loss += float(loss.item()) * cfg.grad_accum_steps
            total_batches += 1
            pbar.set_postfix({'loss': f'{running_loss/total_batches:.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})
            del out, loss, hs3_loss, ab2_loss, cons_loss, img_only_loss

        avg_train_loss = running_loss / max(1, total_batches)
        current_score = -avg_train_loss
        history.append({'epoch': epoch, 'train_loss': avg_train_loss, 'val_score': current_score})
        print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f}, Val Score={current_score:.4f}, Best={best_score:.4f}")
        print_gpu_memory()

        base_model = unwrap_for_save(model)
        ckpt = {
            "model": base_model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "scaler": scaler.state_dict() if scaler.is_enabled() else None,
            "epoch": epoch,
            "best_score": max(best_score, current_score),
            "cfg": cfg.__dict__,
            "history": history
        }

        if current_score > best_score:
            best_score = current_score; patience = 0
            torch.save(ckpt, processed_dir / cfg.checkpoint_name)
            print("New best model saved!")
        else:
            patience += 1
            print(f"No improvement. Patience {patience}/{cfg.early_stop_patience}")
            if patience >= cfg.early_stop_patience:
                print("Early stopping.")
                break

        clear_memory()

    print("\nTraining complete!")
    print(f"Best Val Score: {best_score:.4f} | Reports dir: {reports_dir}")
    return True

# ===========================
# Fused inference helpers (direct 3-class)
# ===========================
def _clean_state_dict_keys(sd):
    out = {}
    for k, v in sd.items():
        if k.startswith("_orig_mod."): k = k[len("_orig_mod."):]
        if k.startswith("module."):    k = k[len("module."):]
        out[k] = v
    return out

def fused_label_names():
    # 0=Abusive, 1=Offensive, 2=Non-abusive
    return ["Abusive","Offensive","Non-abusive"]

# ===========================
# Eval + Insights (direct 3-class)
# ===========================
def load_model_and_processors_lora(ckpt_path, device="cpu"):
    assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    ckpt = torch.load(ckpt_path, map_location=device)
    assert "cfg" in ckpt, "Checkpoint missing cfg."

    class C: pass
    cfg_loaded = C()
    for k, v in ckpt["cfg"].items():
        setattr(cfg_loaded, k, v)

    tokenizer = AutoTokenizer.from_pretrained(cfg_loaded.text_model_name, use_fast=True)
    clip_proc = CLIPImageProcessor.from_pretrained(cfg_loaded.clip_vision_model) if getattr(cfg_loaded, "use_clip_vision", True) else None

    model = OptimizedMultiModalGated(
        text_model_name=cfg_loaded.text_model_name,
        freeze_text=getattr(cfg_loaded, "freeze_text_encoder", True),
        use_clip=getattr(cfg_loaded, "use_clip_vision", True),
        clip_model_name=getattr(cfg_loaded, "clip_vision_model", "openai/clip-vit-base-patch32"),
        fallback_backbone=getattr(cfg_loaded, "image_backbone_fallback", "mobilenet_v2"),
        freeze_vision=getattr(cfg_loaded, "freeze_vision_encoder", True),
        use_multi_task=getattr(cfg_loaded, "use_multi_task", True),
        ab2_mode=getattr(cfg_loaded, "ab2_mode", "ce"),
        hidden=getattr(cfg_loaded, "hidden", 256),
        dropout=getattr(cfg_loaded, "dropout", 0.25),
        add_meta_dim=getattr(cfg_loaded, "add_meta_dim", 2),
        gradient_checkpointing=False,
        cfg_loaded=cfg_loaded
    ).to(device)

    sd = _clean_state_dict_keys(ckpt["model"])
    model.load_state_dict(sd, strict=True)
    if device == "cuda":
        model = model.to(memory_format=torch.channels_last)
    model.eval()
    return model, tokenizer, clip_proc, cfg_loaded

def build_test_loader_for_fused(splits_dir, batch_size, sample_n=None):
    test_text_csv = Path(splits_dir) / "test/text_test.csv"
    memotion_csv  = Path(cfg.base_dir) / cfg.processed_dir_name / cfg.memotion_csv
    ds_test_text = OptimizedProcessedCSVSet(test_text_csv, is_memotion=False, sample_n=sample_n, pre_tokenize=True)
    ds_test_memo = OptimizedProcessedCSVSet(memotion_csv, is_memotion=True, split="test", sample_n=cfg.sample_test_memotion_n, pre_tokenize=True)
    test_ds = ConcatDataset([ds_test_text, ds_test_memo])
    loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False,
                        num_workers=2, collate_fn=collate_fn, pin_memory=(device=="cuda"))
    return loader

def evaluate_fused_on_loader(model, loader, amp=True):
    model.eval()
    all_y, all_p, all_probs = [], [], []
    names = fused_label_names()

    with torch.inference_mode():
        for batch in tqdm(loader, desc="Fused eval (direct 3-class)"):
            ids = batch["input_ids"].to(device); attn = batch["attention_mask"].to(device)
            imgs = batch["image"].to(device); has_img = batch["has_image"].to(device)
            meta = batch["meta"].to(device) if "meta" in batch else None
            y_hs3 = batch["hs3_label"].to(device)

            if amp: imgs = imgs.half()
            with torch.cuda.amp.autocast(enabled=amp, dtype=torch.float16):
                out = model(ids, attn, imgs, has_img, meta=meta)
                logits = out["hs3_logits"]
                probs  = torch.softmax(logits, dim=-1)
                pred   = probs.argmax(dim=1)

            y_np = y_hs3.detach().cpu().numpy()
            mask = (y_np != -100)

            if np.any(mask):
                all_y.append(y_np[mask])
                all_p.append(pred.detach().cpu().numpy()[mask])
                all_probs.append(probs.detach().cpu().numpy()[mask])

    if not all_y:
        print("No valid fused labels found in eval.")
        return None

    y = np.concatenate(all_y)
    p = np.concatenate(all_p)
    probs = np.concatenate(all_probs)

    print("\nFused report (direct 3-class head):")
    print(classification_report(y, p, target_names=names, zero_division=0))

    cm = confusion_matrix(y, p, labels=[0,1,2])
    plt.figure(figsize=(6,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=names, yticklabels=names)
    plt.title("Fused Confusion Matrix (Direct)"); plt.ylabel("True"); plt.xlabel("Pred")
    plt.tight_layout(); plt.savefig(os.path.join(OUT_DIR, "fused_cm.png"), dpi=150); plt.close()

    return {"y": y, "p": p, "probs": probs}

def gate_analysis(loader, model, amp=True):
    model.eval()
    gate_vals = []
    with torch.inference_mode():
        for batch in tqdm(loader, desc="Gate analysis"):
            ids = batch["input_ids"].to(device); attn = batch["attention_mask"].to(device)
            imgs = batch["image"].to(device); has_img = batch["has_image"].to(device)
            meta = batch["meta"].to(device) if "meta" in batch else None
            if amp: imgs = imgs.half()
            with torch.cuda.amp.autocast(enabled=amp, dtype=torch.float16):
                out = model(ids, attn, imgs, has_img, meta=meta)
            gw = out["gate_weight"].detach().cpu().numpy().reshape(-1)
            him = has_img.cpu().numpy().reshape(-1)
            gate_vals.extend(list(gw[him==1]))
    gate_vals = np.array(gate_vals)
    if len(gate_vals):
        print(f"Image-sample count: {len(gate_vals)} | gate mean={gate_vals.mean():.3f} | median={np.median(gate_vals):.3f}")
    return gate_vals

def precision_recall_thresholds(y_true_fused, fused_probs, target_precision=0.90):
    # Abusive binary: class 0 or 1 -> abusive (1), class 2 -> non (0)
    y_ab = (y_true_fused != 2).astype(int)
    p_ab = fused_probs[:,0] + fused_probs[:,1]
    ap = average_precision_score(y_ab, p_ab)
    print(f"Average precision (abusive): {ap:.4f}")
    prec, rec, thr = precision_recall_curve(y_ab, p_ab)
    best = None
    for i in range(len(prec)):
        if prec[i] >= target_precision:
            best = (thr[i-1] if i > 0 else 0.5, prec[i], rec[i]); break
    if best:
        th, bp, br = best
        print(f"Threshold @P>={target_precision}: th={th:.3f}, P={bp:.3f}, R={br:.3f}")
    return ap, prec, rec, thr

def expected_calibration_error(y_true_abusive, y_prob, n_bins=15):
    bins = np.linspace(0., 1., n_bins+1)
    binids = np.digitize(y_prob, bins) - 1
    ece = 0.0
    for b in range(n_bins):
        idx = binids == b
        if np.any(idx):
            conf = y_prob[idx].mean()
            acc = (y_true_abusive[idx] == (y_prob[idx] >= 0.5)).mean()
            ece += np.abs(acc - conf) * np.sum(idx) / len(y_prob)
    return ece

# ===========================
# Main flow
# ===========================
# 1) Train (optional)
if RUN_TRAIN:
    print("Training LoRA + hard-mined model (with image-only aux loss + text dropout)...")
    train_novel()
else:
    print("Skipping training (RUN_TRAIN=False)")

# 2) Load model for eval
print("Loading LoRA model for direct 3-class evaluation...")
model, tok, clip_proc, cfg_loaded = load_model_and_processors_lora(CKPT_PATH, device=device)

# 3) Build test loader for fused eval (full or sample)
test_loader = build_test_loader_for_fused(SPLITS_DIR, batch_size=BATCH_EVAL, sample_n=TEACHER_N)

# 4) Run direct 3-class eval
if RUN_FUSED_EVAL:
    fused_res = evaluate_fused_on_loader(model, test_loader, amp=torch.cuda.is_available())
    if fused_res is None:
        print("No fused results computed.")
    else:
        y = fused_res["y"]; p = fused_res["p"]; probs = fused_res["probs"]
        names = fused_label_names()

        # Insights: Precision/Recall thresholds for abusive (class 0 or 1)
        y_ab = (y != 2).astype(int)
        p_ab = probs[:,0] + probs[:,1]
        ap, prec, rec, thr = precision_recall_thresholds(y, probs, target_precision=0.90)
        ece = expected_calibration_error(y_ab, p_ab, n_bins=15)
        print(f"ECE (abusive): {ece:.3f}")

        # Plot PR curve
        plt.figure(figsize=(5,4))
        plt.plot(rec, prec, label=f'PR (AP={ap:.3f})')
        plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("Abusive PR Curve (Direct 3-class)")
        plt.legend(); plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, "abusive_pr_curve.png"), dpi=150); plt.close()

        # Export eval predictions table (with probs)
        df_pred = pd.DataFrame({
            "y_true": y,
            "y_pred": p,
            "p_abusive": probs[:,0],
            "p_offensive": probs[:,1],
            "p_non_abusive": probs[:,2]
        })
        df_pred.to_csv(os.path.join(OUT_DIR, "fused_predictions_eval.csv"), index=False)
        print(f"Saved fused_predictions_eval.csv to {OUT_DIR}")

print("\nAll done")



"""
INTERPRETATION OF EVALUATION PIPELINE — DIRECT 3-CLASS MULTIMODAL ASSESSMENT

The end-to-end evaluation pipeline provides comprehensive insights into multimodal
hate detection performance through direct 3-class classification and advanced
diagnostic metrics.

Key evaluation interpretations:

- **Direct 3-Class Framework**: Evaluates models on the fused hierarchy
  (0=Abusive, 1=Offensive, 2=Non-abusive) rather than separate binary/ternary
  tasks. This reflects real-world deployment where systems must distinguish
  between hate speech, offensive content, and acceptable material in one pass.

- **Gate Analysis**: Measures the multimodal gate's behavior across image-bearing
  samples, revealing how much the model relies on visual vs. textual information.
  High gate values indicate visual dominance, while low values show text reliance.
  This exposes modality preferences for different types of offensive content.

- **Precision-Recall Tradeoffs**: Analyzes abusive content detection at high
  precision thresholds (e.g., 90%), crucial for safety-critical applications
  where false positives are costly. The threshold analysis identifies optimal
  operating points for different deployment scenarios.

- **Calibration Assessment**: Computes Expected Calibration Error to measure
  how well predicted probabilities match true likelihoods. Poor calibration
  indicates overconfident or underconfident predictions, which is critical
  for risk assessment and content moderation decisions.

- **Comprehensive Export**: Generates prediction tables with per-class
  probabilities, enabling detailed error analysis and understanding of
  model uncertainty patterns across the offensive content spectrum.

Overall interpretation:
This evaluation framework moves beyond simple accuracy metrics to provide
deployment-ready insights about model behavior, calibration quality, and
optimal operating thresholds. The direct 3-class approach combined with
modality analysis and probability calibration makes the system suitable
for real-world content moderation where nuanced distinctions between
abusive, offensive, and acceptable content are essential.
"""

In [None]:
# user_test_fused_lora.py
# Direct 3-class inference: Abusive / Offensive / Non-abusive
# OCR: EasyOCR (default) + PaddleOCR support with safe CPU install and robust parsing
# Keys:
# - Normalize OCR text (NFKC, strip zero-width/NBSP, standardize punctuation)
# - If OCR text exists -> feed OCR text directly; else use user-typed text
# - Optionally ignore image when OCR text is used (avoid gate overshadow)
# - Mean-pooled text encoder; projection for text/image; gate over both modalities
# - Slightly larger inference max_len to reduce truncation

# Optional installs:
# !pip install -U transformers torchvision pandas numpy pillow peft gradio easyocr

import os, sys, gc, subprocess, warnings, re, unicodedata
warnings.filterwarnings("ignore", category=UserWarning)

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from PIL import Image

# Contractions (auto-install if missing)
try:
    import contractions
except ImportError:
    print("Installing 'contractions'...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "contractions"])
    import contractions

# PEFT (LoRA)
try:
    from peft import LoraConfig, get_peft_model, TaskType
except ImportError:
    print("Installing 'peft'...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "peft"])
    from peft import LoraConfig, get_peft_model, TaskType

from transformers import AutoTokenizer, AutoModel, AutoConfig, CLIPVisionModel, CLIPImageProcessor

# ======================
# Paths/config
# ======================
ckpt_path = os.environ.get("CKPT_PATH", "datasets/processed_data/best_novel.pt")  # set to your trained checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"

# Demo defaults
TEMP = 1.0
CLASS_NAMES = ["Abusive", "Offensive", "Non-abusive"]  # 0,1,2

# OCR defaults
OCR_ENGINE = "easyocr"   # 'easyocr' or 'paddle'
OCR_MIN_CONF = 0.50
OCR_MAX_CHARS = 300

# Robust OCR normalization settings
OCR_LOWER = True  # keep casing by default (RoBERTa-like models are cased)
INFER_MAX_LEN = int(os.environ.get("INFER_MAX_LEN", "128"))  # slightly > training max_len
FORCE_TEXT_ONLY_WHEN_OCR = True  # ignore image when OCR text available (to avoid gate overshadow)

# ======================
# Model class (mean pooling + projections + multimodal gate + LoRA)
# ======================
class OptimizedMultiModalGated(nn.Module):
    def __init__(self, text_model_name, freeze_text=True, use_clip=True, clip_model_name=None,
                 fallback_backbone="mobilenet_v2", freeze_vision=True,
                 use_multi_task=True, ab2_mode="ce", hidden=256, dropout=0.25, add_meta_dim=2,
                 gradient_checkpointing=False, cfg_loaded=None):
        super().__init__()
        self.use_multi_task = use_multi_task
        self.ab2_mode = ab2_mode
        self.add_meta_dim = add_meta_dim
        self.use_clip = use_clip
        self.cfg_loaded = cfg_loaded

        proj_dim = getattr(cfg_loaded, "proj_dim", 256)

        # Text encoder (no pooler; mean pooling)
        tcfg = AutoConfig.from_pretrained(text_model_name)
        if hasattr(tcfg, "add_pooling_layer"):
            tcfg.add_pooling_layer = False
        self.text_model = AutoModel.from_pretrained(text_model_name, config=tcfg)
        if gradient_checkpointing and hasattr(self.text_model, "gradient_checkpointing_enable"):
            try: self.text_model.gradient_checkpointing_enable()
            except: pass
        tdim = self.text_model.config.hidden_size
        if getattr(cfg_loaded, "freeze_text_encoder", True):
            for p in self.text_model.parameters(): p.requires_grad = False

        # LoRA on text
        if getattr(cfg_loaded, "use_lora_text", False):
            targets = getattr(cfg_loaded, "lora_target_text", None) or ["query","value"]
            self.text_model = get_peft_model(
                self.text_model,
                LoraConfig(
                    r=getattr(cfg_loaded, "lora_r_text", 8),
                    lora_alpha=getattr(cfg_loaded, "lora_alpha_text", 16),
                    lora_dropout=getattr(cfg_loaded, "lora_dropout_text", 0.1),
                    target_modules=targets,
                    bias="none",
                    task_type=TaskType.FEATURE_EXTRACTION
                )
            )

        # Vision encoder
        if use_clip:
            self.vision = CLIPVisionModel.from_pretrained(clip_model_name or "openai/clip-vit-base-patch32")
            idim = self.vision.config.hidden_size
            if getattr(cfg_loaded, "freeze_vision_encoder", True):
                for p in self.vision.parameters(): p.requires_grad = False
            if getattr(cfg_loaded, "use_lora_vision", False):
                vtargets = getattr(cfg_loaded, "lora_target_vision", None) or ["q_proj","v_proj"]
                self.vision = get_peft_model(
                    self.vision,
                    LoraConfig(
                        r=getattr(cfg_loaded, "lora_r_vision", 4),
                        lora_alpha=getattr(cfg_loaded, "lora_alpha_vision", 8),
                        lora_dropout=getattr(cfg_loaded, "lora_dropout_vision", 0.05),
                        target_modules=vtargets,
                        bias="none",
                        task_type=TaskType.FEATURE_EXTRACTION
                    )
                )
            self.use_clip = True
        else:
            from torchvision import models
            if fallback_backbone.lower() == "mobilenet_v2":
                try:
                    from torchvision.models import MobileNet_V2_Weights
                    im = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
                except Exception:
                    im = models.mobilenet_v2(pretrained=True)
                idim = 1280; im.classifier = nn.Identity()
            else:
                try:
                    from torchvision.models import ResNet18_Weights
                    backbone = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
                except Exception:
                    backbone = models.resnet18(pretrained=True)
                idim = 512; im = nn.Sequential(*list(backbone.children())[:-1], nn.Flatten(1))
            if getattr(cfg_loaded, "freeze_vision_encoder", True):
                for p in im.parameters(): p.requires_grad = False
            self.vision = im
            self.use_clip = False

        self.image_feat_dim = idim

        # Projections (align scales)
        self.t_proj = nn.Sequential(
            nn.Linear(tdim, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Dropout(getattr(cfg_loaded, "dropout", 0.25)),
        )
        self.i_proj = nn.Sequential(
            nn.Linear(idim, proj_dim),
            nn.LayerNorm(proj_dim),
            nn.GELU(),
            nn.Dropout(getattr(cfg_loaded, "dropout", 0.25)),
        )

        # Backbones
        fusion_in = proj_dim + proj_dim + (getattr(cfg_loaded, "add_meta_dim", 2) if getattr(cfg_loaded, "add_meta_dim", 2)>0 else 0)
        text_in   = proj_dim + (getattr(cfg_loaded, "add_meta_dim", 2) if getattr(cfg_loaded, "add_meta_dim", 2)>0 else 0)

        hidden = getattr(cfg_loaded, "hidden", 256)
        dropout = getattr(cfg_loaded, "dropout", 0.25)

        self.fusion_backbone = nn.Sequential(
            nn.Linear(fusion_in, hidden), nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.GELU(), nn.Dropout(dropout),
        )
        self.text_backbone = nn.Sequential(
            nn.Linear(text_in, hidden), nn.ReLU(inplace=True), nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2), nn.GELU(), nn.Dropout(dropout),
        )

        # Gate over both modalities (+meta)
        self.gate_proj = nn.Sequential(
            nn.Linear(fusion_in, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

        # Heads
        self.hs3_fusion_head = nn.Linear(hidden//2, 3)
        self.hs3_text_head   = nn.Linear(hidden//2, 3)
        self.ab2_fusion_head = nn.Linear(hidden//2, 2)
        self.ab2_text_head   = nn.Linear(hidden//2, 2)

    def _mean_pool(self, last_hidden_state, attention_mask):
        mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        summed = (last_hidden_state * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(self, input_ids, attention_mask, images, has_image_mask, meta=None):
        # Text
        out_text = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
        t_last = out_text.last_hidden_state
        pooling = getattr(self.cfg_loaded, "text_pooling", "mean")
        t_feat = self._mean_pool(t_last, attention_mask) if pooling == "mean" else t_last[:,0,:]
        t_proj = self.t_proj(t_feat)

        if (meta is not None) and (meta.dtype != t_proj.dtype):
            meta = meta.to(t_proj.dtype)

        # Vision (dtype-safe; only compute for has_image samples)
        B = images.size(0)
        img_mask_flat = has_image_mask.view(-1)
        if img_mask_flat.any():
            idx = torch.nonzero(img_mask_flat, as_tuple=False).squeeze(1)
            images_sub = images[idx].to(memory_format=torch.channels_last)
            if self.use_clip:
                v_out = self.vision(pixel_values=images_sub)
                i_sub = v_out.pooler_output
            else:
                i_sub = self.vision(images_sub)
            i_sub = i_sub.to(images_sub.dtype)
            i_feat = images.new_zeros((B, self.image_feat_dim), dtype=images_sub.dtype)
            i_feat[idx] = i_sub
        else:
            i_feat = images.new_zeros((B, self.image_feat_dim), dtype=images.dtype)
        i_proj = self.i_proj(i_feat)

        # Inputs
        add_meta_dim = getattr(self.cfg_loaded, "add_meta_dim", 2)
        if add_meta_dim and meta is not None:
            t_in = torch.cat([t_proj, meta], dim=1)
            f_in = torch.cat([t_proj, i_proj, meta], dim=1)
        else:
            t_in = t_proj
            f_in = torch.cat([t_proj, i_proj], dim=1)

        # Backbones
        t_repr = self.text_backbone(t_in)
        f_repr = self.fusion_backbone(f_in)

        # Gate (mask to zero if no image)
        gate_weight = self.gate_proj(f_in) * img_mask_flat.float().unsqueeze(1)  # [B,1]

        # Heads
        hs3_f = self.hs3_fusion_head(f_repr)
        hs3_t = self.hs3_text_head(t_repr)
        hs3_logits = gate_weight * hs3_f + (1 - gate_weight) * hs3_t

        ab2_f = self.ab2_fusion_head(f_repr)
        ab2_t = self.ab2_text_head(t_repr)
        ab2_logits = gate_weight * ab2_f + (1 - gate_weight) * ab2_t

        return {
            "hs3_logits": hs3_logits,
            "ab2_logits": ab2_logits,
            "gate_weight": gate_weight.detach(),
            "hs3_f_only": hs3_f,
            "ab2_f_only": ab2_f,
            "hs3_t_only": hs3_t,
            "ab2_t_only": ab2_t,
        }

# ======================
# Loader + predictor helpers
# ======================
def _clean_keys(sd):
    out = {}
    for k, v in sd.items():
        if k.startswith("_orig_mod."): k = k[10:]
        if k.startswith("module."):    k = k[7:]
        out[k] = v
    return out

def load_model_and_processors_lora(ckpt_path, device="cpu"):
    assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    ckpt = torch.load(ckpt_path, map_location=device)
    assert "cfg" in ckpt, "Checkpoint missing cfg."

    class C: pass
    cfg_loaded = C()
    for k, v in ckpt["cfg"].items():
        setattr(cfg_loaded, k, v)
    if not hasattr(cfg_loaded, "text_pooling"):
        setattr(cfg_loaded, "text_pooling", "mean")

    tok = AutoTokenizer.from_pretrained(cfg_loaded.text_model_name, use_fast=True)
    clip_proc = CLIPImageProcessor.from_pretrained(cfg_loaded.clip_vision_model) if getattr(cfg_loaded, "use_clip_vision", True) else None

    model = OptimizedMultiModalGated(
        text_model_name=cfg_loaded.text_model_name,
        freeze_text=getattr(cfg_loaded, "freeze_text_encoder", True),
        use_clip=getattr(cfg_loaded, "use_clip_vision", True),
        clip_model_name=getattr(cfg_loaded, "clip_vision_model", "openai/clip-vit-base-patch32"),
        fallback_backbone=getattr(cfg_loaded, "image_backbone_fallback", "mobilenet_v2"),
        freeze_vision=getattr(cfg_loaded, "freeze_vision_encoder", True),
        use_multi_task=getattr(cfg_loaded, "use_multi_task", True),
        ab2_mode=getattr(cfg_loaded, "ab2_mode", "ce"),
        hidden=getattr(cfg_loaded, "hidden", 256),
        dropout=getattr(cfg_loaded, "dropout", 0.25),
        add_meta_dim=getattr(cfg_loaded, "add_meta_dim", 2),
        gradient_checkpointing=False,
        cfg_loaded=cfg_loaded
    ).to(device)

    sd = _clean_keys(ckpt["model"])
    model.load_state_dict(sd, strict=True)
    if device == "cuda":
        model = model.to(memory_format=torch.channels_last)
    model.eval()
    return model, tok, clip_proc, cfg_loaded

def get_image_tensor(image, clip_proc, image_size=224):
    # image can be numpy array, PIL.Image, file path, or None
    pil = None
    if image is None:
        pil = None
    elif isinstance(image, str):
        if os.path.exists(image):
            pil = Image.open(image).convert("RGB")
    elif isinstance(image, Image.Image):
        pil = image.convert("RGB")
    else:
        try:
            pil = Image.fromarray(image).convert("RGB")
        except Exception:
            pil = None

    if pil is None:
        img_tensor = torch.zeros(3, image_size, image_size, dtype=torch.float32)
        return img_tensor, None

    if clip_proc is not None:
        img_tensor = clip_proc(images=pil, return_tensors="pt")["pixel_values"][0]
    else:
        from torchvision import transforms
        tfm = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
        ])
        img_tensor = tfm(pil)
    return img_tensor, pil

def softmax_temp(logits, T=1.0):
    if T is None or T <= 0: T = 1.0
    return torch.softmax(logits / float(T), dim=-1)

# ======================
# OCR helpers + normalization (EasyOCR + PaddleOCR, safe fallback)
# ======================
_OCR_READER = None
_OCR_BACKEND = None
_ZWS_RE = re.compile(r'[\u200B-\u200D\uFEFF]')  # zero-width chars

# --- Aggressive OCR cleanup helpers (drop-in) ---
# Collapse long character repeats: "sooooo" -> "soo"
_REPEAT_RE = re.compile(r"(.)\1{2,}")

# Simple leetspeak deobfuscation map (used inside tokens that mix letters+digits)
_LEET_MAP = {
    "0": "o",
    "1": "i",
    "3": "e",
    "4": "a",
    "5": "s",
    "7": "t",
    "8": "ate",  # helps "h8" -> "hate" and "l8r" -> "later"
}

# Common texting slang/abbreviations to standard English (lowercase keys)
_SLANG_MAP = {
    # single words
    "u": "you", "ur": "your", "r": "are", "ya": "you", "yall": "you all",
    "im": "i am", "ima": "i am", "ama": "ask me anything",
    "idk": "i do not know", "imo": "in my opinion", "imho": "in my humble opinion",
    "btw": "by the way", "brb": "be right back", "gtg": "got to go", "ttyl": "talk to you later",
    "tbh": "to be honest", "smh": "shaking my head", "rn": "right now",
    "omg": "oh my god", "irl": "in real life",
    "pls": "please", "plz": "please",
    "thx": "thanks", "thanx": "thanks", "tnx": "thanks", "ty": "thank you", "tysm": "thank you so much",
    "bc": "because", "cuz": "because", "coz": "because",
    "wanna": "want to", "gonna": "going to", "gotta": "got to",
    "kinda": "kind of", "sorta": "sort of", "lemme": "let me", "gimme": "give me",
    "aint": "is not", "cant": "cannot", "wont": "will not",
    "shouldnt": "should not", "couldnt": "could not", "wouldnt": "would not",
    "dont": "do not", "doesnt": "does not", "didnt": "did not",
    "isnt": "is not", "arent": "are not", "wasnt": "was not", "werent": "were not",
    "havent": "have not", "hasnt": "has not", "hadnt": "had not", "mustnt": "must not", "neednt": "need not",
    "ok": "okay", "kk": "okay", "k": "okay",
    "lol": "laughing", "lmao": "laughing", "lmfao": "laughing", "rofl": "laughing", "xd": "laughing",
    "nah": "no", "yup": "yes", "yeah": "yes", "nope": "no", "bruh": "bro",
    "w": "with",  # common "w" -> "with"
    "luv": "love",
    # numeric slang
    "b4": "before", "l8r": "later", "gr8": "great",
    "2day": "today", "2moro": "tomorrow", "2mrw": "tomorrow", "tmrw": "tomorrow", "tmr": "tomorrow",
    "4u": "for you", "4ya": "for you", "4you": "for you",
    # common obfuscations of profanity (normalize to base form)
    "fuk": "fuck", "fck": "fuck", "fcuk": "fuck", "phuck": "fuck", "fucc": "fuck",
    "wtf": "what the fuck", "wth": "what the hell",
}

def _strip_punct_and_symbols(s: str) -> str:
    # Remove all Unicode punctuation (P*) and symbol (S*) categories -> replace with space
    return "".join(ch if not unicodedata.category(ch).startswith(("P", "S")) else " " for ch in s)

def _reduce_repeats(s: str) -> str:
    # Cap character repeats to max 2 (e.g., "loooove" -> "loove")
    return _REPEAT_RE.sub(r"\1\1", s)

def _deobfuscate_leet_token(tok: str) -> str:
    # Only apply if token mixes letters and digits; keep pure numbers intact
    if tok.isdigit() or not any(c.isdigit() for c in tok):
        return tok
    out = []
    for ch in tok:
        if ch.isdigit():
            out.append(_LEET_MAP.get(ch, ch))
        else:
            out.append(ch)
    return "".join(out)

def _expand_slang(s: str) -> str:
    # Token-level expansion using _SLANG_MAP and leet deobfuscation
    toks = s.split()
    out = []
    for t in toks:
        # Direct slang
        if t in _SLANG_MAP:
            repl = _SLANG_MAP[t]
            out.extend(repl.split())
            continue
        # Leet deobfuscation
        t2 = _deobfuscate_leet_token(t)
        # Try slang after leet (e.g., "h8" -> "hate" or via map)
        if t2 in _SLANG_MAP:
            repl = _SLANG_MAP[t2]
            out.extend(repl.split())
        else:
            out.append(t2)
    return " ".join(out)

def _install_paddle_cpu():
    try:
        import paddle  # noqa
        return True
    except Exception:
        pass
    try:
        print("Installing paddlepaddle (CPU) ...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "paddlepaddle==2.6.1"])
        import paddle  # noqa
        return True
    except Exception as e:
        print("PaddlePaddle install failed:", e)
        return False

def ensure_ocr(engine="easyocr"):
    """
    Returns (reader, backend). If Paddle fails, falls back to EasyOCR.
    Paddle is forced to CPU to avoid CUDA mismatches.
    """
    global _OCR_READER, _OCR_BACKEND
    if _OCR_READER is not None and _OCR_BACKEND == engine:
        return _OCR_READER, _OCR_BACKEND

    if engine == "paddle":
        ok = _install_paddle_cpu()
        if not ok:
            print("Falling back to EasyOCR due to Paddle install failure.")
            engine = "easyocr"

    if engine == "paddle":
        try:
            import paddle
            paddle.device.set_device("cpu")
            from paddleocr import PaddleOCR
            _OCR_READER = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
            _OCR_BACKEND = "paddle"
            return _OCR_READER, _OCR_BACKEND
        except Exception as e:
            print("PaddleOCR init failed, falling back to EasyOCR:", e)
            engine = "easyocr"

    try:
        import easyocr
    except ImportError:
        print("Installing easyocr...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "easyocr"])
        import easyocr
    _OCR_READER = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
    _OCR_BACKEND = "easyocr"
    return _OCR_READER, _OCR_BACKEND

def normalize_text(s: str, lower: bool = OCR_LOWER, max_chars: int = OCR_MAX_CHARS) -> str:
    if not s:
        return ""

    # 1) Unicode normalize + strip invisibles and standardize punctuation
    s = unicodedata.normalize("NFKC", s)
    s = s.replace("\u00A0", " ")  # NBSP -> space
    s = _ZWS_RE.sub("", s)        # remove zero-width chars
    s = (s.replace("“", '"').replace("”", '"')
           .replace("‘", "'").replace("’", "'")
           .replace("—", "-").replace("–", "-").replace("‐", "-")
           .replace("…", "..."))

    # 2) Early whitespace squeeze (helps contractions lib)
    s = " ".join(s.split())

    # 3) Expand English contractions before stripping punctuation
    try:
        s = contractions.fix(s)
    except Exception:
        pass

    # 4) Lowercase for consistent matching
    if lower:
        s = s.lower()

    # 5) Aggressive: remove punctuation and symbols
    s = _strip_punct_and_symbols(s)

    # 6) Reduce character elongations (heeellooo -> heelloo)
    s = _reduce_repeats(s)

    # 7) Expand common texting slang + simple leetspeak
    s = _expand_slang(s)

    # 8) Final whitespace squeeze
    s = " ".join(s.split())

    # 9) Truncate to max chars
    if len(s) > max_chars:
        s = s[:max_chars]

    return s

def _parse_paddle_result(res, min_conf=0.0):
    """
    Robustly parse PaddleOCR results into (text, conf) list.
    Supports both:
      [[ [points], (text, conf) ], ...] or [ [ [ [points], (text, conf) ], ... ] ]
    """
    items = []
    if not isinstance(res, list) or len(res) == 0:
        return items
    # If single-image list wrapper, unwrap
    if len(res) == 1 and isinstance(res[0], list) and len(res[0]) > 0 and isinstance(res[0][0], list):
        res = res[0]
    for entry in res:
        try:
            if not isinstance(entry, (list, tuple)) or len(entry) < 2:
                continue
            info = entry[1]
            if isinstance(info, (list, tuple)) and len(info) >= 2:
                txt, conf = info[0], float(info[1])
                if txt and str(txt).strip() and conf >= min_conf:
                    items.append((str(txt).strip(), conf))
        except Exception:
            continue
    return items

def ocr_extract_text(pil_image, engine=OCR_ENGINE, min_conf=OCR_MIN_CONF) -> str:
    if pil_image is None:
        return ""
    reader, backend = ensure_ocr(engine)
    text_pieces = []
    try:
        if backend == "paddle":
            from paddleocr import PaddleOCR  # noqa: F401
            res = reader.ocr(np.array(pil_image), cls=True)
            pairs = _parse_paddle_result(res, min_conf=min_conf)
            text_pieces = [t for t, c in pairs]
        else:
            import easyocr  # noqa: F401
            res = reader.readtext(np.array(pil_image))
            for _, txt, conf in res:
                if float(conf) >= min_conf and txt and txt.strip():
                    text_pieces.append(txt.strip())
    except Exception as e:
        print("OCR failed:", e)
        return ""
    return " ".join(text_pieces)

# ======================
# Load model once
# ======================
if any(n not in globals() for n in ["model", "tok", "clip_proc", "cfg_loaded"]):
    model, tok, clip_proc, cfg_loaded = load_model_and_processors_lora(ckpt_path, device=device)
    print("Loaded model")

# ======================
# Inference
# ======================
@torch.inference_mode()
def user_predict(
    text: str,
    image=None,
    sarcasm=False,
    humour=False,
    temperature: float = None,
    use_ocr: bool = True,
    ocr_engine: str = OCR_ENGINE,
    ocr_min_conf: float = OCR_MIN_CONF,
    ignore_image_when_ocr: bool = FORCE_TEXT_ONLY_WHEN_OCR
):
    temperature = float(TEMP if temperature is None else temperature)

    # Prepare image tensor and OCR text
    img_tensor, pil = get_image_tensor(image, clip_proc, image_size=getattr(cfg_loaded, "image_size", 224))
    if torch.cuda.is_available(): img_tensor = img_tensor.half()
    imgs = img_tensor.unsqueeze(0).to(device)
    has_img = torch.tensor([img_tensor.abs().sum().item() > 0], dtype=torch.bool, device=device)

    raw_ocr_text = ""
    if use_ocr and pil is not None:
        raw_ocr_text = ocr_extract_text(pil, engine=ocr_engine, min_conf=ocr_min_conf)
    ocr_text = normalize_text(raw_ocr_text, lower=OCR_LOWER, max_chars=OCR_MAX_CHARS)

    # If OCR produced text, feed it directly; else use user text
    user_text = (text or "").strip()
    final_text = ocr_text if ocr_text.strip() else user_text

    # Optionally ignore image when OCR text is used (avoid gate overshadow)
    if ignore_image_when_ocr and ocr_text.strip():
        has_img = torch.tensor([False], dtype=torch.bool, device=device)
        # imgs tensor can remain unchanged; it won't be used when has_img=False

    # Tokenize (slightly larger max_len at inference)
    used_max_len = max(getattr(cfg_loaded, "max_len", 96), INFER_MAX_LEN)
    enc = tok(final_text, truncation=True, max_length=used_max_len, padding=True, return_tensors="pt")
    ids = enc["input_ids"].to(device); attn = enc["attention_mask"].to(device)

    meta_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    meta = torch.tensor([[int(bool(sarcasm)), int(bool(humour))]], dtype=meta_dtype, device=device)

    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available(), dtype=torch.float16):
        out = model(ids, attn, imgs, has_img, meta)
        probs_t = softmax_temp(out["hs3_logits"], T=temperature).detach().cpu()[0]
        gate = float(out["gate_weight"].mean().detach().cpu().item() if out["gate_weight"].numel() else 0.0)

    probs = probs_t.numpy()
    pred_idx = int(np.argmax(probs))
    pred = CLASS_NAMES[pred_idx]

    # Abusive probability (binary): P(Abusive or Offensive)
    p_abusive = float(probs[0] + probs[1])

    return {
        "final_label": pred,
        "final_probs": {k: float(v) for k, v in zip(CLASS_NAMES, probs)},
        "abusive_probability": p_abusive,
        "gate_weight_image": gate,  # ~0 = text-dominant, ~1 = image-dominant
        "used_image": bool(has_img.item()),
        "temperature": temperature,
        "ocr_text_raw": raw_ocr_text,
        "ocr_text_norm": ocr_text,
        "final_text": final_text,
        "ocr_engine": _OCR_BACKEND or ocr_engine,
        "ocr_min_conf": ocr_min_conf,
        "ignored_image_due_to_ocr": bool(ignore_image_when_ocr and ocr_text.strip() != "")
    }

# Quick CLI-like test
try:
    print(user_predict("I hate you", image=None, sarcasm=False, humour=False))
except Exception as e:
    print("Quick test failed:", e)

# ======================
# Batch CSV inference (optional)
# CSV columns supported: text, image_path, sarcasm, humour
# ======================
def predict_csv(csv_path, out_path=None, use_ocr=True, ocr_engine=OCR_ENGINE, ocr_min_conf=OCR_MIN_CONF, ignore_image_when_ocr=FORCE_TEXT_ONLY_WHEN_OCR):
    import pandas as pd
    df = pd.read_csv(csv_path)
    outs = []
    for i, row in df.iterrows():
        text = str(row.get("text", "") or "")
        img_p = row.get("image_path", None)
        img_p = str(img_p) if (img_p is not None and str(img_p).strip() != "" and os.path.exists(str(img_p))) else None
        sarcasm = bool(row.get("sarcasm", False))
        humour  = bool(row.get("humour", False))
        res = user_predict(text, image=img_p, sarcasm=sarcasm, humour=humour,
                           use_ocr=use_ocr, ocr_engine=ocr_engine, ocr_min_conf=ocr_min_conf,
                           ignore_image_when_ocr=ignore_image_when_ocr)
        outs.append({
            "text": text,
            "image_path": img_p,
            "sarcasm": sarcasm,
            "humour": humour,
            "pred_label": res["final_label"],
            "p_abusive": res["abusive_probability"],
            "p_Abusive": res["final_probs"]["Abusive"],
            "p_Offensive": res["final_probs"]["Offensive"],
            "p_Non_abusive": res["final_probs"]["Non-abusive"],
            "gate_weight_image": res["gate_weight_image"],
            "used_image": res["used_image"],
            "ocr_text_raw": res["ocr_text_raw"],
            "ocr_text_norm": res["ocr_text_norm"],
            "final_text": res["final_text"],
            "ignored_image_due_to_ocr": res["ignored_image_due_to_ocr"],
        })
    out_df = pd.DataFrame(outs)
    if out_path:
        out_df.to_csv(out_path, index=False)
        print(f"Saved predictions to {out_path}")
    return out_df

# ======================
# Gradio UI (optional)
# ======================
try:
    import gradio as gr
    def gr_fn(text, image, sarcasm, humour, temperature, use_ocr, ocr_min_conf, ocr_engine, ignore_image_when_ocr):
        res = user_predict(
            text, image=image, sarcasm=sarcasm, humour=humour, temperature=temperature,
            use_ocr=use_ocr, ocr_engine=ocr_engine, ocr_min_conf=ocr_min_conf,
            ignore_image_when_ocr=ignore_image_when_ocr
        )
        probs_display = {
            "Abusive": res["final_probs"]["Abusive"],
            "Offensive": res["final_probs"]["Offensive"],
            "Non-abusive": res["final_probs"]["Non-abusive"],
            "Abusive (binary)": res["abusive_probability"],
        }
        extra = (
            f"Gate(image)={res['gate_weight_image']:.2f} | Used image={res['used_image']} "
            f"| Temp={res['temperature']} | OCR={res['ocr_engine']}(min_conf={res['ocr_min_conf']}) "
            f"| Ignored image due to OCR={res['ignored_image_due_to_ocr']}"
        )
        return probs_display, res["final_label"], res["ocr_text_raw"], res["ocr_text_norm"], res["final_text"], extra

    demo = gr.Interface(
        fn=gr_fn,
        inputs=[
            gr.Textbox(label="Text", lines=3, placeholder="Type a comment..."),
            gr.Image(label="Optional image", type="numpy"),
            gr.Checkbox(label="Sarcasm flag", value=False),
            gr.Checkbox(label="Humour flag", value=False),
            gr.Slider(0.5, 2.0, value=TEMP, step=0.1, label="Temperature"),
            gr.Checkbox(label="Use OCR", value=True),
            gr.Slider(0.0, 1.0, value=OCR_MIN_CONF, step=0.05, label="OCR min confidence"),
            gr.Radio(choices=["easyocr","paddle"], value="easyocr", label="OCR engine"),
            gr.Checkbox(label="Ignore image when OCR text is used", value=True),
        ],
        outputs=[
            gr.Label(num_top_classes=4, label="Probabilities"),
            gr.Textbox(label="Final label"),
            gr.Textbox(label="OCR text (raw)"),
            gr.Textbox(label="OCR text (normalized)"),
            gr.Textbox(label="Final text sent to model"),
            gr.Markdown(label="Details"),
        ],
        title="Abusive/Offensive Detection (LoRA, direct 3-class) + OCR (Paddle-safe + Easy fallback)"
    )
    # Uncomment to launch with a public link:
    demo.launch(share=True)
    print("Gradio app ready. Uncomment demo.launch(share=True) to run.")
except Exception as e:
    print("Gradio not available:", e)



"""
INTERPRETATION OF DEPLOYMENT PIPELINE — ROBUST MULTIMODAL HATE DETECTION WITH OCR

The user inference pipeline demonstrates a production-ready system for multimodal
hate detection that handles real-world complexities through sophisticated OCR
integration and robust text normalization.

Key deployment interpretations:

- **OCR-Driven Modality Selection**: Intelligently prioritizes OCR-extracted text
  over user-provided text when available, recognizing that meme text often
  contains the primary offensive content. The system can optionally ignore
  visual features when OCR text is present to prevent gate overshadowing.

- **Aggressive Text Normalization Pipeline**: Implements multi-stage text cleaning
  that handles common obfuscation techniques used in hateful content:
  - Unicode normalization and zero-width character removal
  - Contraction expansion and punctuation standardization  
  - Character repeat reduction ("sooooo" → "soo")
  - Leetspeak deobfuscation ("h8" → "hate", "l8r" → "later")
  - Texting slang expansion ("u" → "you", "ur" → "your")
  This normalization is crucial for detecting coded hate speech and evasive language.

- **Dual OCR Engine Support**: Provides fallback between PaddleOCR (higher accuracy)
  and EasyOCR (easier installation), with automatic CPU fallback to handle
  deployment environment constraints and CUDA compatibility issues.

- **Modality Gate Transparency**: Exposes the gate weight that shows how much
  the model relies on visual vs. textual features, providing interpretability
  for why particular predictions were made.

- **Temperature-Controlled Confidence**: Allows adjustment of prediction confidence
  through temperature scaling, enabling calibration for different risk tolerance
  levels in production deployment.

- **Batch Processing Capability**: Supports CSV-based batch inference for
  large-scale content moderation workflows, with comprehensive output including
  OCR results and modality usage.

Overall interpretation:
This deployment system addresses the practical challenges of real-world hate
detection: evasive text patterns, multimodal content, and deployment constraints.
The sophisticated OCR pipeline ensures that visual text is properly captured
and normalized, while the flexible inference options support both interactive
moderation interfaces and automated batch processing. The system's transparency
about modality usage and confidence calibration makes it suitable for
high-stakes content moderation applications.
"""

In [None]:
!pip install vaderSentiment wordcloud xgboost -q

In [None]:
import pandas as pd
import os

# Define the folder structure
base_folder = "datasets/processed_data/splits"
subfolders = ["train", "test", "val"]
file_names = {
    "train": "text_train.csv",
    "test": "text_test.csv",
    "val": "text_val.csv"
}

# List to store all dataframes
dataframes = []

for folder in subfolders:
    file_path = os.path.join(base_folder, folder, file_names[folder])

    # Check if file exists
    if os.path.exists(file_path):
        # Read the CSV file
        df = pd.read_csv(file_path)

        # Prune columns - keep only 'text' and 'label'
        columns_to_keep = [col for col in ['text', 'label'] if col in df.columns]
        df_pruned = df[columns_to_keep]

        # Add source column to track which split it came from
        df_pruned['source'] = folder

        dataframes.append(df_pruned)
        print(f"Processed {file_path}: {len(df_pruned)} rows")
    else:
        print(f"Warning: File not found - {file_path}")

# Combine all dataframes
if dataframes:
    combined_df = pd.concat(dataframes, ignore_index=True)

    # Display info about the combined dataset
    print(f"\nCombined dataset shape: {combined_df.shape}")
    print(f"\nRows from each source:")
    print(combined_df['source'].value_counts())

    # Save the combined dataframe to CSV
    combined_df.to_csv('combined_text_data.csv', index=False)
    print(f"\nCombined CSV saved as 'combined_text_data.csv'")

    # Display first few rows
    print(f"\nFirst few rows of combined data:")
    print(combined_df.head())

else:
    print("No data was processed. Please check your file paths.")





"""
INTERPRETATION OF DATA CONSOLIDATION PIPELINE — DATASET AGGREGATION AND STRUCTURAL ANALYSIS

The data consolidation script performs systematic aggregation of training splits
while maintaining provenance tracking and implementing selective column retention.

Key structural interpretations:

- **Multi-Split Integration**: Combines train, validation, and test partitions
  into a unified dataset, enabling holistic analysis of data distribution
  across the entire experimental pipeline. This reveals potential data leakage
  or split imbalances that could affect model generalization.

- **Selective Feature Retention**: Prunes all columns except 'text' and 'label',
  focusing the dataset on core predictive features while eliminating metadata
  and auxiliary variables that might complicate analysis or introduce bias.

- **Provenance Tracking**: Adds a 'source' column to preserve split membership,
  allowing researchers to trace predictions back to original data partitions
  and analyze performance variations across training/validation/test sets.

- **Data Integrity Verification**: Implements existence checks for file paths
  and provides detailed logging of processing outcomes, ensuring transparent
  handling of missing or corrupted data sources.

- **Unified Analysis Foundation**: Creates a consolidated dataset suitable for:
  - Exploratory data analysis across all splits
  - Label distribution comparison between partitions
  - Text quality assessment and preprocessing validation
  - Cross-split model error analysis

Overall interpretation:
This consolidation represents a crucial preprocessing step that transforms
disjointed experimental splits into an analytically coherent dataset. The
deliberate column pruning and source tracking balance analytical utility
with data integrity, creating a foundation for robust model evaluation
and dataset quality assessment across the entire machine learning pipeline.
The approach demonstrates thoughtful data management practices that support
reproducible research and comprehensive model diagnostics.
"""


In [None]:
import os, re, gc, string, warnings, json, time
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
matplotlib.use('Agg')  # lighter backend
import matplotlib.pyplot as plt

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer, ENGLISH_STOP_WORDS
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import (
    roc_auc_score, average_precision_score, f1_score, balanced_accuracy_score,
    roc_curve, precision_recall_curve, confusion_matrix
)
from sklearn.calibration import calibration_curve # Moved from sklearn.metrics
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import mannwhitneyu



# Import optional
try:
    from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
    VADER_AVAILABLE = True
except Exception:
    VADER_AVAILABLE = False

try:
    from wordcloud import WordCloud
    WORDCLOUD_AVAILABLE = True
except Exception:
    WORDCLOUD_AVAILABLE = False

try:
    from xgboost import XGBClassifier
    XGB_AVAILABLE = True
except Exception:
    XGB_AVAILABLE = False

warnings.filterwarnings("ignore")
pd.set_option('mode.chained_assignment', None)

# ----------------------------
# CONFIG — T4 RAM OPTIMIZED
# ----------------------------
BINARY_PALETTE = ["#1a9850", "#d73027"]  # [Non-Abusive, Abusive]
sns.set_theme(style="whitegrid", palette=BINARY_PALETTE)
plt.rcParams.update({
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "font.size": 10,
    "axes.prop_cycle": plt.cycler(color=BINARY_PALETTE),
})

CFG = {
    "CSV_PATH": "combined_text_data.csv",
    "TEXT_COL": "text",
    "LABEL_COL": "label",
    "OUTPUT_DIR": "figures_full",
    "SVD_COMPONENTS": 200,
    "MIN_DF": 50,
    "MAX_DF": 0.85,
    "SENTIMENT_SAMPLE": 80_000,
    "WORDCLOUD_SAMPLE": 150_000,
    "MASK_TERMS": True,
    "MASK_KEEP": 1,
    "ALPHA0": 10.0,
    "RANDOM_STATE": 42
}

os.makedirs(CFG["OUTPUT_DIR"], exist_ok=True)
np.random.seed(CFG["RANDOM_STATE"])

# ----------------------------
# UTILS
# ----------------------------
def normalize_label_series(s):
    vals = s.astype("string").str.strip().str.lower()
    mapping = {
        "abusive":1,"toxic":1,"hate":1,"hateful":1,"offensive":1,"insult":1,
        "non-abusive":0,"non_abusive":0,"non abusive":0,"neutral":0,"benign":0,"clean":0
    }
    if pd.api.types.is_numeric_dtype(s) and set(pd.unique(pd.to_numeric(s, errors="coerce").dropna())).issubset({0,1}):
        return s.astype("int8")
    mapped = vals.map(mapping)
    if mapped.isna().any():
        cats = sorted(vals.dropna().unique())
        if len(cats)==2: mapped = vals.map({cats[0]:0, cats[1]:1})
    return mapped.fillna(0).astype("int8")

def mask_token(tok, keep=1):
    if len(tok) <= 2*keep: return "*" * len(tok)
    return tok[:keep] + ("*" * (len(tok) - 2*keep)) + tok[-keep:]

def maybe_mask_terms(terms, mask=True, keep=1):
    if not mask: return terms
    out = []
    for t in terms:
        pieces = t.split() if " " in t else [t]
        masked = []
        for p in pieces:
            if len(p) > 3 and any(c.isalpha() for c in p):
                masked.append(mask_token(p, keep))
            else:
                masked.append(p)
        out.append(" ".join(masked))
    return out

EMOJI_RE = re.compile("[" "\U0001F1E6-\U0001F1FF" "\U0001F300-\U0001F5FF" "\U0001F600-\U0001F64F"
                      "\U0001F680-\U0001F6FF" "\U0001F700-\U0001F77F" "\U0001F780-\U0001F7FF"
                      "\U0001F800-\U0001F8FF" "\U0001F900-\U0001F9FF" "\U0001FA00-\U0001FAFF"
                      "\u2600-\u26FF" "\u2700-\u27BF" "]+")

STOPWORDS = list(set(ENGLISH_STOP_WORDS) - {"no", "nor", "not", "never"})
TOK_RE = re.compile(r"#\w+|@\w+|[A-Za-z]+(?:'[A-Za-z]+)?|[A-Za-z]+")

def custom_tokenizer(text): return TOK_RE.findall(text)

def clean_for_vectorizer(txt):
    x = str(txt)
    x = re.sub(r"http\S+", " URL ", x)
    x = re.sub(r"@\w+", " @user ", x)
    x = re.sub(r"\s+", " ", x).strip().lower()
    return x

def stratified_sample(df, label_col="label", per_class=50_000, seed=42):
    parts = []
    for v, g in df.groupby(label_col):
        n = min(per_class, len(g))
        parts.append(g.sample(n, random_state=seed))
    return pd.concat(parts).sample(frac=1, random_state=seed).reset_index(drop=True) if parts else df

def timer(msg):
    print(f"\n--- {msg} ---")
    gc.collect()

# ----------------------------
# LOAD DATA — ALL ROWS
# ----------------------------
print("Loading FULL dataset...")
df = pd.read_csv(CFG["CSV_PATH"], usecols=[CFG["TEXT_COL"], CFG["LABEL_COL"]])
df[CFG["TEXT_COL"]] = df[CFG["TEXT_COL"]].astype("string").fillna("").str.strip()
df["label"] = normalize_label_series(df[CFG["LABEL_COL"]])
print(f"Loaded {len(df):,} rows")

# ----------------------------
# BASIC METRICS
# ----------------------------
df["char_count"] = df[CFG["TEXT_COL"]].str.len().astype(np.int32)
df["word_count"] = df[CFG["TEXT_COL"]].str.count(r"\b\w+\b").astype(np.int32)

label_counts = df["label"].value_counts().sort_index()
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
sample_plot = stratified_sample(df, "label", per_class=min(100_000, df["label"].value_counts().min()))
sns.countplot(data=sample_plot, x="label", ax=axes[0])
axes[0].set_xticklabels(["Non-Abusive", "Abusive"])
axes[1].pie(label_counts.values, labels=["Non-Abusive", "Abusive"], colors=BINARY_PALETTE, autopct='%1.1f%%')
plt.tight_layout()
plt.savefig(f"{CFG['OUTPUT_DIR']}/01_label_distribution.png", bbox_inches="tight")
plt.close()

sample_plot = stratified_sample(df[["char_count","word_count","label"]], "label", per_class=50_000)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for col, ax in zip(["char_count", "word_count"], axes):
    sns.histplot(data=sample_plot, x=col, hue="label", kde=True, ax=ax, alpha=0.6)
    ax.set_title(col.replace("_"," ").title())
plt.tight_layout()
plt.savefig(f"{CFG['OUTPUT_DIR']}/02_text_length.png", bbox_inches="tight")
plt.close()

# ----------------------------
# AGGRESSION FEATURES
# ----------------------------
timer("Aggression cues")
t = df[CFG["TEXT_COL"]]
df["exclamation_count"] = t.str.count("!").astype(np.int16)
df["upper_letter_count"] = t.str.count(r"[A-Z]").astype(np.int32)
df["total_letter_count"] = t.str.count(r"[A-Za-z]").astype(np.int32)
df["upper_letter_ratio"] = (df["upper_letter_count"] / df["total_letter_count"].clip(lower=1)).fillna(0).astype(np.float32)
df["elongation_count"] = t.str.count(r"(.)\1{2,}").astype(np.int16)
df["all_caps_token_ratio"] = t.str.count(r"\b[A-Z]{2,}\b").div(t.str.count(r"\b\w+\b").clip(lower=1)).fillna(0.0).astype(np.float32)
df["punct_ratio"] = t.str.count(rf"[{re.escape(string.punctuation)}]").div(t.str.len().clip(lower=1)).fillna(0.0).astype(np.float32)

if len(df) <= 1_500_000:
    df["emoji_count"] = t.apply(lambda x: len(EMOJI_RE.findall(str(x)))).astype(np.int16)
else:
    emoji_sample = df.sample(500_000, random_state=CFG["RANDOM_STATE"])
    emoji_means = emoji_sample.groupby("label")[CFG["TEXT_COL"]].apply(
        lambda s: np.mean([len(EMOJI_RE.findall(str(x))) for x in s])
    ).to_dict()
    df["emoji_count"] = df["label"].map(emoji_means).fillna(0).astype(np.float32)

plot_df = stratified_sample(
    df[["upper_letter_ratio","exclamation_count","elongation_count","emoji_count",
        "all_caps_token_ratio","punct_ratio","label"]],
    "label", per_class=50_000
)

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
cols = ["upper_letter_ratio","exclamation_count","elongation_count","emoji_count","all_caps_token_ratio","punct_ratio"]
for i, col in enumerate(cols):
    ax = axes[i//3, i%3]
    sns.boxplot(data=plot_df, x="label", y=col, ax=ax)
    ax.set_title(col.replace("_"," ").title())
    ax.set_xticklabels(["Non-Abusive","Abusive"])
plt.tight_layout()
plt.savefig(f"{CFG['OUTPUT_DIR']}/03_aggression_cues.png", bbox_inches="tight")
plt.close()

# ----------------------------
# N-GRAMS — FULL DATA
# ----------------------------
timer("N-gram modeling (FULL data, RAM-optimized)")
ngram_df = df.copy()
ngram_df["clean"] = ngram_df[CFG["TEXT_COL"]].apply(clean_for_vectorizer)

tfidf = TfidfVectorizer(
    tokenizer=custom_tokenizer, preprocessor=lambda x: x, lowercase=True,
    stop_words=STOPWORDS, ngram_range=(1,2), min_df=CFG["MIN_DF"], max_df=CFG["MAX_DF"], strip_accents="unicode"
)
X_tfidf = tfidf.fit_transform(ngram_df["clean"])
feats = np.array(tfidf.get_feature_names_out())
y = ngram_df["label"].values

m1 = np.asarray(X_tfidf[y==1].mean(axis=0)).ravel()
m0 = np.asarray(X_tfidf[y==0].mean(axis=0)).ravel()

cv = CountVectorizer(vocabulary=feats, tokenizer=custom_tokenizer, preprocessor=lambda x: x)
Xc = cv.transform(ngram_df["clean"])
c1 = np.asarray(Xc[y==1].sum(axis=0)).ravel()
c0 = np.asarray(Xc[y==0].sum(axis=0)).ravel()

n1, n0 = c1.sum(), c0.sum()
bg = c1 + c0
alpha = CFG["ALPHA0"] * (bg / max(1, bg.sum()))

def log_odds_z(ca, na, cb, nb, a):
    la = np.log((ca + a) / (na + a.sum() - (ca + a) + 1e-12))
    lb = np.log((cb + a) / (nb + a.sum() - (cb + a) + 1e-12))
    delta = la - lb
    var = 1.0/(ca + a + 1e-12) + 1.0/(cb + a + 1e-12)
    return delta / np.sqrt(var)

z = log_odds_z(c1, n1, c0, n0, alpha)

df1 = (Xc[y==1] > 0).sum(axis=0).A1
df0 = (Xc[y==0] > 0).sum(axis=0).A1
pur_ab = df1 / np.maximum(1, df1 + df0)

def top_terms(direction="abusive", k=20, purity_thresh=0.65):
    if direction == "abusive":
        mask = (pur_ab >= purity_thresh) & (m1 > m0)
        idx = np.argsort(z[mask])[-k:][::-1]
        terms = feats[mask][idx]
    else:
        mask = (1 - pur_ab >= purity_thresh) & (m0 > m1)
        idx = np.argsort(-z[mask])[-k:][::-1]
        terms = feats[mask][idx]
    return terms, idx

top_ab_terms, _ = top_terms("abusive", k=20)
top_na_terms, _ = top_terms("non", k=20)

fig, axes = plt.subplots(1, 2, figsize=(16, 9))
for ax, terms, title, color in zip(
    axes,
    [top_ab_terms, top_na_terms],
    ["Abusive discriminators", "Non-Abusive discriminators"],
    [BINARY_PALETTE[1], BINARY_PALETTE[0]]
):
    z_vals = [z[np.where(feats == t)[0][0]] for t in terms]
    if "Non-Abusive" in title:
        z_vals = [-v for v in z_vals]
    ax.barh(range(len(terms)), z_vals, color=color)
    ax.set_yticks(range(len(terms)))
    ax.set_yticklabels(maybe_mask_terms(terms, CFG["MASK_TERMS"], CFG["MASK_KEEP"]))
    ax.invert_yaxis()
    ax.set_title(title + " (log-odds z)")
plt.tight_layout()
plt.savefig(f"{CFG['OUTPUT_DIR']}/05_discriminative_terms_logodds.png", bbox_inches="tight")
plt.close()

# Bigram novelty
cv2 = CountVectorizer(ngram_range=(2,2), tokenizer=custom_tokenizer, preprocessor=lambda x: x,
                      stop_words=STOPWORDS, min_df=30, max_df=0.9, strip_accents="unicode", binary=True)
X2 = cv2.fit_transform(ngram_df["clean"])
v2 = np.array(cv2.get_feature_names_out())
seen_ab = np.asarray(X2[y==1].sum(axis=0)).ravel() > 0
seen_na = np.asarray(X2[y==0].sum(axis=0)).ravel() > 0
novel_ab_rate = (seen_ab & ~seen_na).sum() / max(1, seen_ab.sum())
print(f"Bigram novelty rate: {novel_ab_rate:.2%}")

# ----------------------------
# SENTIMENT (sampled)
# ----------------------------
if VADER_AVAILABLE:
    timer("VADER sentiment")
    sent_sample = stratified_sample(df[[CFG["TEXT_COL"], "label"]], "label", per_class=CFG["SENTIMENT_SAMPLE"]//2)
    analyzer = SentimentIntensityAnalyzer()
    sent_sample["vader_compound"] = sent_sample[CFG["TEXT_COL"]].apply(lambda x: analyzer.polarity_scores(str(x))["compound"])
    plt.figure(figsize=(8, 5))
    sns.boxplot(data=sent_sample, x="label", y="vader_compound")
    plt.xticks([0,1], ["Non-Abusive","Abusive"])
    plt.title("VADER Compound Sentiment")
    plt.savefig(f"{CFG['OUTPUT_DIR']}/04_sentiment.png", bbox_inches="tight")
    plt.close()

# ----------------------------
# CORRELATION
# ----------------------------
corr_cols = ["char_count","word_count","upper_letter_ratio","exclamation_count","elongation_count","all_caps_token_ratio","punct_ratio"]
if "emoji_count" in df.columns:
    corr_cols.append("emoji_count")
plt.figure(figsize=(10, 8))
sns.heatmap(df[corr_cols].corr(), annot=True, cmap="coolwarm", center=0, fmt=".2f")
plt.title("Feature Correlation Matrix (FULL)")
plt.tight_layout()
plt.savefig(f"{CFG['OUTPUT_DIR']}/07_correlation.png", bbox_inches="tight")
plt.close()

# ----------------------------
# WORDCLOUD
# ----------------------------
if WORDCLOUD_AVAILABLE:
    try:
        wc_df = ngram_df.sample(min(CFG["WORDCLOUD_SAMPLE"], len(ngram_df)), random_state=CFG["RANDOM_STATE"])
        cv_wc = CountVectorizer(tokenizer=custom_tokenizer, preprocessor=lambda x: x,
                                stop_words=STOPWORDS, max_features=2000, min_df=10, strip_accents="unicode")
        Xc_wc = cv_wc.fit_transform(wc_df["clean"])
        vocab_wc = cv_wc.get_feature_names_out()
        freqs = np.asarray(Xc_wc.sum(axis=0)).ravel()
        word_freq = dict(zip(vocab_wc, freqs))
        wc = WordCloud(width=800, height=400, background_color="white", max_words=200, colormap="RdYlGn_r")
        wc.generate_from_frequencies(word_freq)
        plt.figure(figsize=(12, 6)); plt.imshow(wc, interpolation="bilinear"); plt.axis("off"); plt.title("Word Cloud (Sampled)")
        plt.savefig(f"{CFG['OUTPUT_DIR']}/08_wordcloud.png", bbox_inches="tight", dpi=300)
        plt.close()
    except Exception as e:
        print(f"WordCloud skipped: {e}")

# ----------------------------
# MODELING — FULL DATA
# ----------------------------
timer("Modeling (FULL data)")
model_df = df.copy()
model_df["clean"] = model_df[CFG["TEXT_COL"]].apply(clean_for_vectorizer)

X_text = tfidf.transform(model_df["clean"])
del ngram_df, Xc, cv, tfidf
gc.collect()

svd = TruncatedSVD(n_components=CFG["SVD_COMPONENTS"], random_state=CFG["RANDOM_STATE"])
X_text_svd = svd.fit_transform(X_text)
scaler = StandardScaler()
X_text_svd = scaler.fit_transform(X_text_svd)
del X_text, svd, scaler
gc.collect()

num_cols = ["upper_letter_ratio", "exclamation_count", "elongation_count", "all_caps_token_ratio", "punct_ratio"]
if "emoji_count" in df.columns:
    num_cols.append("emoji_count")
X_num = model_df[num_cols].fillna(0).to_numpy().astype(np.float32)

X_all = np.hstack([X_text_svd, X_num])
y_all = model_df["label"].to_numpy()
del X_text_svd, X_num, model_df
gc.collect()

groups = df[CFG["TEXT_COL"]].str.lower().str.replace(r"\s+", " ", regex=True).str.strip()
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=CFG["RANDOM_STATE"])
train_idx, test_idx = next(gss.split(X_all, y_all, groups=groups))
X_tr, X_te = X_all[train_idx], X_all[test_idx]
y_tr, y_te = y_all[train_idx], y_all[test_idx]
del X_all, y_all, groups
gc.collect()

def report_model(name, y_true, y_prob, y_pred):
    roc = roc_auc_score(y_true, y_prob)
    pr = average_precision_score(y_true, y_prob)
    f1 = f1_score(y_true, y_pred)
    ba = balanced_accuracy_score(y_true, y_pred)
    print(f"{name}: ROC-AUC={roc:.3f} | PR-AUC={pr:.3f} | F1={f1:.3f} | BalAcc={ba:.3f}")

# Logistic Regression
lr = LogisticRegression(max_iter=1000, class_weight="balanced", solver="liblinear", random_state=CFG["RANDOM_STATE"])
lr.fit(X_tr, y_tr)
lr_prob = lr.predict_proba(X_te)[:,1]
lr_pred = (lr_prob >= 0.5).astype(int)
report_model("LogReg", y_te, lr_prob, lr_pred)

# Random Forest
rf = RandomForestClassifier(
    n_estimators=300, max_depth=15, min_samples_leaf=3, n_jobs=-1,
    class_weight="balanced_subsample", random_state=CFG["RANDOM_STATE"]
)
rf.fit(X_tr, y_tr)
rf_prob = rf.predict_proba(X_te)[:,1]
rf_pred = (rf_prob >= 0.5).astype(int)
report_model("RandomForest", y_te, rf_prob, rf_pred)

# XGBoost
if XGB_AVAILABLE:
    pos, neg = (y_tr == 1).sum(), (y_tr == 0).sum()
    xgb = XGBClassifier(
        n_estimators=500, max_depth=7, learning_rate=0.1,
        subsample=0.8, colsample_bytree=0.8,
        tree_method="hist", objective="binary:logistic", eval_metric="auc",
        scale_pos_weight=neg / pos, random_state=CFG["RANDOM_STATE"], n_jobs=-1
    )
    xgb.fit(X_tr, y_tr, eval_set=[(X_te, y_te)], verbose=0)
    xgb_prob = xgb.predict_proba(X_te)[:,1]
    xgb_pred = (xgb_prob >= 0.5).astype(int)
    report_model("XGBoost", y_te, xgb_prob, xgb_pred)

# ----------------------------
# EXTRA MODEL PLOTS
# ----------------------------
def plot_roc_pr(y_true, y_proba, model_name, save_path):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    fpr, tpr, _ = roc_curve(y_true, y_proba)
    axes[0].plot(fpr, tpr, color=BINARY_PALETTE[1], lw=2, label=f'ROC-AUC = {roc_auc_score(y_true, y_proba):.3f}')
    axes[0].plot([0,1],[0,1], '--', color='gray')
    axes[0].set(xlabel='FPR', ylabel='TPR', title=f'{model_name} - ROC')
    axes[0].legend()

    prec, rec, _ = precision_recall_curve(y_true, y_proba)
    axes[1].plot(rec, prec, color=BINARY_PALETTE[0], lw=2, label=f'PR-AUC = {average_precision_score(y_true, y_proba):.3f}')
    axes[1].set(xlabel='Recall', ylabel='Precision', title=f'{model_name} - PR')
    axes[1].legend()
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_calibration(y_true, y_proba, model_name, save_path):
    prob_true, prob_pred = calibration_curve(y_true, y_proba, n_bins=10)
    plt.figure(figsize=(6,5))
    plt.plot(prob_pred, prob_true, 'o-', color=BINARY_PALETTE[1])
    plt.plot([0,1],[0,1], '--', color='gray')
    plt.xlabel("Mean Predicted Prob"); plt.ylabel("Fraction Positive")
    plt.title(f"{model_name} - Calibration")
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_confusion_matrix(y_true, y_pred, model_name, save_path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap="RdYlGn_r",
                xticklabels=["Non-Abusive", "Abusive"],
                yticklabels=["Non-Abusive", "Abusive"])
    plt.title(f"{model_name} - Confusion Matrix")
    plt.ylabel("True"); plt.xlabel("Predicted")
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

def plot_feature_importance(model, feature_names, model_name, save_path, top_k=20):
    if hasattr(model, 'coef_'):
        imp = np.abs(model.coef_[0])
        top_idx = np.argsort(imp)[-top_k:][::-1]
        color = BINARY_PALETTE[1]
    elif hasattr(model, 'feature_importances_'):
        imp = model.feature_importances_
        top_idx = np.argsort(imp)[-top_k:][::-1]
        color = BINARY_PALETTE[0]
    else:
        return
    names = [feature_names[i] for i in top_idx]
    masked = [maybe_mask_terms([n], CFG["MASK_TERMS"], CFG["MASK_KEEP"])[0] if any(c.isalpha() for c in n) else n for n in names]
    plt.figure(figsize=(8, top_k*0.35))
    plt.barh(range(len(imp[top_idx])), imp[top_idx], color=color)
    plt.yticks(range(len(masked)), masked)
    plt.gca().invert_yaxis()
    plt.xlabel("Importance" if "Forest" in model_name or "XGBoost" in model_name else "|Coefficient|")
    plt.title(f"{model_name} - Top {top_k} Features")
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.close()

# Feature names for importance
svd_names = [f"SVD_{i}" for i in range(CFG["SVD_COMPONENTS"])]
full_feature_names = svd_names + num_cols

# LogReg
plot_roc_pr(y_te, lr_prob, "Logistic Regression", f"{CFG['OUTPUT_DIR']}/10_lr_roc_pr.png")
plot_calibration(y_te, lr_prob, "Logistic Regression", f"{CFG['OUTPUT_DIR']}/11_lr_calibration.png")
plot_confusion_matrix(y_te, lr_pred, "Logistic Regression", f"{CFG['OUTPUT_DIR']}/12_lr_confusion.png")
plot_feature_importance(lr, full_feature_names, "Logistic Regression", f"{CFG['OUTPUT_DIR']}/13_lr_feature_importance.png")

# RandomForest
plot_roc_pr(y_te, rf_prob, "Random Forest", f"{CFG['OUTPUT_DIR']}/14_rf_roc_pr.png")
plot_calibration(y_te, rf_prob, "Random Forest", f"{CFG['OUTPUT_DIR']}/15_rf_calibration.png")
plot_confusion_matrix(y_te, rf_pred, "Random Forest", f"{CFG['OUTPUT_DIR']}/16_rf_confusion.png")
plot_feature_importance(rf, full_feature_names, "Random Forest", f"{CFG['OUTPUT_DIR']}/17_rf_feature_importance.png")

# XGBoost
if XGB_AVAILABLE:
    plot_roc_pr(y_te, xgb_prob, "XGBoost", f"{CFG['OUTPUT_DIR']}/18_xgb_roc_pr.png")
    plot_calibration(y_te, xgb_prob, "XGBoost", f"{CFG['OUTPUT_DIR']}/19_xgb_calibration.png")
    plot_confusion_matrix(y_te, xgb_pred, "XGBoost", f"{CFG['OUTPUT_DIR']}/20_xgb_confusion.png")
    plot_feature_importance(xgb, full_feature_names, "XGBoost", f"{CFG['OUTPUT_DIR']}/21_xgb_feature_importance.png")

print("All model plots saved!")

# ----------------------------
# FINAL SUMMARY
# ----------------------------
summary = {
    "total_samples": len(df),
    "abusive_ratio": float((df["label"] == 1).mean()),
    "bigram_novelty_rate": float(novel_ab_rate),
    "mean_upper_ratio_abusive": float(df[df["label"]==1]["upper_letter_ratio"].mean()),
    "mean_upper_ratio_non": float(df[df["label"]==0]["upper_letter_ratio"].mean()),
    "top_abusive_terms_masked": maybe_mask_terms(list(top_ab_terms[:10]), CFG["MASK_TERMS"], CFG["MASK_KEEP"]),
    "top_non_abusive_terms_masked": maybe_mask_terms(list(top_na_terms[:10]), CFG["MASK_TERMS"], CFG["MASK_KEEP"])
}
with open(f"{CFG['OUTPUT_DIR']}/eda_summary_full.json", "w") as f:
    json.dump(summary, f, indent=2)

print("\n" + "="*60)
print("FULL ANALYSIS COMPLETE — All rows used, all plots saved")
print(f"Output: {CFG['OUTPUT_DIR']}/")
print("Non-Abusive = #1a9850 | Abusive = #d73027")
print("="*60)




"""
INTERPRETATION OF COMPREHENSIVE EDA PIPELINE — FULL-SCALE TEXT ANALYSIS FRAMEWORK

The exploratory data analysis pipeline implements a systematic, production-grade
framework for understanding abusive content patterns through multi-modal
feature engineering and statistical modeling.

Key analytical interpretations:

- **Multi-Dimensional Feature Engineering**: Extracts linguistic, structural,
  and behavioral cues including capitalization patterns, punctuation intensity,
  character elongation, and emoji usage. These features capture both explicit
  and subtle aggression markers that distinguish abusive from non-abusive content.

- **Statistical Discriminative Analysis**: Employs log-odds ratio with z-scoring
  to identify terms with the strongest class separation power, revealing both
  overt hate speech vocabulary and more nuanced offensive language patterns.

- **RAM-Optimized Scalability**: Implements strategic sampling, sparse matrix
  operations, and memory-efficient data types to handle massive text corpora
  while maintaining analytical rigor across the entire dataset.

- **Multi-Model Benchmarking**: Compares Logistic Regression, Random Forest,
  and XGBoost to evaluate different learning paradigms' effectiveness for
  abusive content detection, providing insights into feature interactions
  and classification boundaries.

- **Comprehensive Model Diagnostics**: Extends beyond basic accuracy metrics
  to include ROC/PR curves, calibration analysis, confusion matrices, and
  feature importance, offering a complete picture of model behavior and
  deployment readiness.

- **Privacy-Preserving Visualization**: Implements term masking for sensitive
  content while maintaining analytical transparency, balancing research
  utility with ethical considerations.

Overall interpretation:
This EDA framework represents a sophisticated approach to understanding
abusive content that moves beyond simple keyword counting to capture
the complex linguistic and behavioral patterns characteristic of online
harm. The systematic feature engineering and multi-model evaluation
provide both immediate insights into dataset characteristics and
foundational understanding for developing robust detection systems.
The implementation demonstrates careful consideration of computational
constraints while maintaining analytical depth, making it suitable for
both research and production applications in content moderation.
"""

In [None]:
# Memotion Dataset 7k — Local/Jupyter-Compatible EDA
RUN_TEXT_TSNE = False
RUN_IMAGE_EMBEDDINGS = False
SAMPLE_IMAGE_EMB_COUNT = 500
TFIDF_MAX_FEATURES = 5000

# Install dependencies (safe in Jupyter)
try:
    import wordcloud
except ImportError:
    !pip install -q wordcloud

import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from wordcloud import WordCloud
import re
from collections import Counter
import json
import warnings
warnings.filterwarnings("ignore")
sns.set(style="whitegrid")

# Define local paths (you can update DATA_DIR to your actual folder)
DATA_DIR = Path("datasets/memotion_dataset_7k")  # CHANGE THIS IF NEEDED
assert DATA_DIR.exists(), f"Dataset folder '{DATA_DIR}' not found. Please set DATA_DIR correctly."

# Use consistent output directory
OUT_DIR = Path("figure_multimodal")
OUT_DIR.mkdir(parents=True, exist_ok=True)
print("Outputs will be saved to:", OUT_DIR.resolve())

# ------------------------------------------------------------------
# LOAD DATA
# ------------------------------------------------------------------
csvs = {p.name: p for p in DATA_DIR.glob('*.csv')}
print("Found CSVs:", list(csvs.keys()))

# Merge reference.csv and labels.csv if both exist
if 'reference.csv' in csvs and 'labels.csv' in csvs:
    ref = pd.read_csv(csvs['reference.csv'], encoding='utf-8', low_memory=False)
    lab = pd.read_csv(csvs['labels.csv'], encoding='utf-8', low_memory=False)
    print("reference.csv shape:", ref.shape)
    print("labels.csv shape:", lab.shape)

    common_cols = set(ref.columns).intersection(set(lab.columns))
    print("Common columns:", common_cols)

    if 'image_name' in ref.columns and 'image_name' in lab.columns:
        df_all = ref.merge(lab, on='image_name', how='outer', suffixes=('_ref', '_lab'))
    else:
        # Try other common keys
        possible_keys = ['image_url', 'imageid', 'image_id', 'img', 'img_name']
        key = None
        for k in possible_keys:
            if k in ref.columns and k in lab.columns:
                key = k
                break
        if key:
            df_all = ref.merge(lab, on=key, how='outer', suffixes=('_ref', '_lab'))
        else:
            print("No common key found. Concatenating (may cause duplication).")
            df_all = pd.concat([ref, lab], ignore_index=True)
else:
    # Load all CSVs and concatenate
    dfs = []
    for name, p in csvs.items():
        try:
            dfs.append(pd.read_csv(p, encoding='utf-8', low_memory=False))
        except Exception:
            dfs.append(pd.read_csv(p, encoding='latin1', low_memory=False))
    df_all = pd.concat(dfs, ignore_index=True) if len(dfs) > 1 else dfs[0]

print("Combined dataframe shape:", df_all.shape)
display(df_all.head(3))

# ------------------------------------------------------------------
# COLUMN MAPPING & BASIC STATS
# ------------------------------------------------------------------
def find_col_by_keywords(keywords, available_cols):
    lower_map = {c.lower(): c for c in available_cols}
    for k in keywords:
        if k in lower_map:
            return lower_map[k]
    for c in available_cols:
        for k in keywords:
            if k in c.lower():
                return c
    return None

cols = df_all.columns.tolist()
col_map = {
    'image_name': find_col_by_keywords(['image_name','imagename','image','imageurl','image_url'], cols),
    'text_ocr': find_col_by_keywords(['text_ocr','text','text_corrected','ocr','caption'], cols),
    'text_corrected': find_col_by_keywords(['text_corrected','textclean','clean_text'], cols),
    'sentiment': find_col_by_keywords(['overall_sentiment','sentiment','label','labels'], cols),
    'humour': find_col_by_keywords(['humour','humor'], cols),
    'sarcasm': find_col_by_keywords(['sarcasm'], cols),
    'offensive': find_col_by_keywords(['offensive'], cols),
    'motivational': find_col_by_keywords(['motivational','motiv'], cols),
}

print("\nGuessed column mapping:")
for k, v in col_map.items():
    print(f" - {k}: {v}")

# ------------------------------------------------------------------
# LABEL DISTRIBUTIONS
# ------------------------------------------------------------------
if col_map['sentiment'] and col_map['sentiment'] in df_all.columns:
    s = col_map['sentiment']
    print("\nSentiment distribution:")
    display(df_all[s].value_counts(dropna=False))
    plt.figure(figsize=(6, 4))
    order = df_all[s].value_counts().index
    sns.countplot(y=df_all[s], order=order)
    plt.title('Sentiment Distribution')
    plt.tight_layout()
    plt.savefig(OUT_DIR / "01_sentiment_dist.png", bbox_inches="tight")
    plt.show()

for e in ['humour', 'sarcasm', 'offensive', 'motivational']:
    c = col_map.get(e)
    if c and c in df_all.columns:
        print(f"\n{e.capitalize()} distribution:")
        display(df_all[c].value_counts(dropna=False))
        plt.figure(figsize=(6, 3))
        sns.countplot(y=df_all[c], order=df_all[c].value_counts().index)
        plt.title(f"{e.capitalize()} Distribution")
        plt.tight_layout()
        plt.savefig(OUT_DIR / f"02_{e}_dist.png", bbox_inches="tight")
        plt.show()

# ------------------------------------------------------------------
# TEXT PROCESSING
# ------------------------------------------------------------------
FALLBACK_STOPWORDS = {
    'a','an','the','and','or','is','are','to','of','in','on','for','with','that','this','it','as','at','by','be',
    'from','was','were','has','have','i','you','we','they','he','she','not','but','so','if','then'
}

try:
    from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
    STOPWORDS = set(ENGLISH_STOP_WORDS).union(FALLBACK_STOPWORDS)
except Exception:
    STOPWORDS = FALLBACK_STOPWORDS

def simple_tokenize(text):
    if not isinstance(text, str) or not text.strip():
        return []
    s = text.lower()
    s = re.sub(r'http\S+', ' ', s)
    s = re.sub(r"[^a-z0-9']", ' ', s)
    tokens = [t for t in re.split(r'\s+', s) if t and t not in STOPWORDS and len(t) > 1]
    return tokens

text_col = col_map.get('text_ocr') or col_map.get('text_corrected')
if text_col and text_col in df_all.columns:
    if col_map.get('text_corrected') and col_map['text_corrected'] in df_all.columns:
        df_all['clean_text'] = df_all[col_map['text_corrected']].fillna('').astype(str)
    else:
        df_all['clean_text'] = df_all[text_col].fillna('').astype(str)

    df_all['clean_text'] = df_all['clean_text'].apply(lambda s: re.sub(r'\s+', ' ', s.strip().lower()))

    # Token stats
    tokens = []
    for t in df_all['clean_text']:
        tokens.extend(simple_tokenize(t))
    top_tokens = Counter(tokens).most_common(60)
    print("\nTop tokens:", top_tokens[:20])

    # Bigrams
    bigrams = Counter()
    for t in df_all['clean_text']:
        toks = simple_tokenize(t)
        for i in range(len(toks) - 1):
            bigrams[(toks[i], toks[i+1])] += 1
    top_bigrams = [(" ".join(k), v) for k, v in bigrams.most_common(30)]
    print("\nTop bigrams:", top_bigrams[:20])

    # WordCloud
    all_text = " ".join([t for t in df_all['clean_text'] if isinstance(t, str) and t.strip()])
    if all_text:
        wc = WordCloud(width=900, height=400, background_color='white', collocations=False).generate(all_text)
        plt.figure(figsize=(12, 5))
        plt.imshow(wc, interpolation='bilinear')
        plt.axis('off')
        plt.title('WordCloud — All Meme Texts')
        plt.savefig(OUT_DIR / "03_wordcloud.png", bbox_inches="tight", dpi=300)
        plt.show()

    # Text length features
    df_all['text_len_chars'] = df_all['clean_text'].str.len()
    df_all['text_len_words'] = df_all['clean_text'].apply(lambda s: len(simple_tokenize(s)))

    plt.figure(figsize=(12, 4))
    plt.subplot(1,2,1)
    sns.histplot(df_all['text_len_chars'].dropna(), bins=50)
    plt.title('Text Length (Characters)')
    plt.subplot(1,2,2)
    sns.histplot(df_all['text_len_words'].dropna(), bins=30)
    plt.title('Text Length (Tokens)')
    plt.tight_layout()
    plt.savefig(OUT_DIR / "04_text_length.png", bbox_inches="tight")
    plt.show()
else:
    print("No usable text column found.")

# ------------------------------------------------------------------
# IMAGE ANALYSIS
# ------------------------------------------------------------------
# Guess image directory
image_dir = None
for candidate in [DATA_DIR / 'images', DATA_DIR]:
    if candidate.is_dir():
        n_imgs = len(list(candidate.glob('*.jpg'))) + len(list(candidate.glob('*.png')))
        if n_imgs > 10:
            image_dir = candidate
            break

print("Image directory:", image_dir)

if image_dir:
    img_files = [p for p in image_dir.rglob('*') if p.suffix.lower() in ['.jpg', '.jpeg', '.png']]
    print(f"Found {len(img_files)} image files.")

    # Sample images
    if col_map['image_name'] and col_map['image_name'] in df_all.columns:
        sample_fns = df_all[col_map['image_name']].dropna().sample(min(12, len(df_all))).tolist()
        plt.figure(figsize=(12, 8))
        for i, fn in enumerate(sample_fns):
            plt.subplot(3, 4, i + 1)
            try:
                p = image_dir / fn
                if not p.exists():
                    matches = list(image_dir.rglob(fn))
                    p = matches[0] if matches else p
                img = Image.open(p).convert('RGB')
                plt.imshow(img)
                plt.axis('off')
                plt.title(str(fn)[:30], fontsize=8)
            except Exception:
                plt.text(0.5, 0.5, "missing", ha='center', va='center')
                plt.axis('off')
                plt.title(str(fn)[:30], fontsize=8)
        plt.suptitle('Sample Memes')
        plt.tight_layout()
        plt.savefig(OUT_DIR / "05_sample_images.png", bbox_inches="tight", dpi=150)
        plt.show()

    # Image size stats (sample 1000)
    sizes = []
    for p in img_files[:1000]:
        try:
            with Image.open(p) as im:
                sizes.append(im.size)
        except:
            continue
    if sizes:
        widths = [s[0] for s in sizes]
        heights = [s[1] for s in sizes]
        ratios = [w / h for w, h in sizes if h > 0]
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 3, 1)
        sns.histplot(widths, bins=30)
        plt.title('Width')
        plt.subplot(1, 3, 2)
        sns.histplot(heights, bins=30)
        plt.title('Height')
        plt.subplot(1, 3, 3)
        sns.histplot(ratios, bins=30)
        plt.title('Aspect Ratio (W/H)')
        plt.tight_layout()
        plt.savefig(OUT_DIR / "06_image_stats.png", bbox_inches="tight")
        plt.show()
else:
    print("No image directory found.")

# ------------------------------------------------------------------
# CORRELATIONS
# ------------------------------------------------------------------
if 'text_len_words' in df_all.columns and col_map['sentiment'] in df_all.columns:
    plt.figure(figsize=(6, 4))
    sns.boxplot(x=df_all[col_map['sentiment']], y=df_all['text_len_words'])
    plt.title('Text Length vs Sentiment')
    plt.tight_layout()
    plt.savefig(OUT_DIR / "07_text_len_vs_sentiment.png", bbox_inches="tight")
    plt.show()

if image_dir and col_map['image_name'] in df_all.columns:
    fn_to_area = {}
    for p in img_files[:2000]:
        try:
            with Image.open(p) as im:
                fn_to_area[p.name] = im.width * im.height
        except:
            continue
    if fn_to_area:
        df_all['img_area'] = df_all[col_map['image_name']].map(fn_to_area).fillna(0)
        if col_map['sentiment'] in df_all.columns:
            plt.figure(figsize=(6, 4))
            sns.boxplot(x=df_all[col_map['sentiment']], y=df_all['img_area'])
            plt.yscale('log')
            plt.title('Image Area vs Sentiment (log scale)')
            plt.tight_layout()
            plt.savefig(OUT_DIR / "08_img_area_vs_sentiment.png", bbox_inches="tight")
            plt.show()

# ------------------------------------------------------------------
# OPTIONAL: TEXT t-SNE (disabled by default)
# ------------------------------------------------------------------
if RUN_TEXT_TSNE and 'clean_text' in df_all.columns:
    print("Running t-SNE (may take several minutes)...")
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.decomposition import TruncatedSVD
    from sklearn.manifold import TSNE

    vect = TfidfVectorizer(max_features=TFIDF_MAX_FEATURES, ngram_range=(1,2), stop_words='english')
    X_tfidf = vect.fit_transform(df_all['clean_text'].fillna(''))
    svd = TruncatedSVD(n_components=50, random_state=42)
    X_svd = svd.fit_transform(X_tfidf)
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
    X_tsne = tsne.fit_transform(X_svd)
    df_all['text_tsne_x'] = X_tsne[:,0]
    df_all['text_tsne_y'] = X_tsne[:,1]

    plt.figure(figsize=(8,6))
    if col_map['sentiment'] in df_all.columns:
        sns.scatterplot(data=df_all, x='text_tsne_x', y='text_tsne_y', hue=col_map['sentiment'], alpha=0.6)
    else:
        plt.scatter(df_all['text_tsne_x'], df_all['text_tsne_y'], s=6)
    plt.title('t-SNE of Meme Texts')
    plt.savefig(OUT_DIR / "09_text_tsne.png", bbox_inches="tight")
    plt.show()

# ------------------------------------------------------------------
# OPTIONAL: IMAGE EMBEDDINGS (disabled by default)
# ------------------------------------------------------------------
if RUN_IMAGE_EMBEDDINGS and image_dir:
    print("Image embeddings require PyTorch and GPU. Skipping unless enabled and configured.")

# ------------------------------------------------------------------
# SAVE OUTPUTS
# ------------------------------------------------------------------
df_all.head(2000).to_csv(OUT_DIR / 'df_sample_2000.csv', index=False)

try:
    with open(OUT_DIR / 'top_tokens.json', 'w') as f:
        json.dump({'top_tokens': top_tokens[:200]}, f, indent=2)
except Exception as e:
    print("Could not save tokens:", e)

print(f"\nEDA complete. All outputs saved to: {OUT_DIR.resolve()}")\



"""
MEMOTION DATASET 7K — EXPLORATORY DATA ANALYSIS (EDA)
-----------------------------------------------------

DATA LOADING & PREPARATION:
- Merges multiple CSV files (reference.csv + labels.csv) using image_name as key
- Automatically maps columns using keyword matching for text, labels, and metadata
- Handles encoding issues and provides comprehensive data shape reporting

TEXT ANALYSIS:
- Implements smart tokenization with URL removal and stopword filtering
- Generates word frequency distributions and bigram analysis
- Creates word clouds to visualize prominent text themes
- Analyzes text length distributions in characters and tokens

VISUAL CONTENT ANALYSIS:
- Locates image directories and validates file accessibility
- Displays sample memes in grid layout for qualitative inspection
- Computes image dimensions and aspect ratio statistics
- Correlates visual features with textual sentiment labels

MULTIMODAL CORRELATIONS:
- Examines relationships between text length and sentiment categories
- Analyzes image area vs sentiment patterns using logarithmic scaling
- Provides optional t-SNE visualization for text embedding clustering
- Supports image embedding analysis (disabled by default for performance)

OUTPUT & REPORTING:
- Generates comprehensive visualization suite (9+ plot types)
- Saves processed data samples for further analysis
- Exports token frequency data as JSON for external use
- Creates self-contained output directory with all results

OVERALL INTERPRETATION:
This EDA pipeline provides a complete multimodal analysis of the Memotion 7K dataset,
revealing patterns in text characteristics, visual properties, and their relationships
with sentiment labels. The systematic approach enables data quality validation, feature
understanding, and informs subsequent model design decisions for meme classification tasks.
"""