related lib

In [None]:
import os
import pandas as pd
import shutil
import re
import nltk
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
from datasets import load_dataset
from typing import Tuple, Dict, List, Optional

# 下载必要的NLTK资源
nltk.download('stopwords', quiet=True)
from nltk.corpus import stopwords
nltk.download('wordnet', quiet=True)
from nltk.stem import WordNetLemmatizer

# 设置matplotlib中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

def clear_huggingface_cache():
    """
    Clear HuggingFace dataset cache
    """
    try:
        cache_dir = os.path.expanduser("~/.cache/huggingface/datasets")
        if os.path.exists(cache_dir):
            shutil.rmtree(cache_dir)
            print(f"HuggingFace cache directory cleared: {cache_dir}")
    except Exception as e:
        print(f"Warning clearing cache: {e}")

def clean_text(text: str) -> str:
    """
    Clean text data
    """
    if pd.isna(text):
        return ""
    
    text = str(text).strip()
    # 移除HTML标签
    text = re.sub(r'<.*?>', '', text)
    # 移除特殊字符和数字
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # 转换为小写
    text = text.lower()
    # 去除多余空格
    text = ' '.join(text.split())
    
    return text

def preprocess_text(text: str, remove_stopwords: bool = True, lemmatize: bool = True) -> str:
    """
    Advanced text preprocessing
    
    Args:
        text: Original text
        remove_stopwords: Whether to remove stopwords
        lemmatize: Whether to lemmatize words
        
    Returns:
        str: Preprocessed text
    """
    # 基础清理
    text = clean_text(text)
    
    if not text:
        return ""
    
    # 分词
    words = text.split()
    
    # 移除停用词
    if remove_stopwords:
        stop_words = set(stopwords.words('english'))
        words = [word for word in words if word not in stop_words]
    
    # 词形还原
    if lemmatize:
        lemmatizer = WordNetLemmatizer()
        words = [lemmatizer.lemmatize(word) for word in words]
    
    return ' '.join(words)

def get_tokenizer(model_name='bert-base-uncased'):
    """
    Get pre-trained tokenizer
    
    Args:
        model_name: Pre-trained model name
        
    Returns:
        Tokenizer instance
    """
    return BertTokenizer.from_pretrained(model_name)

def load_yelp_dataset():
    """
    Load Yelp dataset from Hugging Face
    
    Args:
        save_to_local: Whether to save to local CSV
    
    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: (train_df, test_df)
    """
    print("=== Yelp Dataset Loader ===")
    print("Loading yelp_review_full from HuggingFace...")
    clear_huggingface_cache()
    
    try:
        dataset = load_dataset("yelp_review_full", cache_dir=None)
        
        # Convert to DataFrames
        train_df = pd.DataFrame({'text': dataset['train']['text'], 'label': dataset['train']['label']})
        test_df = pd.DataFrame({'text': dataset['test']['text'], 'label': dataset['test']['label']})
        
        # Convert labels from 0-4 to 1-5 stars
        train_df['label'] += 1
        test_df['label'] += 1
        
        print(f"Loaded {len(train_df)} train samples and {len(test_df)} test samples")
        print(f"Train label distribution: {train_df['label'].value_counts().sort_index().to_dict()}")
        
        return train_df, test_df
    except Exception as e:
        print(f"Failed to load from HuggingFace: {e}")
        raise

