In [1]:
#später wieder einfügen für ALLE daten, jetzt nur ein paar Beispieldaten um zu sehen ob es klappt statt load sample

# Lade die offiziellen N-Gramm-Statistiken der Autoren aus Google Cloud Storage
def load_official_ngram_data():
    file_list = fs.ls(GCS_BUCKET)
    df_list = []
    for file in file_list:
        if file.endswith(".parquet"):
            with fs.open(file) as f:
                df = pd.read_parquet(f)
                df_list.append(df)
    official_ngram_data = pd.concat(df_list, ignore_index=True)
    return official_ngram_data

In [4]:
# N-gram Analysis Notebook

# First, let's import all necessary libraries
import collections
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from typing import List, Tuple
import random
from tqdm.notebook import tqdm  # For progress bars in Jupyter
import warnings
warnings.filterwarnings('ignore')  # Suppress warnings for cleaner output

# We'll create a mock gcsfs since you might not have access to the actual Google Cloud Storage
class MockGCSFileSystem:
    def __init__(self, project=None):
        self.files = {}
    
    def ls(self, path):
        # Return a list of mock file paths
        return [f"{path}/file_{i}.parquet" for i in range(5)]
    
    def open(self, file_path):
        # Create a dummy file-like object
        import io
        return io.BytesIO()

# Try to import gcsfs, but use our mock if it fails
try:
    import gcsfs
    fs = gcsfs.GCSFileSystem(project='transformer-ngrams')
    GCS_BUCKET = 'gs://transformer-ngrams/TinyStories/train_data_rules'
except ImportError:
    print("Using mock GCS filesystem - no actual GCS data will be loaded")
    fs = MockGCSFileSystem()
    GCS_BUCKET = 'mock-bucket/TinyStories/train_data_rules'

# Configuration parameters
MAX_FILES = 5
MAX_ROWS = 5000
SAMPLE_FRACTION = 0.15

# Function to generate synthetic data for testing
def generate_synthetic_data(num_rows=1000):
    """Generate synthetic data similar to what we'd expect from the GCS bucket"""
    # Sample vocabulary for generating text
    vocab = ["the", "quick", "brown", "fox", "jumps", "over", "lazy", "dog", 
             "a", "man", "woman", "child", "house", "car", "tree", "sky"]
    
    # Generate random text
    texts = []
    for _ in range(num_rows):
        text_length = random.randint(10, 30)
        text = " ".join(random.choices(vocab, k=text_length))
        texts.append(text)
    
    # Generate next token counters (simplified)
    next_token_counters = []
    for _ in range(num_rows):
        # Create a list of random counts for potential next tokens
        counter = [random.randint(1, 100) for _ in range(random.randint(3, 10))]
        next_token_counters.append(counter)
    
    # Create DataFrame
    df = pd.DataFrame({
        'text': texts,
        'next_token_counter': next_token_counters,
        'context_size_used': [random.randint(1, 7) for _ in range(num_rows)],
        'target': [random.randint(1, 1000) for _ in range(num_rows)]
    })
    
    return df

# Extract n-grams from token lists
def extract_ngrams(tokens: List[str], n: int) -> List[Tuple[str, ...]]:
    """Extract n-grams from a list of tokens"""
    return [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]

# Chunk dataset into overlapping chunks
def chunk_dataset(corpus: List[str], chunk_size: int = 2048, step: int = None) -> List[List[str]]:
    """
    Split tokens into overlapping chunks. If corpus is smaller than chunk_size,
    the entire corpus is returned as a single chunk.
    """
    if step is None:
        step = chunk_size // 4  # 25% overlap for efficiency
    
    chunk_size = min(chunk_size, len(corpus))
    if len(corpus) <= chunk_size:
        return [corpus]
    
    return [corpus[i:i + chunk_size] for i in range(0, len(corpus) - chunk_size + 1, step)]

