# Text Simplification Dataset Analysis

This notebook analyzes a text simplification dataset with three complexity levels:
- **Elementary**: Simplified text for basic readers
- **Intermediate**: Medium complexity text
- **Advanced**: Complex/original text

The dataset contains parallel texts at different complexity levels, making it ideal for text simplification research.

In [None]:
!pip install textstat
!pip install rouge_score

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import re
from textstat import flesch_reading_ease, flesch_kincaid_grade, automated_readability_index
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
import ast
import warnings
warnings.filterwarnings('ignore')

import json
import os
import logging
from typing import Dict, List, Any
from dataclasses import dataclass, field
import time

import torch
from torch.utils.data import Dataset
from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    TrainingArguments, Trainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback
)
from datasets import Dataset as HFDataset, DatasetDict
import numpy as np
from rouge_score import rouge_scorer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name()}")
    print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 1. Data Loading and Initial Exploration

In [None]:
# Load the combined dataset
df = pd.read_csv('all_data.csv')

print(f"Dataset shape: {df.shape}")
print(f"\nColumns: {list(df.columns)}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Check for missing values
print("Missing values per column:")
print(df.isnull().sum())

print("\nData types:")
print(df.dtypes)

print("\nBasic statistics:")
print(df.describe())

## 2. Data Preprocessing and Text Parsing

In [None]:
def safe_eval(text):
    """Safely evaluate string representation of lists"""
    if pd.isna(text):
        return []
    try:
        # Try to evaluate as Python literal
        return ast.literal_eval(text)
    except (ValueError, SyntaxError):
        # If that fails, treat as a single string
        return [str(text)]

# Parse the text columns (they appear to be stored as string representations of lists)
df_parsed = df.copy()
for col in ['Elementary', 'Intermediate', 'Advanced']:
    df_parsed[col] = df_parsed[col].apply(safe_eval)

print("Sample parsed data:")
for i, row in df_parsed.head(2).iterrows():
    print(f"\n--- Article {i+1} ---")
    print(f"Elementary sentences: {len(row['Elementary'])}")
    print(f"Intermediate sentences: {len(row['Intermediate'])}")
    print(f"Advanced sentences: {len(row['Advanced'])}")

    if row['Elementary']:
        print(f"\nFirst Elementary sentence: {row['Elementary'][0][:100]}...")
    if row['Advanced']:
        print(f"First Advanced sentence: {row['Advanced'][0][:100]}...")

In [None]:
# Create a flattened dataset with sentence pairs
sentence_data = []

for idx, row in df_parsed.iterrows():
    elementary = row['Elementary']
    intermediate = row['Intermediate']
    advanced = row['Advanced']

    # Get the maximum length to align sentences
    max_len = max(len(elementary), len(intermediate), len(advanced))

    for i in range(max_len):
        elem_sent = elementary[i] if i < len(elementary) else None
        inter_sent = intermediate[i] if i < len(intermediate) else None
        adv_sent = advanced[i] if i < len(advanced) else None

        # Only add if at least elementary and advanced exist
        if elem_sent and adv_sent:
            sentence_data.append({
                'article_id': idx,
                'sentence_id': i,
                'elementary': elem_sent,
                'intermediate': inter_sent,
                'advanced': adv_sent
            })

# Create DataFrame with sentence pairs
sentences_df = pd.DataFrame(sentence_data)
print(f"Total sentence pairs: {len(sentences_df)}")
sentences_df.head()

## 3. Text Complexity Analysis

In [None]:
import nltk
import ssl

try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

try:
    nltk.download('punkt')
    nltk.download('punkt_tab')
    nltk.download('stopwords')
except Exception as e:
    print(f"Warning: NLTK download failed: {e}")


In [None]:
def simple_sentence_tokenize(text):
    """Simple sentence tokenizer that doesn't rely on NLTK"""
    if not text:
        return []
    # Split on sentence ending punctuation
    sentences = re.split(r'[.!?]+', text)
    # Remove empty strings and strip whitespace
    sentences = [s.strip() for s in sentences if s.strip()]
    return sentences

def simple_word_tokenize(text):
    """Simple word tokenizer that doesn't rely on NLTK"""
    if not text:
        return []
    # Split on whitespace and punctuation, keep only alphabetic words
    words = re.findall(r'\b[a-zA-Z]+\b', text.lower())
    return words

def safe_textstat_score(text, metric_func):
    """Safely calculate textstat metrics with error handling"""
    try:
        if not text or pd.isna(text) or len(text.strip()) == 0:
            return 0
        # Some textstat functions fail on very short texts
        if len(text.strip()) < 10:
            return 0
        return metric_func(text)
    except:
        return 0

def calculate_text_metrics(text):
    """Calculate various text complexity metrics with robust error handling"""

    # Handle empty or NaN input
    if not text or pd.isna(text):
        return {
            'word_count': 0,
            'sentence_count': 0,
            'avg_word_length': 0,
            'flesch_score': 0,
            'flesch_grade': 0,
            'automated_readability': 0
        }

    # Convert to string if it's not already
    text = str(text)

    # Basic tokenization without NLTK dependencies
    try:
        words = simple_word_tokenize(text)
        sentences = simple_sentence_tokenize(text)
    except:
        # Fallback: very simple tokenization
        words = text.lower().split()
        words = [re.sub(r'[^a-zA-Z]', '', word) for word in words]
        words = [word for word in words if word]
        sentences = text.split('.')
        sentences = [s.strip() for s in sentences if s.strip()]

    # Calculate basic metrics
    word_count = len(words)
    sentence_count = len(sentences) if sentences else 1  # Avoid division by zero
    avg_word_length = np.mean([len(word) for word in words]) if words else 0

    # Calculate readability scores with error handling
    try:
        flesch_score = safe_textstat_score(text, flesch_reading_ease)
        flesch_grade = safe_textstat_score(text, flesch_kincaid_grade)
        automated_readability = safe_textstat_score(text, automated_readability_index)
    except ImportError:
        # If textstat is not available, use simple approximations
        if sentence_count > 0 and word_count > 0:
            avg_sentence_length = word_count / sentence_count
            # Simple syllable approximation: count vowel groups
            total_syllables = sum([max(1, len(re.findall(r'[aeiouAEIOU]', word))) for word in words])
            avg_syllables_per_word = total_syllables / word_count if word_count > 0 else 0

            flesch_score = max(0, min(100, 206.835 - 1.015 * avg_sentence_length - 84.6 * avg_syllables_per_word))
            flesch_grade = max(0, 0.39 * avg_sentence_length + 11.8 * avg_syllables_per_word - 15.59)
            automated_readability = max(0, 4.71 * avg_syllables_per_word + 0.5 * avg_sentence_length - 21.43)
        else:
            flesch_score = 0
            flesch_grade = 0
            automated_readability = 0
    except Exception as e:
        # If any other error occurs, set to 0
        flesch_score = 0
        flesch_grade = 0
        automated_readability = 0

    return {
        'word_count': word_count,
        'sentence_count': sentence_count,
        'avg_word_length': round(avg_word_length, 2),
        'flesch_score': round(flesch_score, 2),
        'flesch_grade': round(flesch_grade, 2),
        'automated_readability': round(automated_readability, 2)
    }

# Calculate metrics for each complexity level
print("Calculating text complexity metrics...")
for level in ['elementary', 'intermediate', 'advanced']:
    print(f"Processing {level} level...")
    try:
        metrics = sentences_df[level].apply(calculate_text_metrics)

        # Convert to DataFrame and add to main dataframe
        metrics_df = pd.DataFrame(metrics.tolist())
        for col in metrics_df.columns:
            sentences_df[f'{level}_{col}'] = metrics_df[col]
        print(f"  ✓ {level} metrics calculated successfully")
    except Exception as e:
        print(f"  ✗ Error calculating {level} metrics: {e}")
        # Add empty columns if calculation fails
        for metric in ['word_count', 'sentence_count', 'avg_word_length', 'flesch_score', 'flesch_grade', 'automated_readability']:
            sentences_df[f'{level}_{metric}'] = 0

print("\nMetrics calculation complete!")
print(f"Dataset now has {sentences_df.shape[1]} columns")
sentences_df.head()

In [None]:
# Compare complexity metrics across levels
metrics_to_compare = ['word_count', 'avg_word_length', 'flesch_score', 'flesch_grade']

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.ravel()

for i, metric in enumerate(metrics_to_compare):
    data_to_plot = []
    labels = []

    for level in ['elementary', 'intermediate', 'advanced']:
        col_name = f'{level}_{metric}'
        if col_name in sentences_df.columns:
            data_to_plot.append(sentences_df[col_name].dropna())
            labels.append(level.capitalize())

    if data_to_plot:
        axes[i].boxplot(data_to_plot, labels=labels)
        axes[i].set_title(f'{metric.replace("_", " ").title()} by Complexity Level')
        axes[i].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Summary statistics
print("\nSummary Statistics by Complexity Level:")
for metric in metrics_to_compare:
    print(f"\n{metric.replace('_', ' ').title()}:")
    for level in ['elementary', 'intermediate', 'advanced']:
        col_name = f'{level}_{metric}'
        if col_name in sentences_df.columns:
            mean_val = sentences_df[col_name].mean()
            std_val = sentences_df[col_name].std()
            print(f"  {level.capitalize()}: {mean_val:.2f} (±{std_val:.2f})")

## 4. Text Simplification Examples

In [None]:
# Show examples of text simplification
def show_simplification_examples(df, n=5):
    """Display examples of text at different complexity levels"""

    # Filter out rows with missing data
    complete_rows = df.dropna(subset=['elementary', 'advanced'])

    # Select random examples
    examples = complete_rows.sample(n=min(n, len(complete_rows)))

    for i, (idx, row) in enumerate(examples.iterrows()):
        print(f"\n{'='*60}")
        print(f"EXAMPLE {i+1}")
        print(f"{'='*60}")

        print("\n🔴 ADVANCED (Original):")
        print(f"{row['advanced']}")

        if pd.notna(row['intermediate']):
            print("\n🟡 INTERMEDIATE:")
            print(f"{row['intermediate']}")

        print("\n🟢 ELEMENTARY (Simplified):")
        print(f"{row['elementary']}")

        # Show metrics
        print("\n📊 COMPLEXITY METRICS:")
        for level in ['advanced', 'elementary']:
            if f'{level}_word_count' in df.columns:
                wc = row[f'{level}_word_count']
                awl = row[f'{level}_avg_word_length']
                flesch = row[f'{level}_flesch_score']
                print(f"  {level.upper()}: {wc} words, {awl:.1f} avg word length, {flesch:.1f} Flesch score")

show_simplification_examples(sentences_df, n=3)

## 5. Vocabulary Analysis

In [None]:
def analyze_vocabulary(texts, level_name):
    """Analyze vocabulary usage in texts"""
    all_words = []
    stop_words = set(stopwords.words('english'))

    for text in texts:
        if pd.notna(text):
            words = word_tokenize(text.lower())
            # Filter out punctuation and stop words
            words = [word for word in words if word.isalpha() and word not in stop_words]
            all_words.extend(words)

    word_freq = Counter(all_words)
    unique_words = len(word_freq)
    total_words = len(all_words)

    print(f"\n{level_name.upper()} VOCABULARY:")
    print(f"  Total words: {total_words:,}")
    print(f"  Unique words: {unique_words:,}")
    print(f"  Vocabulary richness: {unique_words/total_words:.3f}")
    print(f"  Top 10 words: {word_freq.most_common(10)}")

    return word_freq

# Analyze vocabulary for each level
vocab_stats = {}
for level in ['elementary', 'intermediate', 'advanced']:
    vocab_stats[level] = analyze_vocabulary(sentences_df[level], level)

In [None]:
# Compare vocabulary overlap between levels
def vocabulary_overlap(vocab1, vocab2, name1, name2):
    """Calculate vocabulary overlap between two levels"""
    words1 = set(vocab1.keys())
    words2 = set(vocab2.keys())

    overlap = words1.intersection(words2)
    only_in_1 = words1 - words2
    only_in_2 = words2 - words1

    print(f"\nVOCABULARY OVERLAP: {name1.upper()} vs {name2.upper()}")
    print(f"  Shared words: {len(overlap)} ({len(overlap)/len(words1.union(words2)):.1%})")
    print(f"  Only in {name1}: {len(only_in_1)}")
    print(f"  Only in {name2}: {len(only_in_2)}")

    return overlap, only_in_1, only_in_2

# Compare elementary vs advanced
if 'elementary' in vocab_stats and 'advanced' in vocab_stats:
    overlap_ea, only_elem, only_adv = vocabulary_overlap(
        vocab_stats['elementary'],
        vocab_stats['advanced'],
        'elementary',
        'advanced'
    )

    print(f"\nWords unique to ADVANCED (sample): {list(only_adv)[:20]}")
    print(f"Words unique to ELEMENTARY (sample): {list(only_elem)[:20]}")

## 6. Dataset Preparation for Machine Learning

In [None]:
# Prepare datasets for different simplification tasks

# 1. Advanced to Elementary simplification
adv_to_elem = sentences_df[['advanced', 'elementary']].dropna()
adv_to_elem.columns = ['source', 'target']
adv_to_elem['task'] = 'advanced_to_elementary'

# 2. Advanced to Intermediate simplification
adv_to_inter = sentences_df[['advanced', 'intermediate']].dropna()
adv_to_inter.columns = ['source', 'target']
adv_to_inter['task'] = 'advanced_to_intermediate'

# 3. Intermediate to Elementary simplification
inter_to_elem = sentences_df[['intermediate', 'elementary']].dropna()
inter_to_elem.columns = ['source', 'target']
inter_to_elem['task'] = 'intermediate_to_elementary'

# Combine all tasks
ml_dataset = pd.concat([adv_to_elem, adv_to_inter, inter_to_elem], ignore_index=True)

print(f"Machine Learning Dataset Summary:")
print(f"Total pairs: {len(ml_dataset)}")
print(f"\nBy task:")
print(ml_dataset['task'].value_counts())

# Save the ML dataset
ml_dataset.to_csv('text_simplification_ml_dataset.csv', index=False)
print(f"\nDataset saved as 'text_simplification_ml_dataset.csv'")

ml_dataset.head()

In [None]:
# Create train/validation/test splits
from sklearn.model_selection import train_test_split

def create_splits(df, test_size=0.2, val_size=0.1, random_state=42):
    """Create train/validation/test splits"""

    # First split: separate test set
    train_val, test = train_test_split(
        df, test_size=test_size, random_state=random_state, stratify=df['task']
    )

    # Second split: separate validation from training
    val_size_adjusted = val_size / (1 - test_size)  # Adjust validation size
    train, val = train_test_split(
        train_val, test_size=val_size_adjusted, random_state=random_state, stratify=train_val['task']
    )

    return train, val, test

# Create splits for each task
splits = {}

for task in ml_dataset['task'].unique():
    task_data = ml_dataset[ml_dataset['task'] == task].copy()

    if len(task_data) >= 10:  # Only split if we have enough data
        # For small datasets, just do train/test split
        train_data, test_data = train_test_split(
            task_data, test_size=0.2, random_state=42
        )

        # Use 10% of training data for validation
        if len(train_data) >= 10:
            train_data, val_data = train_test_split(
                train_data, test_size=0.125, random_state=42  # 0.125 of 0.8 = 0.1 of total
            )
        else:
            val_data = train_data.sample(n=min(2, len(train_data)//2), random_state=42)
            train_data = train_data.drop(val_data.index)

        splits[task] = {
            'train': train_data,
            'val': val_data,
            'test': test_data
        }

        print(f"\n{task.upper()} SPLITS:")
        print(f"  Train: {len(train_data)} samples")
        print(f"  Validation: {len(val_data)} samples")
        print(f"  Test: {len(test_data)} samples")

        # Save splits
        train_data.to_csv(f'{task}_train.csv', index=False)
        val_data.to_csv(f'{task}_val.csv', index=False)
        test_data.to_csv(f'{task}_test.csv', index=False)

print("\nDataset splits saved successfully!")

In [None]:
# Final dataset summary
print("📊 FINAL DATASET SUMMARY")
print("=" * 50)
print(f"📄 Total articles: {len(df_parsed)}")
print(f"📝 Total sentence pairs: {len(sentences_df)}")
print(f"🤖 ML dataset pairs: {len(ml_dataset)}")

print(f"\n📈 COMPLEXITY ANALYSIS:")
for level in ['elementary', 'intermediate', 'advanced']:
    if f'{level}_flesch_score' in sentences_df.columns:
        avg_flesch = sentences_df[f'{level}_flesch_score'].mean()
        avg_words = sentences_df[f'{level}_word_count'].mean()
        print(f"  {level.capitalize()}: {avg_flesch:.1f} Flesch score, {avg_words:.1f} avg words")



In [None]:

original_count = len(ml_dataset)
print(f"Original dataset: {original_count} pairs")

# Remove empty/very short texts
ml_dataset_clean = ml_dataset.dropna(subset=['source', 'target'])
ml_dataset_clean = ml_dataset_clean[
    (ml_dataset_clean['source'].str.len() > 10) &
    (ml_dataset_clean['target'].str.len() > 10)
]

# Remove identical source-target pairs (wastes training time)
ml_dataset_clean = ml_dataset_clean[ml_dataset_clean['source'] != ml_dataset_clean['target']]

# Remove extremely long texts (causes memory issues)
ml_dataset_clean = ml_dataset_clean[
    (ml_dataset_clean['source'].str.len() < 1000) &
    (ml_dataset_clean['target'].str.len() < 1000)
]

print(f"Filtered dataset: {len(ml_dataset_clean)} pairs")
print(f"Kept: {len(ml_dataset_clean)/original_count:.1%} of original data")

In [None]:


# T5 expects: "simplify: [complex text]" → "[simple text]"
ml_dataset_clean['input_text'] = 'simplify: ' + ml_dataset_clean['source']
ml_dataset_clean['target_text'] = ml_dataset_clean['target']

print("✅ T5 format created")
print("Input format: 'simplify: [complex text]'")
print("Target format: '[simplified text]'")

# Show examples
print("\n📋 TRAINING FORMAT EXAMPLES:")
for i, row in ml_dataset_clean.head(3).iterrows():
    print(f"\nExample {i+1}:")
    print(f"Input:  {row['input_text'][:100]}...")
    print(f"Target: {row['target_text'][:100]}...")

In [None]:


import os
from sklearn.model_selection import train_test_split

# Create output directory
os.makedirs('training_ready', exist_ok=True)

# Split by task to maintain balance
all_train, all_val, all_test = [], [], []

for task in ml_dataset_clean['task'].unique():
    task_data = ml_dataset_clean[ml_dataset_clean['task'] == task].copy()
    print(f"{task}: {len(task_data)} pairs")

    if len(task_data) >= 10:
        # 80/10/10 split
        train, temp = train_test_split(task_data, test_size=0.2, random_state=42)
        val, test = train_test_split(temp, test_size=0.5, random_state=42)

        all_train.append(train)
        all_val.append(val)
        all_test.append(test)
    else:
        # Too few samples, put in training
        all_train.append(task_data)

# Combine all tasks
train_df = pd.concat(all_train, ignore_index=True)
val_df = pd.concat(all_val, ignore_index=True) if all_val else train_df.sample(n=min(50, len(train_df)//10))
test_df = pd.concat(all_test, ignore_index=True) if all_test else train_df.sample(n=min(50, len(train_df)//10))

print(f"\n📊 FINAL SPLITS:")
print(f"Train: {len(train_df)} pairs")
print(f"Val: {len(val_df)} pairs")
print(f"Test: {len(test_df)} pairs")

# Save as JSONL (standard format for training)
train_df[['input_text', 'target_text']].to_json('training_ready/train.jsonl', orient='records', lines=True)
val_df[['input_text', 'target_text']].to_json('training_ready/val.jsonl', orient='records', lines=True)
test_df[['input_text', 'target_text']].to_json('training_ready/test.jsonl', orient='records', lines=True)

# Also save as CSV for inspection
train_df.to_csv('training_ready/train.csv', index=False)
val_df.to_csv('training_ready/val.csv', index=False)
test_df.to_csv('training_ready/test.csv', index=False)

print(f"\n✅ TRAINING FILES SAVED!")
print(f"📁 Location: training_ready/")
print(f"📄 Files: train.jsonl, val.jsonl, test.jsonl")
print(f"🚀 READY TO TRAIN T5 MODEL!")

print(f"\n🎯 NEXT STEPS:")
print(f"1. Install: pip install transformers datasets torch")
print(f"2. Use training_ready/train.jsonl for training")
print(f"3. Use training_ready/val.jsonl for validation")
print(f"4. Start with t5-small model")


In [None]:
@dataclass
class ModelConfig:
    """Training configuration."""
    model_name: str = "t5-small"
    max_input_length: int = 512
    max_target_length: int = 256
    output_dir: str = "t5-simplification-model"

    # Training hyperparameters
    batch_size: int = 2  # Reduced for stability
    num_epochs: int = 3
    learning_rate: float = 5e-5
    warmup_steps: int = 500
    eval_steps: int = 200
    save_steps: int = 200

class RobustSimplificationDataset(Dataset):
    """Ultra-robust dataset with strict validation."""

    def __init__(self, data_file: str, tokenizer, config: ModelConfig):
        self.tokenizer = tokenizer
        self.config = config
        self.data = self._load_and_validate_data(data_file)

    def _load_and_validate_data(self, data_file: str):
        """Load and strictly validate all data."""
        valid_data = []

        with open(data_file, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                try:
                    item = json.loads(line.strip())

                    # Validate required fields
                    if not ('input_text' in item and 'target_text' in item):
                        continue

                    # Clean and validate text
                    input_text = str(item['input_text']).strip()
                    target_text = str(item['target_text']).strip()

                    # Skip empty or very short texts
                    if len(input_text) < 10 or len(target_text) < 5:
                        continue

                    # Test tokenization
                    try:
                        # Try tokenizing to ensure it works
                        input_tokens = self.tokenizer(
                            input_text,
                            max_length=self.config.max_input_length,
                            padding='max_length',
                            truncation=True,
                            return_tensors='pt'
                        )

                        target_tokens = self.tokenizer(
                            target_text,
                            max_length=self.config.max_target_length,
                            padding='max_length',
                            truncation=True,
                            return_tensors='pt'
                        )

                        # Verify shapes
                        if (input_tokens.input_ids.shape[1] == self.config.max_input_length and
                            target_tokens.input_ids.shape[1] == self.config.max_target_length):

                            valid_data.append({
                                'input_text': input_text,
                                'target_text': target_text
                            })

                    except Exception as token_error:
                        print(f"⚠️ Tokenization failed for sample {i}: {token_error}")
                        continue

                except Exception as e:
                    print(f"⚠️ Skipping line {i}: {e}")
                    continue

        print(f"📊 Loaded {len(valid_data)} validated examples from {data_file}")
        return valid_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """Get item with guaranteed consistent shapes."""
        item = self.data[idx]

        # Tokenize input
        input_encoding = self.tokenizer(
            item['input_text'],
            max_length=self.config.max_input_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize target
        target_encoding = self.tokenizer(
            item['target_text'],
            max_length=self.config.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Ensure correct shapes and flatten
        input_ids = input_encoding.input_ids.squeeze(0)  # Remove batch dimension
        attention_mask = input_encoding.attention_mask.squeeze(0)
        labels = target_encoding.input_ids.squeeze(0)

        # Final shape validation
        assert input_ids.shape == torch.Size([self.config.max_input_length])
        assert attention_mask.shape == torch.Size([self.config.max_input_length])
        assert labels.shape == torch.Size([self.config.max_target_length])

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

# Recreate datasets with robust class
print("🔄 Recreating datasets with robust validation...")

train_dataset = RobustSimplificationDataset("training_ready/train.jsonl", tokenizer, config)
val_dataset = RobustSimplificationDataset("training_ready/val.jsonl", tokenizer, config)

print(f"✅ Robust datasets created:")
print(f"   Train: {len(train_dataset)} samples")
print(f"   Val: {len(val_dataset)} samples")

# Initialize configuration
config = ModelConfig()
print("⚙️ Training Configuration (Fixed):")
print(f"   Model: {config.model_name}")
print(f"   Batch size: {config.batch_size} (reduced for stability)")
print(f"   Epochs: {config.num_epochs}")
print(f"   Learning rate: {config.learning_rate}")
print(f"   Output directory: {config.output_dir}")

In [None]:
def compute_metrics(eval_pred, tokenizer) -> Dict[str, float]:
    """Compute ROUGE metrics for evaluation."""
    predictions, labels = eval_pred

    # Replace -100 tokens with pad token for decoding
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Decode
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Compute ROUGE scores
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    rouge1_scores = []
    rouge2_scores = []
    rougeL_scores = []

    for pred, label in zip(decoded_preds, decoded_labels):
        scores = scorer.score(label, pred)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)

    return {
        'rouge1': np.mean(rouge1_scores),
        'rouge2': np.mean(rouge2_scores),
        'rougeL': np.mean(rougeL_scores)
    }

print("✅ Evaluation metrics function ready!")


In [None]:
# Load tokenizer and model
print("📥 Loading T5 tokenizer and model...")

tokenizer = T5Tokenizer.from_pretrained(config.model_name)
model = T5ForConditionalGeneration.from_pretrained(config.model_name)

# Move model to device
model.to(device)

# Model info
num_params = sum(p.numel() for p in model.parameters())
print(f"✅ Model loaded: {config.model_name}")
print(f"   Parameters: {num_params:,} ({num_params/1e6:.1f}M)")
print(f"   Device: {next(model.parameters()).device}")

# Test tokenizer
test_input = "simplify: This is a complex sentence that needs simplification."
test_tokens = tokenizer(test_input, return_tensors="pt")
print(f"\n🔤 Tokenizer test:")
print(f"   Input: {test_input}")
print(f"   Tokens: {test_tokens['input_ids'].shape}")
print(f"   Decoded: {tokenizer.decode(test_tokens['input_ids'][0], skip_special_tokens=True)}")


In [None]:
# Create datasets
print("📊 Creating training datasets...")

train_dataset = SimplificationDataset(
    "training_ready/train.jsonl",
    tokenizer,
    config
)

val_dataset = SimplificationDataset(
    "training_ready/val.jsonl",
    tokenizer,
    config
)

test_dataset = SimplificationDataset(
    "training_ready/test.jsonl",
    tokenizer,
    config
)

print(f"\n📈 Dataset sizes:")
print(f"   Train: {len(train_dataset)} examples")
print(f"   Val: {len(val_dataset)} examples")
print(f"   Test: {len(test_dataset)} examples")

# Test dataset loading
sample = train_dataset[0]
print(f"\n🔍 Sample data shapes:")
print(f"   Input IDs: {sample['input_ids'].shape}")
print(f"   Attention mask: {sample['attention_mask'].shape}")
print(f"   Labels: {sample['labels'].shape}")


In [None]:
# Ultra-simple data collator that handles shapes explicitly
class UltraSimpleCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features):
        """Handle batching with explicit shape checking."""

        # Extract features
        input_ids_list = []
        attention_mask_list = []
        labels_list = []

        for feature in features:
            input_ids_list.append(feature['input_ids'])
            attention_mask_list.append(feature['attention_mask'])
            labels_list.append(feature['labels'])

        # Stack with shape validation
        try:
            input_ids = torch.stack(input_ids_list, dim=0)
            attention_mask = torch.stack(attention_mask_list, dim=0)
            labels = torch.stack(labels_list, dim=0)

            # Verify final batch shapes
            batch_size = len(features)
            expected_input_shape = (batch_size, config.max_input_length)
            expected_labels_shape = (batch_size, config.max_target_length)

            assert input_ids.shape == expected_input_shape, f"Input shape mismatch: {input_ids.shape} vs {expected_input_shape}"
            assert labels.shape == expected_labels_shape, f"Labels shape mismatch: {labels.shape} vs {expected_labels_shape}"

            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels
            }

        except Exception as e:
            print(f" Collation error: {e}")
            print(f"Feature shapes:")
            for i, feature in enumerate(features):
                print(f"  Feature {i}: input={feature['input_ids'].shape}, mask={feature['attention_mask'].shape}, labels={feature['labels'].shape}")
            raise

# Use the ultra-simple collator
data_collator = UltraSimpleCollator(tokenizer)
print(" Ultra-simple data collator created")


In [None]:
# Create new trainer with all fixes
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_bound,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

# Find and resume from checkpoint
checkpoints = glob.glob(f"{config.output_dir}/checkpoint-*")

print("🚀 Starting training with all fixes...")

start_time = time.time()

try:
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('-')[-1]))
        print(f"Resuming from: {latest_checkpoint}")
        training_result = trainer.train(resume_from_checkpoint=latest_checkpoint)
    else:
        print("Starting fresh training")
        training_result = trainer.train()

    training_time = time.time() - start_time
    print(f"\n🎉 Training completed successfully!")
    print(f"Training time: {training_time/60:.1f} minutes")
    print(f"Final loss: {training_result.training_loss:.4f}")

    # Save the model
    trainer.save_model()
    tokenizer.save_pretrained(config.output_dir)
    print(f" Model saved to: {config.output_dir}")
except Exception as e:
    print(f" Training failed again: {e}")

    # Emergency: Skip evaluation and just train
    print(" Trying training without evaluation...")

    # Remove evaluation from training args
    training_args.eval_strategy = "no"
    training_args.save_strategy = "epoch"

    trainer_no_eval = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    try:
        training_result = trainer_no_eval.train()
        print("✅ Training without evaluation successful!")
        trainer_no_eval.save_model()
    except Exception as final_error:
        print(f"❌ Final attempt failed: {final_error}")

In [None]:
# Start training
print("🚀 Starting T5 text simplification training...")
print("=" * 50)

# Record start time
start_time = time.time()

try:
    # Train the model
    training_result = trainer.train()

    # Calculate training time
    training_time = time.time() - start_time

    print("\n" + "=" * 50)
    print("🎉 Training completed successfully!")
    print(f"⏱️ Training time: {training_time/60:.1f} minutes")
    print(f"📊 Final training loss: {training_result.training_loss:.4f}")

except KeyboardInterrupt:
    print("\n⚠️ Training interrupted by user")
    print("💾 Saving current state...")
    trainer.save_model()

except Exception as e:
    print(f"\n❌ Training failed with error: {e}")
    print("💾 Attempting to save model state...")
    try:
        trainer.save_model()
        print("✅ Model state saved")
    except:
        print("❌ Could not save model state")