def preprocess_data(train_df: pd.DataFrame, test_df: pd.DataFrame, 
                   remove_stopwords: bool = True, lemmatize: bool = True, 
                   batch_size: Optional[int] = None) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Preprocess text data with advanced options
    
    Args:
        train_df: Training DataFrame
        test_df: Test DataFrame
        remove_stopwords: Whether to remove stopwords
        lemmatize: Whether to lemmatize words
        batch_size: Batch size for processing (None means process all at once)
        
    Returns:
        Tuple[pd.DataFrame, pd.DataFrame]: Preprocessed train and test DataFrames
    """
    print("Preprocessing text data...")
    
    # 复制数据
    train_df = train_df.copy()
    test_df = test_df.copy()
    
    # 高级文本预处理
    preprocess_func = lambda text: preprocess_text(text, remove_stopwords, lemmatize)
    
    # 分批处理以减少内存使用
    if batch_size and batch_size > 0:
        print(f"Processing in batches of {batch_size}...")
        
        # 处理训练集
        processed_train_texts = []
        for i in range(0, len(train_df), batch_size):
            batch_texts = train_df['text'].iloc[i:i+batch_size].apply(preprocess_func)
            processed_train_texts.extend(batch_texts.tolist())
            print(f"  Processed {min(i+batch_size, len(train_df))}/{len(train_df)} train samples")
        train_df['text'] = processed_train_texts
        
        # 处理测试集
        processed_test_texts = []
        for i in range(0, len(test_df), batch_size):
            batch_texts = test_df['text'].iloc[i:i+batch_size].apply(preprocess_func)
            processed_test_texts.extend(batch_texts.tolist())
            print(f"  Processed {min(i+batch_size, len(test_df))}/{len(test_df)} test samples")
        test_df['text'] = processed_test_texts
    else:
        # 一次性处理
        train_df['text'] = train_df['text'].apply(preprocess_func)
        test_df['text'] = test_df['text'].apply(preprocess_func)
    
    # 添加文本长度特征
    train_df['text_length'] = train_df['text'].str.len()
    test_df['text_length'] = test_df['text'].str.len()
    
    # 过滤空文本
    train_df = train_df[train_df['text'].str.len() > 0].reset_index(drop=True)
    test_df = test_df[test_df['text'].str.len() > 0].reset_index(drop=True)
    
    print(f"Preprocessed: {len(train_df)} train, {len(test_df)} test samples")
    return train_df, test_df

def analyze_data(df: pd.DataFrame, data_name: str = "Dataset") -> Dict:
    """
    Analyze the dataset and return statistics
    
    Args:
        df: DataFrame to analyze
        data_name: Name of the dataset
        
    Returns:
        Dict: Analysis results
    """
    print(f"\n=== {data_name} Analysis ===")
    
    # 基本统计
    total_samples = len(df)
    print(f"Total samples: {total_samples}")
    
    # 标签分布
    label_counts = df['label'].value_counts().sort_index()
    print(f"\nLabel distribution:")
    for label, count in label_counts.items():
        percentage = (count / total_samples) * 100
        print(f"  {label} stars: {count} ({percentage:.1f}%)")
    
    # 文本长度统计
    min_length = df['text_length'].min()
    max_length = df['text_length'].max()
    avg_length = df['text_length'].mean()
    median_length = df['text_length'].median()
    
    print(f"\nText length statistics:")
    print(f"  Minimum: {min_length} characters")
    print(f"  Maximum: {max_length} characters")
    print(f"  Average: {avg_length:.2f} characters")
    print(f"  Median: {median_length} characters")
    
    # 词汇统计
    total_words = df['text'].str.split().str.len().sum()
    avg_words = total_words / total_samples
    print(f"\nVocabulary statistics:")
    print(f"  Total words: {total_words}")
    print(f"  Average words per sample: {avg_words:.2f}")
    
    # 计算不重复词汇数
    unique_words = set()
    batch_size = 10000  # 每次处理的样本数
    for i in range(0, len(df), batch_size):
        batch_text = ' '.join(df['text'].iloc[i:i+batch_size].tolist())
        batch_words = batch_text.split()
        unique_words.update(batch_words)
    
    unique_words_count = len(unique_words)
    print(f"  Unique words: {unique_words_count}")
    
    # 返回分析结果
    return {
        'total_samples': total_samples,
        'label_counts': label_counts.to_dict(),
        'text_length': {
            'min': min_length,
            'max': max_length,
            'avg': avg_length,
            'median': median_length
        },
        'vocabulary': {
            'total_words': total_words,
            'avg_words_per_sample': avg_words,
            'unique_words': unique_words_count
        }
    }

def visualize_data(df: pd.DataFrame, output_dir: str = "analysis_results", sample_size: int = 100000) -> None:
    """
    Visualize dataset characteristics with memory optimization
    
    Args:
        df: DataFrame to visualize
        output_dir: Directory to save visualizations
        sample_size: Maximum number of samples to use for visualization
    """
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 对大型数据集进行采样以减少内存使用
    if len(df) > sample_size:
        df_sample = df.sample(n=sample_size, random_state=42)
        print(f"Using sample of {sample_size} out of {len(df)} samples for visualization")
    else:
        df_sample = df
    
    # 1. 标签分布 - 使用原始数据确保准确计数
    plt.figure(figsize=(8, 6))
    label_counts = df['label'].value_counts().sort_index()
    label_counts.plot(kind='bar', color=sns.color_palette('viridis', len(label_counts)))
    plt.title('Label Distribution (Star Ratings)')
    plt.xlabel('Star Rating')
    plt.ylabel('Count')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'label_distribution.png'), dpi=100)
    plt.close()
    
    # 2. 文本长度分布
    plt.figure(figsize=(10, 6))
    sns.histplot(df_sample['text_length'], bins=50, kde=True, color='blue')
    plt.title('Text Length Distribution')
    plt.xlabel('Text Length (characters)')
    plt.ylabel('Count')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'text_length_distribution.png'), dpi=100)
    plt.close()
    
    # 3. 文本长度与标签的关系
    plt.figure(figsize=(10, 6))
    sns.boxplot(x='label', y='text_length', data=df_sample, palette='viridis')
    plt.title('Text Length vs Star Rating')
    plt.xlabel('Star Rating')
    plt.ylabel('Text Length (characters)')
    
    # 去除异常值以提高可读性
    max_length = df_sample['text_length'].quantile(0.95)
    plt.ylim(0, max_length)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'text_length_vs_rating.png'), dpi=100)
    plt.close()
    
    print(f"Visualizations saved to {output_dir}")

def test_preprocessing():
    """Test data loading, preprocessing and analysis"""
    print("=== Testing Yelp Dataset Loader ===")
    
    try:
        # Load dataset
        train_df, test_df = load_yelp_dataset()
        print(f"✓ Dataset loaded successfully: {len(train_df)} train, {len(test_df)} test samples")
        print(f"  Columns: {list(train_df.columns)}")
        
        # Advanced preprocessing with memory optimization
        train_processed, test_processed = preprocess_data(
            train_df, test_df, 
            remove_stopwords=True, 
            lemmatize=True,
            batch_size=50000  # 使用分批处理减少内存占用
        )
        print(f"✓ Data preprocessed successfully: {len(train_processed)} train, {len(test_processed)} test samples")
        
        # Analyze data
        print("\n=== Analyzing Preprocessed Data ===")
        train_analysis = analyze_data(train_processed, "Training Set")
        test_analysis = analyze_data(test_processed, "Test Set")
        
        # Visualize data with sampling
        print("\n=== Generating Visualizations ===")
        combined_df = pd.concat([train_processed, test_processed], keys=['train', 'test']).reset_index(level=0).rename(columns={'level_0': 'set'})
        visualize_data(combined_df, sample_size=100000)  # 使用采样减少内存使用
        
        print(f"\n✓ All tests completed successfully!")
        
    except Exception as e:
        print(f"✗ Test failed: {e}")
        import traceback
        traceback.print_exc()
    
    print("\n=== Testing Complete ===")

if __name__ == "__main__":
    test_preprocessing()