# Compute n-gram statistics
def compute_ngram_statistics(corpus_chunks: List[List[str]], max_n: int = 7):
    """Compute n-gram statistics for different values of n"""
    ngram_counts = collections.defaultdict(collections.Counter)
    
    print("Computing n-gram statistics...")
    for chunk_idx, chunk in enumerate(tqdm(corpus_chunks, desc="Processing chunks")):
        # Show debug output for every 10th chunk up to 30 chunks
        show_debug = (chunk_idx % 10 == 0 and chunk_idx < 30)
        
        for n in range(1, max_n + 1):
            ngrams = extract_ngrams(chunk, n)
            
            if show_debug and n == 7 and ngrams:
                print(f"🔍 Example {n}-grams from chunk {chunk_idx}:", ngrams[:3])
                
            for ngram in ngrams:
                ngram_counts[n][ngram] += 1
    
    # Display the top n-grams for each value of n
    for n in range(1, max_n + 1):
        most_common = ngram_counts[n].most_common(5)
        if most_common:
            print(f"📊 Top-5 {n}-grams:", most_common)
            print(f"📊 Number of unique {n}-grams:", len(ngram_counts[n]))
    
    return ngram_counts

# Store n-gram statistics in JSON
def store_ngram_statistics_json(ngram_counts, filename="ngrams.json"):
    """Store n-gram statistics in a JSON file"""
    # Convert only the most common n-grams to JSON to save space
    max_ngrams_per_n = 10000
    
    json_data = {}
    for n, counts in ngram_counts.items():
        top_counts = dict(counts.most_common(max_ngrams_per_n))
        json_data[str(n)] = {" ".join(ngram): count for ngram, count in top_counts.items()}
    
    print(f"Saving n-gram statistics to {filename}...")
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(json_data, f)
    print("✅ Saved!")
    
    return json_data

# Load a small sample of data
def load_sample_data(use_synthetic=True, max_files=MAX_FILES, max_rows=MAX_ROWS, sample_fraction=SAMPLE_FRACTION):
    """Load a sample of data, either from GCS or synthetic"""
    if use_synthetic:
        print("Generating synthetic data for testing...")
        return generate_synthetic_data(num_rows=max_rows)
    
    file_list = fs.ls(GCS_BUCKET)
    
    # Take a stratified sample of files
    if len(file_list) > max_files:
        step = len(file_list) // max_files
        file_list = [file_list[i] for i in range(0, len(file_list), step)][:max_files]
    
    print(f"Loading data from {len(file_list)} files...")
    df_list = []
    
    for file_idx, file in enumerate(file_list):
        if file.endswith(".parquet"):
            print(f"Processing file {file_idx+1}/{len(file_list)}: {file}")
            try:
                with fs.open(file) as f:
                    # Read only required columns
                    try:
                        df = pd.read_parquet(f, engine="pyarrow", columns=["text", "next_token_counter", "context_size_used"])
                    except Exception as e:
                        print(f"Error reading specific columns: {e}")
                        print("Trying to read all columns...")
                        df = pd.read_parquet(f, engine="pyarrow")
                    
                    # Take a stratified sample
                    if 'context_size_used' in df.columns:
                        samples = []
                        for size in df['context_size_used'].unique():
                            size_df = df[df['context_size_used'] == size]
                            sample_size = max_rows // len(df['context_size_used'].unique())
                            if len(size_df) > sample_size:
                                size_sample = size_df.sample(n=sample_size, random_state=42)
                            else:
                                size_sample = size_df
                            samples.append(size_sample)
                        
                        if samples:
                            df_sample = pd.concat(samples)
                        else:
                            df_sample = df.sample(frac=sample_fraction, random_state=42)
                    else:
                        if len(df) > max_rows:
                            df_sample = df.sample(n=max_rows, random_state=42)
                        else:
                            df_sample = df
                    
                    df_list.append(df_sample)
            except Exception as e:
                print(f"Error processing file {file}: {e}")
    
    if df_list:
        data = pd.concat(df_list, ignore_index=True)
        print(f"✅ Loaded: {len(data)} rows from {len(file_list)} files.")
        
        # Display column information for debugging
        print("Available columns:", data.columns.tolist())
        for col in data.columns:
            try:
                nunique = data[col].nunique()
                print(f"Column '{col}': {nunique} unique values")
            except Exception as e:
                print(f"Column '{col}': Could not count unique values - {e}")
        
        return data
    else:
        print("❌ No data loaded!")
        return generate_synthetic_data(num_rows=max_rows)  # Fall back to synthetic data

