### Download Dataset, quick EDA, select sub-set of Data 

In [None]:
import matplotlib.pyplot as plt
from datasets import load_dataset
import numpy as np # For calculating counts efficiently if needed
import random # To set a seed for reproducibility if needed

# --- 1. Load the Dataset ---
# Ensure you have loaded the dataset as 'ds'
# ds = load_dataset("microsoft/ms_marco", "v1.1")
# Make sure 'ds' is available in your environment

from datasets import load_dataset

ds = load_dataset("microsoft/ms_marco", "v1.1")

# --- 2. Select Data Split ---
# We'll use the 'train' split for this example. You can change this to 'validation' or 'test' if needed.
split_name = 'train'
sample_size = 5000
random_seed = 42 # Set a seed for reproducible sampling

if split_name not in ds:
    print(f"Error: Split '{split_name}' not found in the dataset. Available splits: {list(ds.keys())}")
else:
    print(f"Using dataset split: {split_name}")
    original_data_split = ds[split_name]
    num_total_queries = len(original_data_split)
    print(f"Total queries in '{split_name}' split: {num_total_queries}")

    if sample_size > num_total_queries:
        print(f"Warning: Sample size ({sample_size}) is larger than the number of queries in the split ({num_total_queries}). Using the entire split.")
        sampled_data_split = original_data_split
        actual_sample_size = num_total_queries
    else:
        # --- 3. Sample the Data ---
        print(f"Randomly sampling {sample_size} queries (seed={random_seed})...")
        # Shuffle the dataset with a seed and select the first N elements
        sampled_data_split = original_data_split.shuffle(seed=random_seed).select(range(sample_size))
        actual_sample_size = sample_size
        print(f"Selected {len(sampled_data_split)} queries for analysis.")


    # --- 4. Calculate Passage Counts (on the sample) ---
    print("Calculating passage counts for the sample...")
    passage_counts = []
    try:
        # Iterate through the SAMPLED dataset
        for entry in sampled_data_split:
            # The number of passages is the length of the 'passage_text' list
            count = len(entry['passages']['passage_text'])
            passage_counts.append(count)
        print(f"Calculated counts for {len(passage_counts)} sampled queries.")

        # --- 5. Plot the Histogram ---
        plt.figure(figsize=(12, 6)) # Adjust figure size for better readability
        # Determine appropriate bins for the sample
        max_count = max(passage_counts) if passage_counts else 0
        bins = np.arange(max_count + 2) - 0.5 # Center bins around integers

        plt.hist(passage_counts, bins=bins, edgecolor='black') # Use calculated bins
        plt.xlabel("Number of Passages Associated with a Query")
        plt.ylabel("Number of Queries (Frequency in Sample)")
        plt.title(f"Distribution of Passage Counts per Query (Sample of {actual_sample_size} from {split_name} split)")
        plt.xticks(np.arange(max_count + 1)) # Set ticks to integer counts
        plt.grid(axis='y', alpha=0.75)
        plt.show()

    except KeyError as e:
        print(f"Error accessing data structure: {e}. Please check the dataset format.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")