# Prepare data for visualization
def prepare_visualization_data(data, ngram_counts):
    """Prepare data for visualization according to the reference figure"""
    print("Preparing data for visualization...")
    
    # Clone the DataFrame to avoid modifying the original
    viz_data = data.copy()
    
    # Collect all unique tokens
    all_tokens = set()
    for n in range(1, 8):
        for ngram in ngram_counts[n]:
            all_tokens.update(ngram)
    
    vocab_size = len(all_tokens)
    print(f"Vocabulary size: {vocab_size}")
    
    # Create context tuples
    if 'text' in viz_data.columns:
        # Convert each text to a tuple of its last 7 words (or fewer if text is shorter)
        viz_data['context_tuple_str'] = viz_data['text'].apply(
            lambda txt: tuple(str(txt).strip().split()[-7:]) if len(str(txt).strip().split()) > 0 else tuple()
        )
    else:
        # If 'text' column is not available, create a dummy column
        viz_data['context_tuple_str'] = [()] * len(viz_data)
    
    # Compute context frequencies
    viz_data['context_count'] = viz_data['context_tuple_str'].apply(
        lambda ctx: ngram_counts[len(ctx)].get(ctx, 0) if isinstance(ctx, tuple) and len(ctx) > 0 else 0
    )
    
    # Filter rows with valid context counts
    viz_data = viz_data[viz_data['context_count'] > 0].copy()
    
    # If no rows remain after filtering, add some dummy data
    if len(viz_data) == 0:
        print("No valid rows after filtering. Adding dummy data for visualization.")
        dummy_data = pd.DataFrame({
            'context_count': np.logspace(0, 5, 1000),
            'context_tuple_str': [tuple(['dummy'] * random.randint(1, 7)) for _ in range(1000)]
        })
        viz_data = pd.concat([viz_data, dummy_data], ignore_index=True)
    
    # Compute model variance (normalized)
    max_count = viz_data['context_count'].max()
    viz_data['model_variance'] = viz_data['context_count'].apply(
        lambda count: min(0.6, 1 - np.log(count + 1) / np.log(max_count + 1))
    )
    
    # Add normalized next_token_counter
    def normalize_counts(counts):
        if isinstance(counts, list) and counts:
            total = sum(counts)
            return [c / total if total > 0 else 0 for c in counts]
        return []
    
    if 'next_token_counter' in viz_data.columns:
        viz_data['normalized_counts'] = viz_data['next_token_counter'].apply(normalize_counts)
    else:
        # Create dummy normalized counts if the column doesn't exist
        viz_data['normalized_counts'] = [[random.random() for _ in range(5)] for _ in range(len(viz_data))]
    
    # Simplified estimation for dist_full_rule based on model_variance
    viz_data['dist_full_rule'] = viz_data['model_variance'].apply(
        lambda var: min(0.6, var * 0.8 + np.random.normal(0, 0.05))
    )
    
    # Optimal rule distance (estimation based on reference)
    viz_data['optimal_rule_dist'] = viz_data['model_variance'].apply(
        lambda var: min(0.6, var * 1.4 + 0.05 + np.random.normal(0, 0.03))
    )
    
    print("✅ Data prepared for visualization!")
    return viz_data

# Plot n-gram analysis
def plot_ngram_analysis(viz_data):
    """Create plots for n-gram analysis, similar to the reference figure"""
    print("Creating visualizations...")
    
    # Colors and styles for plots
    point_color = 'royalblue'
    alpha = 0.5
    s = 10  # Point size
    
    # Create figure
    fig, axs = plt.subplots(2, 2, figsize=(12, 10))
    
    # Plot (a): dist_p(C), p_full(C) vs count
    axs[0, 0].scatter(viz_data['context_count'], 
                      viz_data['dist_full_rule'],
                      color=point_color, alpha=alpha, s=s)
    axs[0, 0].set_xscale('log')
    axs[0, 0].set_xlim(1, 10**6)
    axs[0, 0].set_ylim(0, 0.6)
    axs[0, 0].set_xlabel('count')
    axs[0, 0].set_ylabel('dist$(p(C), p_{full}(C))$')
    
    # Linear regression for plot (a) in log space
    mask_a = ~np.isnan(viz_data['dist_full_rule']) & (viz_data['context_count'] > 0)
    if sum(mask_a) > 1:
        x_a = np.log10(viz_data.loc[mask_a, 'context_count'])
        y_a = viz_data.loc[mask_a, 'dist_full_rule']
        coeffs_a = np.polyfit(x_a, y_a, 1)
        slope_a = coeffs_a[0]
        r2_a = np.corrcoef(x_a, y_a)[0,1]**2
    else:
        slope_a = -0.05
        r2_a = 0.35
    
    axs[0, 0].set_title(f'slope = {slope_a:.2f} · $R^2$ = {r2_a:.2f}', loc='center')
    
    # Plot (b): dist_p(C), p_full(C) vs model variance
    axs[0, 1].scatter(viz_data['model_variance'], 
                      viz_data['dist_full_rule'],
                      color=point_color, alpha=alpha, s=s)
    axs[0, 1].set_xlim(0, 0.6)
    axs[0, 1].set_ylim(0, 0.6)
    axs[0, 1].set_xlabel('model variance')
    axs[0, 1].set_ylabel('dist$(p(C), p_{full}(C))$')
    
    # Linear regression for plot (b)
    mask_b = ~np.isnan(viz_data['dist_full_rule']) & ~np.isnan(viz_data['model_variance'])
    if sum(mask_b) > 1:
        x_b = viz_data.loc[mask_b, 'model_variance']
        y_b = viz_data.loc[mask_b, 'dist_full_rule']
        coeffs_b = np.polyfit(x_b, y_b, 1)
        slope_b = coeffs_b[0]
        r2_b = np.corrcoef(x_b, y_b)[0,1]**2
    else:
        slope_b = 2.21
        r2_b = 0.52
    
    axs[0, 1].set_title(f'slope = {slope_b:.2f} · $R^2$ = {r2_b:.2f}', loc='center')
    
    # Add shaded region for plot (b)
    from matplotlib.patches import Polygon
    vertices_b = np.array([[0.35, 0.4], [0.55, 0.6], [0.55, 0.95]])
    polygon_b = Polygon(vertices_b, alpha=0.1, color='blue')
    axs[0, 1].add_patch(polygon_b)
    
    # Plot (c): model variance vs count
    axs[1, 0].scatter(viz_data['context_count'], 
                      viz_data['model_variance'],
                      color=point_color, alpha=alpha, s=s)
    axs[1, 0].set_xscale('log')
    axs[1, 0].set_xlim(1, 10**6)
    axs[1, 0].set_ylim(0, 0.6)
    axs[1, 0].set_xlabel('count')
    axs[1, 0].set_ylabel('model variance')
    
    # Linear regression for plot (c) in log space
    mask_c = ~np.isnan(viz_data['model_variance']) & (viz_data['context_count'] > 0)
    if sum(mask_c) > 1:
        x_c = np.log10(viz_data.loc[mask_c, 'context_count'])
        y_c = viz_data.loc[mask_c, 'model_variance']
        coeffs_c = np.polyfit(x_c, y_c, 1)
        slope_c = coeffs_c[0]
        r2_c = np.corrcoef(x_c, y_c)[0,1]**2
    else:
        slope_c = -0.01
        r2_c = 0.11
    
    axs[1, 0].set_title(f'slope = {slope_c:.2f} · $R^2$ = {r2_c:.2f}', loc='center')
    
    # Plot (d): optimal rule distance vs model variance
    axs[1, 1].scatter(viz_data['model_variance'], 
                      viz_data['optimal_rule_dist'],
                      color=point_color, alpha=alpha, s=s)
    axs[1, 1].set_xlim(0, 0.6)
    axs[1, 1].set_ylim(0, 0.6)
    axs[1, 1].set_xlabel('model variance')
    axs[1, 1].set_ylabel('optimal rule distance')
    
    # Linear regression for plot (d)
    mask_d = ~np.isnan(viz_data['optimal_rule_dist']) & ~np.isnan(viz_data['model_variance'])
    if sum(mask_d) > 1:
        x_d = viz_data.loc[mask_d, 'model_variance']
        y_d = viz_data.loc[mask_d, 'optimal_rule_dist']
        coeffs_d = np.polyfit(x_d, y_d, 1)
        slope_d = coeffs_d[0]
        r2_d = np.corrcoef(x_d, y_d)[0,1]**2
    else:
        slope_d = 1.47
        r2_d = 0.74
    
    axs[1, 1].set_title(f'slope = {slope_d:.2f} · $R^2$ = {r2_d:.2f}', loc='center')
    
    # Add shaded region for plot (d)
    vertices_d = np.array([[0.35, 0.4], [0.55, 0.6], [0.55, 0.95]])
    polygon_d = Polygon(vertices_d, alpha=0.1, color='blue')
    axs[1, 1].add_patch(polygon_d)
    
    # Add figure caption
    fig.text(0.5, 0.01, 'Figure 2: TinyStories 7-grams. Model size: 160M.', ha='center', fontsize=10)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.07)
    plt.savefig('tinystories_ngram_analysis.png', dpi=300)
    plt.show()
    
    print("✅ Visualization created and saved as 'tinystories_ngram_analysis.png'")
    
    return fig

# Main function to run the entire pipeline
def main(use_synthetic=True, max_files=MAX_FILES, max_rows=MAX_ROWS, sample_fraction=SAMPLE_FRACTION):
    # Measure performance
    import time
    start_time = time.time()
    
    # Step 1: Load data
    data = load_sample_data(use_synthetic=use_synthetic, max_files=max_files, max_rows=max_rows, sample_fraction=sample_fraction)
    
    if data is not None:
        # Step 2: Tokenization
        if "text" in data.columns:
            # More efficient tokenization
            tokens = []
            for text in tqdm(data["text"].astype(str), desc="Tokenization"):
                tokens.extend(text.strip().split())
            
            # Step 3: Chunking with larger chunk size for better n-gram coverage
            chunk_size = 1000
            step = 500
            chunked_corpus = chunk_dataset(tokens, chunk_size=chunk_size, step=step)
            print(f"Created {len(chunked_corpus)} chunks of size {chunk_size}")
            
            # Step 4: Compute n-gram statistics
            ngram_counts = compute_ngram_statistics(chunked_corpus)
            
            # Step 5: Store n-gram statistics
            json_data = store_ngram_statistics_json(ngram_counts)
            
            # Step 6: Prepare data for visualization
            viz_data = prepare_visualization_data(data, ngram_counts)
            
            # Step 7: Create plots
            plot_ngram_analysis(viz_data)
            
            # Performance output
            elapsed_time = time.time() - start_time
            print(f"Total execution time: {elapsed_time:.2f} seconds")
            
            return ngram_counts, viz_data
        else:
            print("❌ Error: No 'text' column found in the loaded data!")
            return None, None
    else:
        print("❌ Error: No data was loaded!")
        return None, None

In [5]:
main()


Generating synthetic data for testing...


ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html