<a href="https://colab.research.google.com/github/berkyalcinkaya/cs145-project2-systems/blob/main/cs145_project2_systems_template_fa2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab (Main)"/></a>

<a href="https://colab.research.google.com/github/berkyalcinkaya/cs145-project2-systems/blob/berk/cs145_project2_systems_template_fa2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab (berk)"/></a>

## Collaborators

1.   Berk Yalcinkaya
2.   Nick Allen


# Setup

In [1]:
import pandas as pd
import os
import uuid
import argparse
import time
import psutil
import heapq
import pyarrow as pa
import pyarrow.parquet as pq
import random
import string
import numpy as np
from typing import List, Optional, Callable, Dict, Union, Any, Tuple
import shutil
import glob
import gc
from IPython.display import display
import tempfile
from pathlib import Path
from functools import partial
import memory_profiler
import math

In [3]:
def clear_parquet_files():
    for file in glob.glob("*.parquet"):
        os.remove(file)
    return

clear_parquet_files()


# Section 0: Generate Test Data

This section has already been implemented for you.

In [2]:
def generate_songs_chunk(start, size, string_length=100):
    data = {
        "song_id": range(start, start + size),
        "title": [f"Song_{i}" for i in range(start, start + size)],
    }
    base_strings = generate_base_strings(size, string_length)
    for i in range(1, 11):
        data[f"extra_col_{i}"] = np.roll(base_strings, shift=i)
    return pd.DataFrame(data)


def generate_users_chunk(start, size, string_length=100):
    data = {
        "user_id": range(start, start + size),
        "age": [18 + ((start + i) % 60) for i in range(size)],
    }
    base_strings = generate_base_strings(size, string_length)
    for i in range(1, 11):
        data[f"extra_col_{i}"] = np.roll(base_strings, shift=i)
    return pd.DataFrame(data)


def generate_listens_chunk(start, size, num_users, num_songs, string_length=16):
    data = {
        "listen_id": range(start, start + size),
        "user_id": np.random.randint(0, num_users, size=size),
        "song_id": np.random.randint(0, num_songs, size=size),
    }
    base_strings = generate_base_strings(size, string_length)
    for i in range(1, 11):
        data[f"extra_col_{i}"] = np.roll(base_strings, shift=i)
    return pd.DataFrame(data)


def generate_base_strings(num_records, string_length):
    chars = np.array(list("ab"))
    random_indices = np.random.randint(0, len(chars), size=(num_records, string_length))
    char_array = chars[random_indices]
    return np.array(list(map("".join, char_array)))


def _write_parquet_streamed(
    filename,
    total_rows,
    make_chunk_fn,
    chunk_size=250_000,
    compression="snappy",
):
    """
    Stream DataFrame chunks to a single Parquet file with one ParquetWriter.
    - schema_df: optional small DataFrame to lock schema; if None we'll infer from the first chunk.
    """
    written = 0

    first_chunk = make_chunk_fn(0, min(chunk_size, total_rows))
    first_table = pa.Table.from_pandas(first_chunk, preserve_index=False)
    writer = pq.ParquetWriter(filename, first_table.schema, compression=compression)
    writer.write_table(first_table)

    written += len(first_chunk)
    del first_chunk
    gc.collect()

    while written < total_rows:
        take = min(chunk_size, total_rows - written)
        chunk_df = make_chunk_fn(written, take)
        writer.write_table(pa.Table.from_pandas(chunk_df, preserve_index=False))
        written += take
        del chunk_df
        gc.collect()

    writer.close()


def generate_test_data(target_size="100MB"):
    """
    Generate datasets with proper foreign key relationships.

    Target COMPRESSED Parquet file sizes on disk:
    100MB total compressed:
        - Songs: 10K rows → ~5MB (5% of total)
        - Users: 50K rows → ~20MB (20% of total)
        - Listens: 1M rows → ~75MB (75% of total)
    1GB total compressed:
        - Songs: 100K rows → ~50MB (5% of total)
        - Users: 500K rows → ~200MB (20% of total)
        - Listens: 10M rows → ~750MB (75% of total)

    Each table needs:
        - Primary key column(s)
        - 10 additional string columns of k characters each
        - For Users: add 'age' column (random 18-80)

    CRITICAL: Listens table must have valid foreign keys!
    Every song_id must exist in Songs
    Every user_id must exist in Users
    """

    assert target_size in ["100MB", "1GB", "10GB"]
    if target_size == "100MB":
        num_songs = 10_000
        num_users = 50_000
        num_listens = 1_000_000

        songs_chunk = 10_000
        users_chunk = 50_000
        listens_chunk = 1_000_000
    elif target_size == "1GB":
        num_songs = 100_000
        num_users = 500_000
        num_listens = 10_000_000

        songs_chunk = 10_000
        users_chunk = 50_000
        listens_chunk = 1_000_000
    else: 
        num_songs = 1_000_000
        num_users = 5_000_000
        num_listens = 100_000_000

        songs_chunk = 10_000
        users_chunk = 50_000
        listens_chunk = 1_000_000

    print("Writing Songs")
    _write_parquet_streamed(
        filename=f"songs_{target_size}.parquet",
        total_rows=num_songs,
        make_chunk_fn=lambda start, size: generate_songs_chunk(start, size),
        chunk_size=songs_chunk,
    )

    print("Writing Users")
    _write_parquet_streamed(
        filename=f"users_{target_size}.parquet",
        total_rows=num_users,
        make_chunk_fn=lambda start, size: generate_users_chunk(start, size),
        chunk_size=users_chunk,
    )

    print("Writing Listens")
    _write_parquet_streamed(
        filename=f"listens_{target_size}.parquet",
        total_rows=num_listens,
        make_chunk_fn=lambda start, size: generate_listens_chunk(
            start, size, num_users, num_songs
        ),
        chunk_size=listens_chunk,
    )

    print("Done!")

# Section 0b: Define Memory and Performance Benchmarking Functions
- Memory will be monitored using the memory_profiler function: %%memit above a cell monitors memory usage of entire cell, %memit monitors the memory usage of a single line
- CPU performance will be measured with a custom decorator defined below


In [3]:
%load_ext memory_profiler

In [4]:
def timer(func):
    """
    Decorator to measure and print the execution time of a function.

    Usage:
        @timer
        def my_function(...):
            ...

    When the decorated function is called, it will print the elapsed time in seconds with a descriptive message.

    Returns:
        The result of the wrapped function, after printing its runtime.
    """
    def wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        print(f"Method '{func.__name__}' took {end_time - start_time:.4f} seconds.")
        return result
    return wrapper

In [6]:
random.seed(0)
if not os.path.exists("listens_100MB.parquet"):
    generate_test_data("100MB")
else:
    print("100MB data already generated")
if not os.path.exists("listens_1GB.parquet"):
    generate_test_data('1GB')
else:
    print("1GB data already generated")
if not os.path.exists("listens_10GB.parquet"):
    generate_test_data('10GB')
else:
    print("10GB data already generated")

100MB data already generated
1GB data already generated
10GB data already generated


# Section 1: Parquet-based Columnar Storage

Implement Parquet-based storage for the tables
- For simplicity, store all data for a table in a single Parquet file and use a single DataFrame object as a buffer

In [7]:
# see ed: https://edstem.org/us/courses/87394/discussion/7251811 for advice on writing to a parquet without loading existing into RAM
# a ColumnarDbFile is actually a directory with an arbitrary number of parquet files inside
# Append writes a new file with the next postfix
# Retrieve reads all parquet files and concatenates them together, done natively by pandas
class ColumnarDbFile:
    def __init__(self, table_name, file_dir='data', file_pfx=''):
        self.file_pfx = file_pfx
        self.table_name = table_name
        self.file_dir = file_dir
        #os.makedirs(self.file_dir, exist_ok=True)
        self.base_file_name = f"{self.file_dir}/{self.file_pfx}_{self.table_name}"
        os.makedirs(self.base_file_name, exist_ok=True)
        
        # Streaming state
        self._streaming = False
        self._stream_writer = None
        self._stream_file_path = None

    def build_table(self, data):
        """Build and save table data to Parquet."""
        assert self._get_num_parquets() == 0
        target_path = f"{self.base_file_name}/{self.table_name}-0.parquet"
        # If data is a string and is a valid file path, copy it
        if isinstance(data, str) and os.path.isfile(data):
            shutil.copy(data, target_path)
        elif isinstance(data, pd.DataFrame):
            data.to_parquet(target_path)
        else:
            raise ValueError("data must be a pandas DataFrame or a valid file path string")
        return

    def retrieve_data(self, columns=None, sample=None):
        """Create pd.DataFrame by reading from Parquet"""
        if sample is not None:
            return next(self.iter_pages(sample, columns=columns, as_pandas=True))
        else:
            return pd.read_parquet(self.base_file_name, columns=columns)

    def append_data(self, data):
        """Append new data to Parquet
        
        Behavior depends on streaming mode:
        - If streaming (start_stream() called): writes to a single parquet file via ParquetWriter
        - Otherwise: creates a new parquet file for each call
        """
        if self._streaming:
            # Convert DataFrame to PyArrow Table
            table = pa.Table.from_pandas(data, preserve_index=False)
            
            # Lazy writer creation: create on first append with schema
            if self._stream_writer is None:
                self._stream_writer = pq.ParquetWriter(self._stream_file_path, table.schema)
            
            # Write to stream
            self._stream_writer.write_table(table)
        else:
            # Original behavior: create new file
            data.to_parquet(self.get_new_parquet_file())
        return

    def get_new_parquet_file(self):
        '''return a path to a new file with name uniqueness'''
        return f"{self.base_file_name}/{self.table_name}-{self._get_num_parquets()}.parquet"

    def _get_num_parquets(self):
        return len(self.get_all_parquet_paths())

    def get_all_parquet_paths(self):
        return glob.glob(f"{self.base_file_name}/*.parquet")
    
    def start_stream(self):
        """Start streaming mode for efficient batch writes.
        
        After calling this, append_data() will write to a single parquet file
        using ParquetWriter (streaming) instead of creating separate files.
        Must call stop_stream() when done to properly close the writer.
        
        If called multiple times, closes any existing writer and starts a new stream.
        
        Can also be used as a context manager:
            with output_db:
                output_db.append_data(df1)
                output_db.append_data(df2)
            # Automatically stops streaming
        """
        # Close existing writer if streaming was already active
        if self._streaming and self._stream_writer is not None:
            self._stream_writer.close()
        
        # Initialize streaming state
        self._streaming = True
        self._stream_file_path = self.get_new_parquet_file()
        self._stream_writer = None  # Will be created lazily on first append_data()
    
    def stop_stream(self):
        """Stop streaming mode and close the ParquetWriter.
        
        Safe to call multiple times or if streaming was never started.
        """
        if self._stream_writer is not None:
            self._stream_writer.close()
            self._stream_writer = None
        
        self._streaming = False
        self._stream_file_path = None
    
    def __enter__(self):
        """Context manager entry: start streaming mode."""
        self.start_stream()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit: stop streaming mode."""
        self.stop_stream()
        return False  # Don't suppress exceptions
    
    def __del__(self):
        """Destructor: ensure stream is closed if not explicitly stopped."""
        # Safety net: close writer if streaming was left open
        if self._streaming and self._stream_writer is not None:
            try:
                self._stream_writer.close()
            except:
                pass  # Ignore errors during cleanup

    def table_metadata(self):
        """Return total rows and total byte size of the table without loading data."""
        parquet_files = glob.glob(f"{self.base_file_name}/*.parquet")

        total_rows = 0
        total_bytes = 0

        for file in parquet_files:
            pf = pq.ParquetFile(file)
            meta = pf.metadata

            total_rows += meta.num_rows
            total_bytes += meta.serialized_size  # includes footer + metadata

        return {
            "num_files": len(parquet_files),
            "total_rows": total_rows,
            "total_compressed_bytes": total_bytes,
        }

    def table_disk_usage(self):
        parquet_files = glob.glob(f"{self.base_file_name}/*.parquet")

        total_bytes = sum(os.path.getsize(f) for f in parquet_files)

        return {
            "num_files": len(parquet_files),
            "total_bytes": total_bytes
        }

    def iter_pages(self, rows_per_batch: int = 100_000, columns=None, as_pandas=True):
        for path in self.get_all_parquet_paths():        
            pf = pq.ParquetFile(path)
            for batch in pf.iter_batches(batch_size=rows_per_batch, columns=columns):
                yield batch.to_pandas() if as_pandas else batch
    
    def __eq__(self, other):
        # load both tables into memory
        df1 = pd.read_parquet(self.base_file_name)
        df2 = pd.read_parquet(other.base_file_name)

        # check for row-wise equality
        return df1.equals(df2)

    @staticmethod
    def fits_in_12GB(bytes_needed: int) -> bool:
        TWELVE_GB = 12 * 1024**3
        return bytes_needed <= TWELVE_GB

    @staticmethod
    def can_process_parquet(bytes_on_disk: int, compression_factor: int = 5) -> bool:
        """
        Returns True if a Parquet dataset of `bytes_on_disk` can be processed
        within 12 GB of RAM, after accounting for decompression expansion.
        """
        estimated_ram = bytes_on_disk * compression_factor
        TWELVE_GB = 12 * 1024**3
        return estimated_ram <= TWELVE_GB

In [8]:
%%memit
print("Building tables...")
if os.path.exists('data'):
    print("Removing existing data directory")
    shutil.rmtree('data')

sizes = ["100MB", "1GB", "10GB"]
tables = {}
for size in sizes:
    for table_name in ["Songs", "Users", "Listens"]:
        key = f"{table_name}_{size}"
        tables[key] = ColumnarDbFile(f"{table_name}_{size}", file_dir='data')
        parquet_path = f"{table_name.lower()}_{size}.parquet"
        assert os.path.exists(parquet_path)
        tables[key].build_table(parquet_path)

print("Tables built successfully.")

Building tables...
Removing existing data directory
Tables built successfully.
peak memory: 141.44 MiB, increment: 2.53 MiB


In [9]:
# retrieve data
tables['Songs_100MB'].retrieve_data(columns = ['song_id', 'title'])

Unnamed: 0,song_id,title
0,0,Song_0
1,1,Song_1
2,2,Song_2
3,3,Song_3
4,4,Song_4
...,...,...
9995,9995,Song_9995
9996,9996,Song_9996
9997,9997,Song_9997
9998,9998,Song_9998


In [10]:
tables['Listens_100MB'].retrieve_data(columns = ['listen_id', 'user_id', 'song_id'])

Unnamed: 0,listen_id,user_id,song_id
0,0,19936,7687
1,1,37756,9045
2,2,35676,3593
3,3,18861,2977
4,4,9826,4653
...,...,...,...
999995,999995,15502,4168
999996,999996,1562,1217
999997,999997,5838,2871
999998,999998,35276,1541


Analyze and report on:
- Space efficiency compared to row storage
  - e.g. Compare file sizes on disk: How much disk space does Parquet use vs. a row storage format like CSV?
- Compression ratios achieved with Parquet
  - e.g. Compare Parquet’s uncompressed encoded size (reported in its metadata) to its compressed on-disk size to compute compression ratios.
  - You could also report the memory expansion factor: how much larger the dataset becomes when loaded into a `pd.DataFrame` compared to the compressed file size.
- Read/write performance characteristics
  - e.g. Read performance: How long does it take to read all columns from Parquet vs. CSV?
  - e.g. Columnar advantage: How long does it take to read selective columns from Parquet vs. reading all columns?
  - e.g. Write performance: How long does it take to write data to Parquet vs. CSV?

In [11]:
@timer
def analyze(size="100MB"):
    """Analyze storage efficiency, compression, and read/write performance."""

    table_files = {
        "Songs": f"songs_{size}.parquet",
        "Users": f"users_{size}.parquet",
        "Listens": f"listens_{size}.parquet",
    }

    report_rows = []

    for table_name, parquet_file in table_files.items():
        parquet_path = Path(parquet_file)

        df = pd.read_parquet(parquet_path)
        mem_usage_bytes = df.memory_usage(deep=True).sum() # memory usage of the dataframe
        parquet_size_bytes = parquet_path.stat().st_size # size of the parquet file on disk

        parquet_file_obj = pq.ParquetFile(parquet_path)
        metadata = parquet_file_obj.metadata
        uncompressed_bytes = 0

        # iterate over all row groups and columns to get the total uncompressed size of the parquet file
        for rg_idx in range(metadata.num_row_groups):
            row_group = metadata.row_group(rg_idx)
            for col_idx in range(row_group.num_columns):
                column_meta = row_group.column(col_idx)
                if column_meta.total_uncompressed_size is not None:
                    uncompressed_bytes += column_meta.total_uncompressed_size

        # calculate compression ratio and memory expansion
        compression_ratio = (
            uncompressed_bytes / parquet_size_bytes
        )
        memory_expansion = (
            mem_usage_bytes / parquet_size_bytes
        )

        # test reading speed of parquet file vs csv, for all columns and selective columns
        # pick 1 less than the total number of columns to test reading selective columns
        subset_columns = list(df.columns)[0:len(df.columns)-1]

        with tempfile.TemporaryDirectory() as tmpdir:
            tmpdir_path = Path(tmpdir)

            csv_path = tmpdir_path / f"{parquet_path.stem}.csv"
            start = time.perf_counter()
            df.to_csv(csv_path, index=False)
            write_csv_time = time.perf_counter() - start
            csv_size_bytes = csv_path.stat().st_size

            parquet_tmp_path = tmpdir_path / f"{parquet_path.stem}.parquet"
            start = time.perf_counter()
            df.to_parquet(parquet_tmp_path, index=False)
            write_parquet_time = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_parquet(parquet_path)
            read_parquet_all = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_csv(csv_path)
            read_csv_all = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_parquet(parquet_path, columns=subset_columns)
            read_parquet_subset = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_csv(csv_path, usecols=subset_columns)
            read_csv_subset = time.perf_counter() - start

        size_saving_pct = (
            100.0 * (1 - parquet_size_bytes / csv_size_bytes)
        )

        # append the results to the report
        report_rows.append(
            {
                "table": table_name,
                "parquet_size_mb": parquet_size_bytes / (1024 ** 2),
                "csv_size_mb": csv_size_bytes / (1024 ** 2),
                "size_saving_pct": size_saving_pct,
                "compression_ratio": compression_ratio,
                "memory_expansion": memory_expansion,
                "read_parquet_all_s": read_parquet_all,
                "read_csv_all_s": read_csv_all,
                "read_parquet_subset_s": read_parquet_subset,
                "read_csv_subset_s": read_csv_subset,
                "write_parquet_s": write_parquet_time,
                "write_csv_s": write_csv_time,
            }
        )

        del df
        gc.collect()

    summary = pd.DataFrame(report_rows)
    print("Analysis Summary for Tables of Size " + size + " (sizes in MB, times in seconds):")
    return summary

In [12]:
display(analyze(size="100MB"))

Analysis Summary for Tables of Size 100MB (sizes in MB, times in seconds):
Method 'analyze' took 9.9829 seconds.


Unnamed: 0,table,parquet_size_mb,csv_size_mb,size_saving_pct,compression_ratio,memory_expansion,read_parquet_all_s,read_csv_all_s,read_parquet_subset_s,read_csv_subset_s,write_parquet_s,write_csv_s
0,Songs,4.271927,9.773173,56.289255,2.41591,3.47343,0.007974,0.073592,0.007648,0.062106,0.027642,0.097124
1,Users,20.347857,48.579238,58.114089,2.471382,3.529207,0.038628,0.340593,0.034948,0.308032,0.090254,0.443893
2,Listens,79.926873,178.866784,55.31486,2.43253,8.042059,0.256465,1.740107,0.293358,1.624935,0.574689,2.668126


Across all three tables, Parquet is markedly more space-efficient than row-oriented CSV: file sizes drop by ~55–58% (e.g., Listens 79.9 MB vs 178.9 MB), corresponding to compression ratios of about 2.4–2.5× relative to the uncompressed Parquet data. Loading into pandas expands data substantially beyond disk size (≈3.5× for Songs/Users and ≈8× for Listens), highlighting the in-memory cost of wide/large tables. Performance-wise, Parquet reads are much faster than CSV both for full scans and column subsets—about 7–9× speedup for full reads and ~5–9× for selective reads—while Parquet writes are also faster, roughly 3.5–5× quicker than writing CSV.

# Section 2: Parse SQL Query

In this section, you should implement logic to parse the following SQL query:
```sql
    SELECT s.song_id, AVG(u.age) AS avg_age,
       COUNT(DISTINCT l.user_id) AS count_distinct_users,
    FROM Songs s
    JOIN Listens l ON s.song_id = l.song_id
    JOIN Users u ON l.user_id = u.user_id
    GROUP BY s.song_id, s.title
    ORDER BY COUNT(DISTINCT l.user_id) DESC, s.song_id;
```

You should manually extract the components from the provided query (i.e. you don't need to implement a general SQL parser, just handle this specific query).

In [13]:
query = """SELECT s.song_id, AVG(u.age) AS avg_age,
COUNT(DISTINCT l.user_id)
FROM Songs s
JOIN Listens l ON s.song_id = l.song_id
JOIN Users u ON l.user_id = u.user_id
GROUP BY s.song_id, s.title
ORDER BY COUNT(DISTINCT l.user_id) DESC, s.song_id;
"""

In [15]:
import re
import re

def parse_tables(query):

    # pattern matches: "from songs s" or "join listens l"
    pattern = r"(from|join)\s+([a-z_]+)\s+([a-z])"

    matches = re.findall(pattern, query)

    tables = {}
    for _, table_name, alias in matches:
        tables[alias] = table_name

    return tables

def parse_joins(query):

    # 1) Get the base table from the FROM clause
    base_match = re.search(r"from\s+([a-z_]+)\s+([a-z])", query)
    if not base_match:
        raise ValueError("Could not find FROM clause")

    base_table_name = base_match.group(1)
    base_alias = base_match.group(2)
    base_table = (base_alias, base_table_name)

    # 2) Get each JOIN clause, in order
    # pattern matches:
    #   join listens l on s.song_id = l.song_id
    join_pattern = (
        r"join\s+([a-z_]+)\s+([a-z])\s+on\s+"
        r"([a-z])\.([a-z_]+)\s*=\s*([a-z])\.([a-z_]+)"
    )

    joins = []
    for m in re.finditer(join_pattern, query):
        joined_table_name = m.group(1)
        joined_alias = m.group(2)
        left_alias = m.group(3)
        left_col = m.group(4)
        right_alias = m.group(5)
        right_col = m.group(6)

        joins.append(
            {
                "joined_table_alias": joined_alias,
                "joined_table_name": joined_table_name,
                "left_alias": left_alias,
                "left_column": left_col,
                "right_alias": right_alias,
                "right_column": right_col,
            }
        )

    return {"base_table" : base_table, "Joins" : joins}


def parse_group_by(query):
    """
    Return GROUP BY columns as a list of (alias, column) tuples.
    Example: [('s', 'song_id'), ('s', 'title')]
    """
    q = query.lower()

    # Capture whatever is between GROUP BY and ORDER BY/semicolon/end
    match = re.search(r"group\s+by\s+(.+?)(order\s+by|;|$)", q, re.DOTALL)
    if not match:
        return []

    groupby_text = match.group(1).strip()

    columns = []
    for col in groupby_text.split(","):
        col = col.strip()

        # Expect pattern: alias.column
        alias, column = col.split(".")
        columns.append((alias, column))

    return columns

def parse_select_and_aggregations(query):
    """
    Build:
      aggregations: {agg_key: {...}}
      select: list of items that may refer to agg_key
    """
    q = query.lower()

    m = re.search(r"select\s+(.+?)\s+from", q, re.DOTALL)
    if not m:
        return [], {}

    select_text = m.group(1).strip()
    raw_items = [item.strip() for item in select_text.split(",") if item.strip()]

    select_list = []
    aggregations = {}
    agg_id = 1

    for idx, item in enumerate(raw_items, start=1):
        # AVG(...)
        if item.startswith("avg("):
            m_avg = re.match(
                r"avg\(\s*([a-z])\.([a-z_]+)\s*\)(\s+as\s+([a-z_]+))?",
                item
            )
            if not m_avg:
                raise ValueError(f"Could not parse AVG aggregation: {item}")
            alias_letter = m_avg.group(1)
            col_name = m_avg.group(2)
            out_alias = m_avg.group(4) if m_avg.group(4) else None

            aggregations[agg_id] = {
                "func": "avg",
                "source": (alias_letter, col_name),
                "distinct": False,
                "output_name": out_alias,
            }

            select_list.append(
                {
                    "kind": "aggregation",
                    "agg_key": agg_id,
                    "alias": out_alias,

                }
            )
            agg_id += 1

        # COUNT(DISTINCT ...)
        elif item.startswith("count("):
            m_cnt = re.match(
                r"count\(\s*distinct\s+([a-z])\.([a-z_]+)\s*\)(\s+as\s+([a-z_]+))?",
                item
            )
            if not m_cnt:
                raise ValueError(f"Could not parse COUNT aggregation: {item}")
            alias_letter = m_cnt.group(1)
            col_name = m_cnt.group(2)
            out_alias = m_cnt.group(4) if m_cnt.group(4) else None

            aggregations[agg_id] = {
                "func": "count",
                "source": (alias_letter, col_name),
                "distinct": True,
                "output_name": out_alias,
            }

            select_list.append(
                {
                    "kind": "aggregation",
                    "agg_key": agg_id,
                    "alias": out_alias,
                }
            )
            agg_id += 1

        # Plain column: alias.column
        else:
            alias_letter, col_name = item.split(".")
            select_list.append(
                {
                    "kind": "column",
                    "source": (alias_letter, col_name),
                    "alias": None,
                }
            )

    return select_list, aggregations


def parse_order_by(query, aggregations):
    """
    Build order_by list where entries can refer to aggregations via agg_key.
    """
    q = query.lower()

    m = re.search(r"order\s+by\s+(.+?)(;|$)", q, re.DOTALL)
    if not m:
        return []

    order_text = m.group(1).strip()
    raw_items = [item.strip() for item in order_text.split(",") if item.strip()]

    order_by = []

    for item in raw_items:
        direction = "asc"
        expr = item

        if expr.endswith(" desc"):
            direction = "desc"
            expr = expr[:-5].strip()
        elif expr.endswith(" asc"):
            direction = "asc"
            expr = expr[:-4].strip()

        # COUNT(DISTINCT ...) → match an aggregation
        if expr.startswith("count("):
            m_cnt = re.match(
                r"count\(\s*distinct\s+([a-z])\.([a-z_]+)\s*\)",
                expr
            )
            if not m_cnt:
                raise ValueError(f"Could not parse ORDER BY aggregation: {expr}")
            src = (m_cnt.group(1), m_cnt.group(2))

            agg_key = None
            for k, agg in aggregations.items():
                if (
                    agg["func"] == "count"
                    and agg["distinct"]
                    and agg["source"] == src
                ):
                    agg_key = k
                    break

            if agg_key is None:
                raise ValueError(f"No matching aggregation found for ORDER BY expr: {expr}")

            order_by.append(
                {
                    "kind": "aggregation",
                    "agg_key": agg_key,
                    "direction": direction,
                }
            )

        else:
            # assume plain column: alias.column
            alias_letter, col_name = expr.split(".")
            order_by.append(
                {
                    "kind": "column",
                    "source": (alias_letter, col_name),
                    "direction": direction,
                }
            )

    return order_by

@timer
def parse_sql(query):
    """
    YOUR TASK: Extract tables, joins, and aggregations
    """
    # Parse SQL string to identify:
    # - Tables involved
    # - Join conditions
    # - GROUP BY columns
    # - Aggregation functions
    # Your implementation here
    query = query.lower()
    output = {}

    output["tables"] = parse_tables(query)
    output["joins"] = parse_joins(query)
    output["GroupBy"] = parse_group_by(query)
    output["select"], output["aggregations"] = parse_select_and_aggregations(query)
    output["orderBy"] = parse_order_by(query, output["aggregations"])

    return output

In [16]:
output = parse_sql(query)
for key, value in output.items():
    print(f"{key}: {value}")

Method 'parse_sql' took 0.0024 seconds.
tables: {'s': 'songs', 'l': 'listens', 'u': 'users'}
joins: {'base_table': ('s', 'songs'), 'Joins': [{'joined_table_alias': 'l', 'joined_table_name': 'listens', 'left_alias': 's', 'left_column': 'song_id', 'right_alias': 'l', 'right_column': 'song_id'}, {'joined_table_alias': 'u', 'joined_table_name': 'users', 'left_alias': 'l', 'left_column': 'user_id', 'right_alias': 'u', 'right_column': 'user_id'}]}
GroupBy: [('s', 'song_id'), ('s', 'title')]
select: [{'kind': 'column', 'source': ('s', 'song_id'), 'alias': None}, {'kind': 'aggregation', 'agg_key': 1, 'alias': 'avg_age'}, {'kind': 'aggregation', 'agg_key': 2, 'alias': None}]
aggregations: {1: {'func': 'avg', 'source': ('u', 'age'), 'distinct': False, 'output_name': 'avg_age'}, 2: {'func': 'count', 'source': ('l', 'user_id'), 'distinct': True, 'output_name': None}}
orderBy: [{'kind': 'aggregation', 'agg_key': 2, 'direction': 'desc'}, {'kind': 'column', 'source': ('s', 'song_id'), 'direction': 'a

# Section 3: Implement Join Algorithms

In this section, you will implement the execution operators (*how* to join) and aggregation after joins.

**Reminder:** If you use temporary files or folders, you should clean them up either as part of your join logic, or after each run. Otherwise you might run into correctness issues!

In [17]:
import hashlib

def HASHVALUE(value, B):
    if isinstance(value, int):
        return hash(value) % B
    sha256 = hashlib.sha256()
    sha256.update(str(value).encode("utf-8"))
    return int(sha256.hexdigest(), 16) % B

In [83]:
def hash_partition(
    table: ColumnarDbFile,
    hash_keys: List[str],
    num_partitions: int,
    parquet_batch_size: int,
    hash_value_fn: Callable[[object, int], int],
    make_partition_path_fn: Callable[[int], str],
    columns: Optional[List[str]] = None,
):
    """
    Hash-partition `table` into `num_partitions` Parquet files.

    - `hash_keys` is a list of column names (one or more).
    - If len(hash_keys) > 1, we build a temporary concatenated column `_hash_key`
      and hash on that.
    - `hash_value_fn(key, num_partitions)` returns an int in [0, num_partitions).
    - `columns` are the columns to write into each partition.
      All `hash_keys` are automatically included in `columns`.
    """
    is_multi_col = len(hash_keys) > 1
    hash_col_name = "_hash_key" if is_multi_col else hash_keys[0]

    # Normalize and ensure hash_keys are included in the columns we read & write
    if columns:
        for col in hash_keys:
            if col not in columns:
                columns.append(col)

    writers: Dict[int, pq.ParquetWriter] = {}
    for batch_df in table.iter_pages(columns=columns, rows_per_batch=parquet_batch_size):
        # If multiple hash columns, build a temporary concatenated key column
        if is_multi_col:
            batch_df[hash_col_name] = (
                batch_df[hash_keys]
                .astype(str)
                .agg("|".join, axis=1)
            )

        # Compute partition id
        batch_df["_part"] = batch_df[hash_col_name].apply(
            lambda x: hash_value_fn(x, num_partitions)
        )

        if is_multi_col:
            batch_df = batch_df.drop(columns=hash_col_name)
        if columns:
            batch_df = batch_df[columns + ["_part"]]

        # Group rows by partition and write each group
        for part_id, part_df in batch_df.groupby("_part"):
            part_df = part_df.drop(columns=["_part"])

            part_table = pa.Table.from_pandas(part_df, preserve_index=False)

            writer = writers.get(part_id)
            if writer is None:
                part_path = make_partition_path_fn(part_id)
                writer = pq.ParquetWriter(part_path, part_table.schema)
                writers[part_id] = writer

            writer.write_table(part_table)

    for w in writers.values():
        w.close()

Implement `HashPartitionJoin`:
1. Hash partition both tables
2. Build hash table from smaller partition
3. Probe with larger partition
4. Return joined results

In [None]:
class FastHashPartitionJoin:
    def __init__(self, num_partitions=4, parquet_batch_size=100_000, use_streaming=False, time_it=True):
        self.num_partitions = num_partitions
        self.parquet_batch_size = parquet_batch_size
        self.use_streaming = use_streaming
        self.time_it = time_it
    
    @timer
    def join(self, table1: ColumnarDbFile, table2: ColumnarDbFile, join_key1, join_key2,
             temp_dir='temp', columns_table1=None, columns_table2=None):
        """
        Perform an optimized hash partition join between two ColumnarDbFile instances.

        Speed ups:
        - load smaller table into memory and build hash map with pandas groupby, larger table is probed in batches
        - vectorized join using numpy operations: see _vectorized_join method for more details
        """
        os.makedirs(temp_dir, exist_ok=True)

        # Partition both tables
        self._hash_partition(table1, join_key1, temp_dir, 'left', columns_table1)
        self._hash_partition(table2, join_key2, temp_dir, 'right', columns_table2)

        output = ColumnarDbFile(f"hpj_{table1.table_name}_{table2.table_name}")
        
        # Clean up any existing files in the output directory
        if os.path.exists(output.base_file_name):
            for file_path in output.get_all_parquet_paths():
                os.remove(file_path)
        
        if self.use_streaming:
            output.start_stream()

        for part_id in range(self.num_partitions):
            left_path = self._make_partition_path(temp_dir, "left", part_id)
            right_path = self._make_partition_path(temp_dir, "right", part_id)

            if not (os.path.exists(left_path) and os.path.exists(right_path)):
                continue
            
            # Process this partition with batched reading
            self._process_partition(
                left_path, right_path, join_key1, join_key2, output
            )

        if self.use_streaming:
            output.stop_stream()
        
        shutil.rmtree(temp_dir)
        return output

    def _process_partition(self, left_path, right_path, join_key1, join_key2, output):
        """
        Process a partition an individual partition from left and right
        Determine which side is smaller and build hash map from that side
        Probe with larger side
        """
        # Get metadata to determine which side is smaller
        left_size = pq.ParquetFile(left_path).metadata.num_rows
        right_size = pq.ParquetFile(right_path).metadata.num_rows
        
        if left_size <= right_size:
            # Build hash map from left, probe with right
            self._build_and_probe(left_path, right_path, join_key1, join_key2, 
                                  output, left_is_build=True)
        else:
            # Build hash map from right, probe with left
            self._build_and_probe(right_path, left_path, join_key2, join_key1, 
                                  output, left_is_build=False)

    def _build_and_probe(self, build_path, probe_path, build_key, probe_key, 
                         output, left_is_build):
        """
        Build hash map from build side and probe with probe side using batched reading.
        """
        # Build hash map from the smaller side (build side)
        hash_map = self._build_hash_map(build_path, build_key)
        
        # Probe with the larger side in batches
        probe_file = pq.ParquetFile(probe_path)
        build_df = pq.read_table(build_path).to_pandas()
        
        for probe_batch in probe_file.iter_batches(batch_size=self.parquet_batch_size):
            probe_df = probe_batch.to_pandas()
            
            # Vectorized join using numpy operations
            joined_df = self._vectorized_join(
                build_df, probe_df, hash_map, probe_key, left_is_build
            )
            
            if not joined_df.empty:
                output.append_data(joined_df)
            
            # Explicit memory cleanup
            del probe_df
            del joined_df
            gc.collect()
        
        del build_df
        del hash_map
        gc.collect()

    def _build_hash_map(self, file_path, key_column):
        """
        Build an optimized hash map using numpy arrays for better performance.
        Returns a dictionary mapping keys to numpy arrays of indices.
        """
        df = pq.read_table(file_path).to_pandas()
        
        # Group indices by key using pandas groupby (much faster than manual loop)
        grouped = df.reset_index().groupby(key_column)['index'].apply(np.array).to_dict()
        
        return grouped

    def _vectorized_join(self, build_df, probe_df, hash_map, probe_key, left_is_build):
        """
        Primary optimization using a vectorized join with vectorized join:
        1. Get probe keys and find indeces of matches in hash map  
        2. Establish a parrallel index for build and probe tables
        3. Build result from parallel indices using advanced pandas indexing
        """
        probe_keys = probe_df[probe_key].values
        
        build_indices = []
        probe_indices = []
        
        # Build index for build and probe tables
        for probe_idx, key in enumerate(probe_keys):
            if key in hash_map:
                build_idxs = hash_map[key]
                build_indices.extend(build_idxs)
                probe_indices.extend([probe_idx] * len(build_idxs))
        
        if not build_indices:
            return pd.DataFrame()
        
        build_indices = np.array(build_indices)
        probe_indices = np.array(probe_indices)
        
        # Build result using advanced indexing
        # advanced works as follows here:
        # build_df.iloc[build_indices] -> get rows from build_df where index is in build_indices
        # probe_df.iloc[probe_indices] -> get rows from probe_df where index is in probe_indices
        # these lists are parallel, meaning that the row in position i in build_indices is joined 
        # with the row in position i in probe_indices
        if left_is_build:
            left_result = build_df.iloc[build_indices].reset_index(drop=True)
            right_result = probe_df.iloc[probe_indices].reset_index(drop=True)
        else:
            left_result = probe_df.iloc[probe_indices].reset_index(drop=True)
            right_result = build_df.iloc[build_indices].reset_index(drop=True)
        
        # Drop duplicate columns from right side (keeping left)
        common_columns = set(left_result.columns) & set(right_result.columns)
        if common_columns:
            right_result = right_result.drop(columns=list(common_columns))

        result = pd.concat([left_result, right_result], axis=1)
        
        return result

    def _make_partition_path(self, output_dir, side, part_id):
        return f"{output_dir}/{side}_part{part_id}.parquet"

    @timer
    def _hash_partition(self, table: ColumnarDbFile, join_key, output_dir, side, columns=None):
        make_partition_path_fn = partial(self._make_partition_path, output_dir, side)
        hash_partition(table, [join_key], self.num_partitions, self.parquet_batch_size,
                       HASHVALUE, make_partition_path_fn, columns=columns)

In [86]:
%%memit

SIZE = "1GB" #["100MB", "1GB", "10GB"]
SAMPLE = 100
USE_STREAMING = True

songs_table = tables[f'Songs_{SIZE}']
listens_table = tables[f'Listens_{SIZE}']

# Select specific columns from each table
songs_cols = ['song_id', 'title']
listens_cols = ['listen_id', 'song_id', 'user_id']

# Create HashPartitionJoin instance
hpj1 = FastHashPartitionJoin(
    num_partitions=4, 
    parquet_batch_size=1000000,
    use_streaming=USE_STREAMING  
)

# Perform the join
result_songs_listens = hpj1.join(
    table1=songs_table,           
    table2=listens_table,         
    join_key1='song_id',          
    join_key2='song_id',          
    temp_dir='temp_songs_listens',
    columns_table1=songs_cols,    
    columns_table2=listens_cols   
)

result_df = result_songs_listens.retrieve_data(sample=None)



Method '_hash_partition' took 0.0606 seconds.
Method '_hash_partition' took 3.0541 seconds.
Method '_build_hash_map' took 0.2388 seconds.
Method '_build_hash_map' took 0.2324 seconds.
Method '_build_hash_map' took 0.2306 seconds.
Method '_build_hash_map' took 0.2346 seconds.
Method 'join' took 11.6230 seconds.
peak memory: 1486.02 MiB, increment: 1254.59 MiB


In [87]:
USE_STREAMING = True
SIZE = "100MB"
# Optional: Verify your implementation against pd.merge
def test_hash_partition_join_comprehensive():
    """
    Comprehensive test that validates both structure AND actual data values.
    This ensures the HPJ implementation is truly correct.
    """
    print("="*70)
    print("Comprehensive Hash Partition Join Test")
    print("="*70)
    
    all_tests_passed = True
    
    # Test: Songs JOIN Listens - FULL DATA VALIDATION
    print("\n" + "="*70)
    print("Test: Songs JOIN Listens ")
    print("="*70)
    
    songs_table = tables[f'Songs_{SIZE}']
    listens_table = tables[f'Listens_{SIZE}']
    
    songs_cols = ['song_id', 'title']
    listens_cols = ['listen_id', 'song_id', 'user_id']
    
    # Perform joins
    hpj1 = FastHashPartitionJoin(num_partitions=4, parquet_batch_size=100_000, use_streaming=USE_STREAMING)
    result_table1 = hpj1.join(
        songs_table, listens_table,
        join_key1='song_id', join_key2='song_id',
        temp_dir='temp_test_songs_listens_comp',
        columns_table1=songs_cols,
        columns_table2=listens_cols
    )
    
    hpj_result1 = result_table1.retrieve_data()
    
    # Get pd.merge result
    songs_df = songs_table.retrieve_data(columns=songs_cols)
    listens_df = listens_table.retrieve_data(columns=listens_cols)
    pd_result1 = pd.merge(songs_df, listens_df, on='song_id', how='inner')
    
    print(f"\nHPJ result shape: {hpj_result1.shape}")
    print(f"pd.merge result shape: {pd_result1.shape}")
    
    test1_passed = True
    
    # 1. Row count check
    if len(hpj_result1) != len(pd_result1):
        print(f"Row count mismatch -- HPJ: {len(hpj_result1)}, pd.merge: {len(pd_result1)}")
        test1_passed = False
        all_tests_passed = False
    else:
        print("Row counts match!")
    
    # 2. Column check
    hpj_cols = set(hpj_result1.columns)
    pd_cols = set(pd_result1.columns)
    if hpj_cols != pd_cols:
        print(f"Column mismatch -- HPJ: {hpj_cols}, pd.merge: {pd_cols}")
        test1_passed = False
        all_tests_passed = False
    else:
        print("Columns match!")
    
    if test1_passed:
        # 3. Sort both results for comparison
        sort_cols = ['song_id', 'listen_id'] if 'listen_id' in hpj_result1.columns else ['song_id']
        hpj_sorted = hpj_result1.sort_values(sort_cols).reset_index(drop=True)
        pd_sorted = pd_result1.sort_values(sort_cols).reset_index(drop=True)
        
        # 4. Check unique keys
        hpj_song_ids = set(hpj_result1['song_id'].unique())
        pd_song_ids = set(pd_result1['song_id'].unique())
        if hpj_song_ids != pd_song_ids:
            print(f"Unique song_ids differ!")
            test1_passed = False
            all_tests_passed = False
        else:
            print("Unique song_ids match!")
        
        # 5. FULL DATA VALUE COMPARISON - This is the critical check!
        print("\nPerforming full data value comparison...")
        data_matches = True
        
        # Compare each column
        for col in sorted(hpj_cols):
            hpj_col_data = hpj_sorted[col].values
            pd_col_data = pd_sorted[col].values
            
            # Use np.array_equal for exact comparison
            if not np.array_equal(hpj_col_data, pd_col_data):
                print(f"Column '{col}' data mismatch")
                
                # Find first mismatch
                mismatch_idx = np.where(hpj_col_data != pd_col_data)[0]
                if len(mismatch_idx) > 0:
                    idx = mismatch_idx[0]
                    print(f"  First mismatch at row {idx}:")
                    print(f"    HPJ: {hpj_col_data[idx]}")
                    print(f"    pd.merge: {pd_col_data[idx]}")
                    print(f"  Total mismatches: {len(mismatch_idx)}")
                
                data_matches = False
                break
        
        if data_matches:
            print("All data values match exactly!")
            print(f"Verified {len(hpj_sorted)} rows × {len(hpj_cols)} columns = {len(hpj_sorted) * len(hpj_cols)} values")
        else:
            print("✗ Data values do NOT match!")
            test1_passed = False
            all_tests_passed = False
        
        # 6. Check for duplicate rows (should be same in both)
        hpj_duplicates = hpj_sorted.duplicated().sum()
        pd_duplicates = pd_sorted.duplicated().sum()
        if hpj_duplicates != pd_duplicates:
            print(f"Duplicate row counts differ (HPJ: {hpj_duplicates}, pd.merge: {pd_duplicates})")
        else:
            print(f"Duplicate row counts match ({hpj_duplicates} duplicates)")
    
    if test1_passed:
        print("\n Test PASSED")
    else:
        print("\n Test FAILED!")
    
    # Summary
    print("\n" + "="*70)
    print("COMPREHENSIVE TEST SUMMARY")
    print("="*70)
    if all_tests_passed:
        print("✓ ALL TESTS PASSED: Hash Partition Join is CORRECT!")
        print("  - Row counts match")
        print("  - Column structure matches")
        print("  - Unique keys match")
        print("  - ALL DATA VALUES match exactly")
    else:
        print("✗ TESTS FAILED: Implementation has issues")
    print("="*70)
    
    return all_tests_passed

test_hash_partition_join_comprehensive()

Comprehensive Hash Partition Join Test

Test: Songs JOIN Listens 
Method '_hash_partition' took 0.0205 seconds.
Method '_hash_partition' took 0.3476 seconds.
Method '_build_hash_map' took 0.0250 seconds.
Method '_build_hash_map' took 0.0250 seconds.
Method '_build_hash_map' took 0.0271 seconds.
Method '_build_hash_map' took 0.0277 seconds.
Method 'join' took 2.0240 seconds.

HPJ result shape: (1000000, 4)
pd.merge result shape: (1000000, 4)
Row counts match!
Columns match!
Unique song_ids match!

Performing full data value comparison...
All data values match exactly!
Verified 1000000 rows × 4 columns = 4000000 values
Duplicate row counts match (0 duplicates)

 Test PASSED

COMPREHENSIVE TEST SUMMARY
✓ ALL TESTS PASSED: Hash Partition Join is CORRECT!
  - Row counts match
  - Column structure matches
  - Unique keys match
  - ALL DATA VALUES match exactly


True

Implement `SortMergeJoin`:
1. Sort both tables by join key
2. Merge sorted sequences
3. Handle duplicates

In [89]:
class ParquetGroupChunkIter:
    """
    Streams a Parquet file that is globally sorted by `join_key`,
    yielding chunks of consecutive rows with the same key.
    Each call to next_chunk() returns (table_chunk, key_value, is_last_for_key).
    Memory bounded by IN_BATCH_ROWS and GROUP_CHUNK_ROWS.
    """
    def __init__(self, file_path: str, columns, join_key: str,
                 in_batch_rows: int = 64_000, group_chunk_rows: int = 64_000):
        self._pf = pq.ParquetFile(file_path)
        self._cols = columns
        self._join_key = join_key
        self._key_idx = (self._cols.index(join_key) if self._cols is not None
                 else self._pf.schema_arrow.get_field_index(join_key))
        if self._key_idx == -1:
            raise ValueError(f"join key {join_key} not in schema")
        self._IN_BATCH_ROWS = in_batch_rows
        self._GROUP_CHUNK_ROWS = group_chunk_rows

        self._it = self._pf.iter_batches(columns=self._cols, batch_size=self._IN_BATCH_ROWS)
        self._batch = None
        self._i = 0
        self._cur_key = None
        self._eof = False

    def _next_batch(self):
        try:
            self._batch = next(self._it)
            self._i = 0
        except StopIteration:
            self._batch = None

    def _skip_nulls(self):
        while self._batch is not None:
            key_arr = self._batch.column(self._key_idx)
            n = self._batch.num_rows
            while self._i < n and key_arr[self._i].as_py() is None:
                self._i += 1
            if self._i < n:
                return
            self._next_batch()

    def next_chunk(self):
        if self._eof:
            return None, None, True

        if self._batch is None:
            self._next_batch()
            if self._batch is None:
                self._eof = True
                return None, None, True

        self._skip_nulls()
        if self._batch is None:
            self._eof = True
            return None, None, True

        if self._cur_key is None:
            self._cur_key = self._batch.column(self._key_idx)[self._i].as_py()

        parts = []
        collected = 0
        more_for_key = False

        while collected < self._GROUP_CHUNK_ROWS:
            if self._batch is None:
                break

            key_arr = self._batch.column(self._key_idx)
            n = self._batch.num_rows
            j = self._i

            # take rows while the key matches and we stay under the chunk cap
            while j < n and key_arr[j].as_py() == self._cur_key and \
                  (collected + (j - self._i)) < self._GROUP_CHUNK_ROWS:
                j += 1

            take = j - self._i
            if take > 0:
                parts.append(self._batch.slice(self._i, take))
                collected += take
                self._i = j

            if collected >= self._GROUP_CHUNK_ROWS:
                # capped chunk; if more rows with same key remain, mark continuation
                if self._i < n and key_arr[self._i].as_py() == self._cur_key:
                    more_for_key = True
                else:
                    # peek next batch start
                    save_batch, save_i = self._batch, self._i
                    self._next_batch()
                    if self._batch is not None:
                        k0 = self._batch.column(self._key_idx)[0].as_py()
                        more_for_key = (k0 == self._cur_key)
                    # restore for the caller to continue
                    self._batch, self._i = save_batch, save_i
                break

            if self._i >= n:
                # move to next batch and check if key continues
                self._next_batch()
                self._skip_nulls()
                if self._batch is None:
                    break
                if self._batch.column(self._key_idx)[self._i].as_py() != self._cur_key:
                    break
            else:
                # key changed within this batch
                break

        tbl = pa.Table.from_batches(parts) if parts else None

        # if no more rows remain for this key after this chunk, clear cur_key
        if not more_for_key:
            if self._batch is not None:
                key_arr = self._batch.column(self._key_idx)
                n = self._batch.num_rows
                while self._i < n and key_arr[self._i].as_py() == self._cur_key:
                    self._i += 1
                if self._i >= n:
                    self._next_batch()
            self._cur_key = None

        key_for_chunk = None
        if tbl is not None:
            key_for_chunk = tbl.column(self._key_idx).chunk(0)[0].as_py() \
                if tbl.num_rows > 0 else None

        return tbl, key_for_chunk, not more_for_key

In [88]:

BWAY_MERGE_FACTOR = 10

class SortMergeJoin:
    def __init__(
        self, bway_merge_factor: int = BWAY_MERGE_FACTOR, num_pages_per_split=100, verbose=False
    ):
        self.bway_merge_factor = bway_merge_factor
        self.num_pages_per_split = num_pages_per_split
        self.ascending = True
        self.tiebreak_key = None

        self.verbose = verbose

    @timer
    def _streaming_inner_join(self,
                            left_sorted: ColumnarDbFile,
                            right_sorted: ColumnarDbFile,
                            join_key1: str,
                            join_key2: str,
                            temp_dir: str,
                            columns_table1: Optional[List[str]],
                            columns_table2: Optional[List[str]],
                            in_batch_rows: int = 64_000,
                            group_chunk_rows: int = 64_000,
                            max_product_rows: int = 200_000) -> ColumnarDbFile:

        def ensure_key(cols, key):
            if cols is None:
                return None
            return cols if key in cols else cols + [key]

        # Compute input file paths
        left_path = os.path.join(left_sorted.base_file_name, f"{left_sorted.table_name}-0.parquet")
        right_path = os.path.join(right_sorted.base_file_name, f"{right_sorted.table_name}-0.parquet")

        # Load schemas to default to all columns if None
        l_pf = pq.ParquetFile(left_path)
        r_pf = pq.ParquetFile(right_path)
        l_schema = l_pf.schema_arrow
        r_schema = r_pf.schema_arrow

        if columns_table1 is None:
            columns_table1 = l_schema.names
        if columns_table2 is None:
            columns_table2 = r_schema.names
        columns_table1 = ensure_key(columns_table1, join_key1)
        columns_table2 = ensure_key(columns_table2, join_key2)

        # Output: keep left key; drop duplicate key from right
        left_out_cols = columns_table1
        right_out_cols = [c for c in columns_table2 if c != join_key2]

        # Prepare iterators
        L = ParquetGroupChunkIter(left_path, columns_table1, join_key1,
                                    in_batch_rows, group_chunk_rows)
        R = ParquetGroupChunkIter(right_path, columns_table2, join_key2,
                                    in_batch_rows, group_chunk_rows)

        # Output writer
        result_table = ColumnarDbFile("join_result", file_dir=temp_dir)
        out_path = os.path.join(result_table.base_file_name, f"{result_table.table_name}-0.parquet")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        writer = None

        def write_join_product(l_tbl: pa.Table, r_tbl: pa.Table):
            nonlocal writer
            if l_tbl is None or r_tbl is None:
                return
            l_df = l_tbl.select(left_out_cols).to_pandas()
            r_df = r_tbl.select(right_out_cols).to_pandas()
            m, n = len(l_df), len(r_df)
            if m == 0 or n == 0:
                return
            block_n = max(1, min(n, max_product_rows // max(1, m)))
            start = 0
            while start < n:
                end = min(n, start + block_n)
                r_block = r_df.iloc[start:end]
                out_df = l_df.assign(_cj=1).merge(r_block.assign(_cj=1), on="_cj").drop(columns=["_cj"])
                out_tbl = pa.Table.from_pandas(out_df, preserve_index=False)
                if writer is None:
                    writer = pq.ParquetWriter(out_path, schema=out_tbl.schema)
                writer.write_table(out_tbl)
                start = end

        # Drive the SMJ
        l_tbl, l_key, l_last = L.next_chunk()
        r_tbl, r_key, r_last = R.next_chunk()

        while l_tbl is not None and r_tbl is not None:
            if l_key is None:
                l_tbl, l_key, l_last = L.next_chunk()
                continue
            if r_key is None:
                r_tbl, r_key, r_last = R.next_chunk()
                continue

            if l_key < r_key:
                l_tbl, l_key, l_last = L.next_chunk()
            elif l_key > r_key:
                r_tbl, r_key, r_last = R.next_chunk()
                
            else:
                # l_key == r_key: join full groups. Cache right group (spill if large),
                # then stream left group chunk-by-chunk against the cached/spilled right.
                MAX_CACHED_GROUP_ROWS = 500_000  # tune as you like

                r_chunks: List[pa.Table] = []
                r_rows = 0
                r_spill_path = None
                r_writer = None

                def _open_r_spill(schema: pa.Schema):
                    nonlocal r_writer, r_spill_path
                    if r_writer is None:
                        r_spill_path = os.path.join(
                            temp_dir, f"_smj_rspill_{uuid.uuid4().hex}.parquet"
                        )
                        r_writer = pq.ParquetWriter(r_spill_path, schema=schema)

                def _cache_r_chunk(tbl: pa.Table):
                    nonlocal r_rows, r_writer
                    if tbl is None or tbl.num_rows == 0:
                        return
                    if r_writer is None and (r_rows + tbl.num_rows) <= MAX_CACHED_GROUP_ROWS:
                        r_chunks.append(tbl)
                    else:
                        if r_writer is None:
                            _open_r_spill(tbl.schema)
                            for t in r_chunks:
                                r_writer.write_table(t)
                            r_chunks.clear()
                        r_writer.write_table(tbl)
                    r_rows += tbl.num_rows

                # 1) Collect entire right group for this key
                _cache_r_chunk(r_tbl)
                while not r_last:
                    r_tbl, r_key2, r_last = R.next_chunk()
                    if r_tbl is None or r_key2 != l_key:
                        break
                    _cache_r_chunk(r_tbl)
                if r_writer is not None:
                    r_writer.close()

                def right_iter():
                    if r_spill_path is not None:
                        pf = pq.ParquetFile(r_spill_path)
                        for b in pf.iter_batches(batch_size=group_chunk_rows):
                            yield pa.Table.from_batches([b])
                    else:
                        for t in r_chunks:
                            yield t

                # 2) Stream left group; for each left chunk, replay right group
                while True:
                    for r_part in right_iter():
                        write_join_product(l_tbl, r_part)
                    if l_last:
                        break
                    l_tbl, l_key2, l_last = L.next_chunk()
                    if l_tbl is None or l_key2 != l_key:
                        break

                # 3) Advance right iterator to next key (we ended exactly at its boundary)
                r_tbl, r_key, r_last = R.next_chunk()
                # proceed to next loop iteration
                continue

        if writer is not None:
            writer.close()
        return result_table

    def _flush_run(
        self,
        dfs: List[pd.DataFrame],
        join_key: str,
        output_dir: str,
        side: str,
        run_idx: int,
        ascending: bool = True,
    ) -> str:

        df_run = pd.concat(dfs, ignore_index=True)

        # Sort by join_key, then tiebreak if set, tiebreak key is always sorted asceding
        if self.tiebreak_key and self.tiebreak_key in df_run.columns:
            df_run_sorted = df_run.sort_values(by=[join_key, self.tiebreak_key], ascending=[ascending, True]) 
        else:
            df_run_sorted = df_run.sort_values(by=join_key, ascending=ascending)

        run_file = os.path.join(output_dir, f"{side}_run_{run_idx}.parquet")
        df_run_sorted.to_parquet(run_file)

        dfs.clear()
        del df_run, df_run_sorted
        gc.collect()

        return run_file

    @timer
    def _external_sort(
        self,
        table: ColumnarDbFile,
        join_key: str,
        output_dir: str,
        side: str,
        columns: Optional[List[str]] = None,
        ascending: bool = True,
        tiebreak_key = None
    ) -> ColumnarDbFile:
        """
        Perform an external sort on a table based on the join key and return a sorted ColumnarDbFile.
        Use _bway_merge to merge sorted files
        """
        self.tiebreak_key = tiebreak_key

        # Get table size (on disk)
        disk_usage = table.table_disk_usage()
        total_bytes = disk_usage["total_bytes"]

        # Check if we can safely process in 12 GB RAM
        if table.can_process_parquet(total_bytes):

            # read data in and sort all in RAM
            df = table.retrieve_data(columns=columns)

            if self.tiebreak_key and self.tiebreak_key in df.columns:
                df_sorted = df.sort_values(by=[join_key, self.tiebreak_key], ascending=[ascending, True]).reset_index(drop=True)
            else:
                df_sorted = df.sort_values(by=join_key, ascending=ascending).reset_index(drop=True)

            # create paraquet in output dir for the table
            sorted_name = f"{side}_{table.table_name}_sorted"
            sorted_table = ColumnarDbFile(sorted_name, file_dir=output_dir)
            sorted_table.build_table(df_sorted)

            # clean unnecessary overhead and return table
            del df, df_sorted
            gc.collect()
            return sorted_table

        else:
            if self.verbose:
                print("sorting table ", table.table_name, "with ", total_bytes, "bytes using external sort")
                print("GBs : ", total_bytes / (1024 * 1024 * 1024))
            # Get list of parquet files in the table directory
            parquet_files = glob.glob(f"{table.base_file_name}/*.parquet")

            runs_path: List[str] = []
            run_idx = 0
            current_dfs: List[pd.DataFrame] = []
            current_row_groups = 0

            # loop through all the parquet files
            if self.verbose:
                print(f"Sorting {len(parquet_files)} files")
            for file in parquet_files:
                pf = pq.ParquetFile(file)

                # safe bounded unit of work for sorting
                num_row_groups = pf.metadata.num_row_groups

                for rg in range(num_row_groups):
                    if self.verbose:
                        print("reading row group ", rg)

                    # read a row group as a chunk
                    batch = pf.read_row_group(rg, columns=columns)
                    df_chunk = batch.to_pandas()
                    current_dfs.append(df_chunk)
                    current_row_groups += 1

                    # treating a row group as a page
                    # change to 2 to practice spliting runs
                    if current_row_groups > self.num_pages_per_split:
                        
                        if self.verbose:
                            print("flushing run ", run_idx)
                        # sort current run in ram
                        # save as a parquet file in the current run directory
                        run_file = self._flush_run(
                        current_dfs, join_key, output_dir, side, run_idx, ascending=ascending
                        )

                        # runs path is a list of the sorted parquet files
                        runs_path.append(run_file)
                        run_idx += 1
                        current_row_groups = 0
                        if self.verbose:
                            print(f"Flushed run {run_idx} at {run_file}")

            # flush remaining partial run, get sorted run_file as a df
            if current_dfs:
              run_file = self._flush_run(
                  current_dfs, join_key, output_dir, side, run_idx, ascending=ascending
              )
              runs_path.append(run_file)
            # output_dir has a list of sorted parquet files for the current table

            # Create the wrapper first so we write where it will read
            sorted_table = ColumnarDbFile(
                table_name=f"{side}_{table.table_name}_sorted",
                file_dir=output_dir,
            )

            # Write the final merged file inside that directory, matching ColumnarDbFile
            final_sorted_path = os.path.join(
                sorted_table.base_file_name, f"{sorted_table.table_name}-0.parquet"
            )
            if self.verbose:
                print("merging all runs into ", final_sorted_path)
            self._merge_all_runs(runs_path, final_sorted_path, join_key, ascending=ascending)

            return sorted_table

    @timer
    def _merge_all_runs(self, sorted_files: List[str], output_file: str, join_key: str, ascending: bool = True) :
        """
        Merge multiple sorted Parquet files into a single sorted Parquet file.
        """
        B = self.bway_merge_factor

        # copy that we will mutate
        runs = list(sorted_files)
        pass_idx = 0

        # while list of runs left to merge
        while len(runs) > 1:
          if self.verbose:
            print("merging pass ", pass_idx)
          next_runs = []

          # B - 1 input buffers +1 output buffer
          for i in range(0, len(runs), B - 1):
                batch = runs[i : i + (B - 1)]   # B-1 input buffers, still a list

                # choose an output path for this merged batch
                # on the final pass, we want the result at `output_file`
                if len(runs) <= B - 1 and len(next_runs) == 0:
                    # last pass, first (and only) merged run -> final output
                    merged_path = output_file
                else:
                    # intermediate pass: write to a temp run file
                    base_dir = os.path.dirname(output_file)
                    merged_path = os.path.join(
                        base_dir,
                        f"bway_pass{pass_idx}_run{len(next_runs)}.parquet",
                    )

                # b-way merge this batch into merged_path
                self._bway_merge(batch, merged_path, join_key, ascending=ascending)

                next_runs.append(merged_path)

          runs = next_runs
          pass_idx += 1

        # At this point, runs has exactly one file.
        final_run = runs[0]
        if final_run != output_file:
            # In case we didn't land exactly on output_file path
            if os.path.exists(output_file):
                os.remove(output_file)
            shutil.move(final_run, output_file)

        return output_file


    def _bway_merge(self, sorted_files: List[str], output_file: str, join_key: str, ascending: bool = True):
        """
        Streaming B-way merge of already-sorted Parquet 'runs' into a single
        sorted Parquet file at `output_file`, ordered by `join_key`.
        Memory use is bounded: keeps one small batch per input + an output buffer.
        """
        if not sorted_files:
            raise ValueError("No input files to merge.")

        # Tunables: keep these moderate to bound memory
        IN_BATCH_ROWS = 64_000         # rows read per input batch
        OUT_ROW_GROUP_ROWS = 128_000   # rows written per output row group

        # Determine common schema and key index from first file
        first_pf = pq.ParquetFile(sorted_files[0])
        schema = first_pf.schema_arrow
        key_index = schema.get_field_index(join_key)
        if key_index == -1:
            raise ValueError(f"join_key '{join_key}' not found in schema.")

        # Initialize streaming readers (preserve on-disk order)
        readers = [pq.ParquetFile(p).iter_batches(batch_size=IN_BATCH_ROWS) for p in sorted_files]

        # Per-input state
        states = []  # list of dicts: { 'iter': iterator, 'batch': RecordBatch|None, 'ridx': int }
        heap = []    # min-heap of ((is_null, key_value), src_idx)

        # Get tiebreak index if set
        tiebreak_index = None
        if self.tiebreak_key:
            tiebreak_index = schema.get_field_index(self.tiebreak_key)

        def _prime_state(it):
            try:
                return next(it)
            except StopIteration:
                return None

        def _push_heap_for_src(src_idx):
            st = states[src_idx]
            batch = st["batch"]
            ridx = st["ridx"]
            if batch is None or ridx >= batch.num_rows:
                return
            key_arr = batch.column(key_index)
            key_val = key_arr[ridx].as_py()
            
            # Get tiebreaker value if set
            tie_val = None
            if tiebreak_index is not None:
                tie_arr = batch.column(tiebreak_index)
                tie_val = tie_arr[ridx].as_py()
            
            # Compute heap key based on sort direction
            if self.ascending:
                heap_key = (key_val is None, key_val, tie_val)
            else:
                # DESC on primary key, ASC on tiebreaker
                heap_key = (key_val is not None, 
                        -key_val if isinstance(key_val, (int, float)) else key_val,
                        tie_val)
            
            heapq.heappush(heap, (heap_key, src_idx))

        # Prime batches and heap
        for it in readers:
            batch = _prime_state(it)
            states.append({"iter": it, "batch": batch, "ridx": 0})
        for i in range(len(states)):
            if states[i]["batch"] is not None and states[i]["batch"].num_rows > 0:
                _push_heap_for_src(i)

        # Open output writer
        writer = pq.ParquetWriter(output_file, schema=schema)
        out_batches: List[pa.RecordBatch] = []
        out_rows = 0

        # Core streaming merge loop
        while heap:
            _, src_idx = heapq.heappop(heap)
            st = states[src_idx]
            batch = st["batch"]
            ridx = st["ridx"]

            # Append a single-row slice to output buffer
            out_batches.append(batch.slice(ridx, 1))
            out_rows += 1

            # Advance this source
            st["ridx"] += 1
            if st["ridx"] >= batch.num_rows:
                st["batch"] = _prime_state(st["iter"])
                st["ridx"] = 0
            if st["batch"] is not None:
                _push_heap_for_src(src_idx)

            # Flush output buffer as a row group when big enough
            if out_rows >= OUT_ROW_GROUP_ROWS:
                writer.write_table(pa.Table.from_batches(out_batches))
                out_batches.clear()
                out_rows = 0

        # Final flush
        if out_rows > 0:
            writer.write_table(pa.Table.from_batches(out_batches))
        writer.close()

    @timer
    def join(
        self,
        table1: ColumnarDbFile,
        table2: ColumnarDbFile,
        join_key1: str,
        join_key2: str,
        temp_dir: str = "temp",
        columns_table1: Optional[List[str]] = None,
        columns_table2: Optional[List[str]] = None,
        ASC: bool = True,
    ) -> Optional[ColumnarDbFile]:
        """
        Perform a sort-merge join between two ColumnarDbFile instances and return a sorted ColumnarDbFile.
        """
        os.makedirs(temp_dir, exist_ok=True)
        self.ascending = ASC

        # Sort both tables externally
        sorted_table1 = self._external_sort(
            table1, join_key1, temp_dir, "left", columns_table1, ascending=ASC
        )
        sorted_table2 = self._external_sort(
            table2, join_key2, temp_dir, "right", columns_table2, ascending=ASC
        )

        result_table = self._streaming_inner_join(
            left_sorted=sorted_table1,
            right_sorted=sorted_table2,
            join_key1=join_key1,
            join_key2=join_key2,
            temp_dir=temp_dir,
            columns_table1=columns_table1,
            columns_table2=columns_table2,
        )

        return result_table

In [61]:
%%memit
SIZE = "1GB"
songs_table = tables[f'Songs_{SIZE}']
users_table = tables[f'Users_{SIZE}']
listens_table = tables[f'Listens_{SIZE}']

smj = SortMergeJoin(num_pages_per_split=1000, verbose=False)

# Example: join Songs with Listens on song_id
sorted_join_result = smj.join(
    songs_table,
    listens_table,
    join_key1="song_id",
    join_key2="song_id",
    temp_dir="temp_songs_listens",
    columns_table1= ["song_id", "title"],
    columns_table2= ["song_id", "user_id"]
)

display(sorted_join_result)


Method '_external_sort' took 0.0553 seconds.
Method '_external_sort' took 1.3795 seconds.
Method '_streaming_inner_join' took 91.9589 seconds.
Method 'join' took 93.4050 seconds.


<__main__.ColumnarDbFile at 0x362f2d8e0>

peak memory: 1144.31 MiB, increment: 1004.33 MiB


In [66]:
# Correctness test vs pandas on a manageable subset
SIZE = "100MB"

# Choose columns and sample sizes
JOIN_KEY_LEFT = "song_id"
JOIN_KEY_RIGHT = "song_id"
LEFT_COLS = ["song_id", "title"]
RIGHT_COLS = ["song_id", "user_id"]

N_LEFT = 200_000
N_RIGHT = 200_000

# 1) Create temporary, smaller ColumnarDbFiles from the big tables
TEST_DIR = "temp_smj_correctness"
if os.path.exists(TEST_DIR):
    shutil.rmtree(TEST_DIR)
os.makedirs(TEST_DIR, exist_ok=True)

left_df_full = tables[f"Songs_{SIZE}"].retrieve_data(columns=LEFT_COLS)
right_df_full = tables[f"Listens_{SIZE}"].retrieve_data(columns=RIGHT_COLS)

left_df = left_df_full.head(N_LEFT).reset_index(drop=True)
right_df = right_df_full.head(N_RIGHT).reset_index(drop=True)

left_tmp = ColumnarDbFile("LeftTest", file_dir=TEST_DIR)
right_tmp = ColumnarDbFile("RightTest", file_dir=TEST_DIR)
left_tmp.build_table(left_df)
right_tmp.build_table(right_df)

# 2) Run your streaming SMJ on the temp tables
smj = SortMergeJoin()
smj_out = smj.join(
    left_tmp,
    right_tmp,
    join_key1=JOIN_KEY_LEFT,
    join_key2=JOIN_KEY_RIGHT,
    temp_dir=os.path.join(TEST_DIR, "join_out"),
    columns_table1=LEFT_COLS,
    columns_table2=RIGHT_COLS,
)

# 3) Materialize results
# SMJ output keeps only the left join key (your implementation drops the right duplicate)
smj_df = smj_out.retrieve_data()

# Pandas baseline: merge then drop the duplicate right key to match SMJ schema
baseline_df = pd.merge(
    left_df,
    right_df,
    left_on=JOIN_KEY_LEFT,
    right_on=JOIN_KEY_RIGHT,
    how="inner",
)
baseline_df = baseline_df[LEFT_COLS + [c for c in RIGHT_COLS if c != JOIN_KEY_RIGHT]]

# 4) Normalize order and compare
order_cols = [JOIN_KEY_LEFT] + [c for c in LEFT_COLS if c != JOIN_KEY_LEFT] + [c for c in RIGHT_COLS if c != JOIN_KEY_RIGHT]
smj_df = smj_df.sort_values(by=order_cols).reset_index(drop=True)
baseline_df = baseline_df.sort_values(by=order_cols).reset_index(drop=True)

# Ensure identical dtypes for fair compare when feasible
for c in order_cols:
    if c in smj_df.columns and c in baseline_df.columns:
        try:
            baseline_df[c] = baseline_df[c].astype(smj_df[c].dtype)
        except Exception:
            pass

# 5) Assertions and quick diagnostics
same_shape = smj_df.shape == baseline_df.shape
same_rows = smj_df.equals(baseline_df)

print("Row counts - SMJ vs Pandas:", smj_df.shape[0], baseline_df.shape[0])
print("Column sets equal:", set(smj_df.columns) == set(baseline_df.columns))
print("Shapes equal:", same_shape)
print("Frames equal:", same_rows)

if not same_rows:
    # Show a small diff preview
    merged_chk = smj_df.merge(
        baseline_df, how="outer", indicator=True, on=order_cols
    )
    print("Mismatched samples:")
    display(merged_chk[merged_chk["_merge"] != "both"].head(10))

Method '_external_sort' took 0.0625 seconds.
Method '_external_sort' took 0.0451 seconds.
Method '_streaming_inner_join' took 8.8478 seconds.
Method 'join' took 8.9557 seconds.
Row counts - SMJ vs Pandas: 200000 200000
Column sets equal: True
Shapes equal: True
Frames equal: True


Implement GROUP BY after joins:
- Here you could use `pd.groupby` or do manual aggregation

In [90]:
# GROUP BY s.song_id, s.title
class HashGroupbyAverageAndDistinct():
    def __init__(self, num_partitions, parquet_batch_size, use_streaming=False):
        self.num_partitions = num_partitions
        self.parquet_batch_size = parquet_batch_size
        self.use_streaming = use_streaming
    
    def _make_partition_path(self, temp_dir, part_id):
        return os.path.join(temp_dir, f"group_part{part_id}.parquet")
    
    @timer
    def groupby_average_distinct(self,
                        table: ColumnarDbFile, 
                        groupby_cols: List[str],
                        average_col: str, 
                        average_col_name: str, 
                        distinct_col: str,
                        distinct_col_name: str,
                        select_cols: List[str], 
                        temp_dir='groupby_temp') -> ColumnarDbFile:
        """
        Perform:
            SELECT select_cols..., AVG(average_col) AS average_col_name, COUNT(DISTINCT distinct_col)
            FROM table
            GROUP BY groupby_cols...
        
        Hash partitioning on (concatenation of) groupby_cols, and then in-memory aggregation per partition

        Assumptions:
        - groupby_col is non-empty
        - select_cols is a subset of groupby_col
        - Per-partition hash table fits in memory
        
        Uses self.use_streaming to determine whether to use ParquetWriter streaming
        for efficient batch writes.
        """
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        
        os.makedirs(temp_dir, exist_ok=True)
        # hash on groupby columns (safe: same group always shares this)

        hash_partition(
            table=table,
            hash_keys=groupby_cols,  # hash on groupby columns (safe: same group always shares this)
            num_partitions=self.num_partitions,
            parquet_batch_size=self.parquet_batch_size,
            hash_value_fn=HASHVALUE,
            make_partition_path_fn=partial(self._make_partition_path, temp_dir),
            columns= list(set(groupby_cols + [average_col, distinct_col])),
        )

        output_db = ColumnarDbFile(f"{table.table_name}_groupby_avg")
        
        # Clean up any existing files in the output directory
        if os.path.exists(output_db.base_file_name):
            for file_path in output_db.get_all_parquet_paths():
                os.remove(file_path)
        
        # Start streaming if enabled
        if self.use_streaming:
            output_db.start_stream()
        for part_id in range(self.num_partitions):
            part_path = self._make_partition_path(temp_dir, part_id)
            if not os.path.exists(part_path):
                continue

            # In-memory hash table for this partition:
            # key: tuple of groupby_col values
            # value: Tuple of (sum of average_col, count of average_col, set of distinct distinct_col values)
            SUM_IDX = 0
            COUNT_IDX = 1
            DISTINCT_SET_IDX = 2
            agg_map: Dict[Any, Tuple[float, int, set]] = {}

            pf = pq.ParquetFile(part_path)
            for batch in pf.iter_batches(batch_size=self.parquet_batch_size):
                df = batch.to_pandas()

                grouped = (
                    df.groupby(groupby_cols)
                    .agg(
                        sum_avg=(average_col, "sum"),
                        cnt_avg=(average_col, "count"),  # SQL AVG ignores NULLs
                        distinct_set=(distinct_col, lambda s: set(s.dropna()))  # SQL ignores NULL in COUNT DISTINCT
                    )
                )

                for key_tuple, row in grouped.iterrows():
                    if not isinstance(key_tuple, tuple):
                        key_tuple = (key_tuple,)

                    state = agg_map.setdefault(key_tuple, [0.0, 0, set()])
                    state[SUM_IDX] += row["sum_avg"]
                    state[COUNT_IDX] += row["cnt_avg"]
                    state[DISTINCT_SET_IDX] |= row["distinct_set"]


            # Turn the per-partition hash table into a DataFrame and append
            if agg_map:
                # Pre-compute column index mapping to avoid repeated index() calls
                col_idx_map = {col: groupby_cols.index(col) for col in select_cols}
                
                # Build output efficiently using list comprehensions
                out_rows = []
                for key_tuple, state in agg_map.items():
                    # Ensure key_tuple is a tuple (handles single vs multi-column)
                    if not isinstance(key_tuple, tuple):
                        key_tuple = (key_tuple,)
                    
                    row_dict = {col: key_tuple[col_idx_map[col]] for col in select_cols}
                    row_dict[average_col_name] = state[SUM_IDX] / state[COUNT_IDX]
                    row_dict[distinct_col_name] = len(state[DISTINCT_SET_IDX])
                    out_rows.append(row_dict)

                if out_rows:
                    out_df = pd.DataFrame(out_rows)
                    output_db.append_data(out_df)

        # Stop streaming if it was enabled
        if self.use_streaming:
            output_db.stop_stream()
        shutil.rmtree(temp_dir)

        return output_db

In [92]:
%%memit
# test implementation

SIZE = "100MB" #["100MB", "1GB", "10GB"]
SAMPLE = 100
USE_STREAMING = False

listens_table = tables[f'Listens_{SIZE}']

groupby_average_distinct = HashGroupbyAverageAndDistinct(
    num_partitions=4,
    parquet_batch_size=10000000,
    use_streaming=USE_STREAMING
)

results = groupby_average_distinct.groupby_average_distinct(
    table=listens_table,
    groupby_cols=['song_id'],
    average_col='user_id',
    average_col_name='avg_user_id',
    distinct_col='user_id',
    distinct_col_name='distinct_user_id',
    select_cols=['song_id']
)

display(results.retrieve_data(sample=None))

Method 'groupby_average_distinct' took 0.6927 seconds.


Unnamed: 0,song_id,avg_user_id,distinct_user_id
0,0,23262.696078,101
1,4,25417.038835,103
2,8,22609.904762,105
3,12,24474.144330,97
4,16,22391.166667,83
...,...,...,...
9995,9983,23652.714286,98
9996,9987,25471.716981,106
9997,9991,24112.888889,99
9998,9995,26495.407407,108


peak memory: 1119.39 MiB, increment: 155.36 MiB


# Section 4: Query Planning & Optimization

In this section, you'll implement smart query planning using metadata analysis. The key idea is to **avoid loading data unnecessarily** by:
1. Analyzing Parquet metadata first (row counts, column names, file sizes)
2. Making intelligent decisions about join order and algorithm selection
3. Loading only the columns you actually need for the query

In [69]:
def analyze_metadata_before_loading(file_paths):
    
    metadata = {}
    for table_name, coldb in file_paths.items():
        # Use provided helpers for counts/sizes
        tm = coldb.table_metadata() # num_files, total_rows, total_compressed_bytes
        du = coldb.table_disk_usage() # num_files, total_bytes (on disk)

        # Get schema from the first parquet file
        base_dir = coldb.base_file_name
        parquet_files = sorted(glob.glob(f"{base_dir}/*.parquet"))
        columns = {}
        if parquet_files:
            pf = pq.ParquetFile(parquet_files[0])
            # Use Arrow schema to read field names and types
            columns = {field.name: str(field.type) for field in pf.schema_arrow}

        metadata[table_name] = {
            "num_files": tm["num_files"],
            "rows": int(tm["total_rows"]),
            "columns": columns,
            "bytes_on_disk": int(du["total_bytes"]),
            "total_compressed_bytes": int(tm["total_compressed_bytes"]),
            "can_process_in_12GB": ColumnarDbFile.can_process_parquet(int(du["total_bytes"]))
        }

    return metadata


def plan_query_execution(metadata, parsed_query, memory_bytes=12 * 1024**3, overhead_per_row=24):
    """
    Use parsed SQL + table metadata to:
      - pick columns to read (column pruning)
      - choose join order (follow SQL joins; build on smaller side)
      - choose algorithm per step (HPJ if build fits in memory, else SMJ)
    Returns a plan dict.
    """
    # 0) Helpers
    def meta_key_for(table_name_lower: str) -> str:
        #  metadata keys are capitalized: 'Songs','Users','Listens'
        return table_name_lower.capitalize()

    def _size_of_type(t: str) -> int:
        t = (t or "").lower()
        if "int64" in t or "float64" in t or "double" in t or "timestamp" in t: return 8
        if "int32" in t or "float32" in t: return 4
        if "bool" in t: return 1
        return 16  # fallback for strings

    def _estimate_build_bytes(rows: int, key_type: str, payload_bytes=0, overhead=48, load=1.3):
        per_row = _size_of_type(key_type) + payload_bytes + overhead
        return int(rows * per_row * load)

    def _size_of_type(t: str) -> int:
        t = (t or "").lower()
        if "int64" in t or "float64" in t or "double" in t or "timestamp" in t: return 8
        if "int32" in t or "float32" in t: return 4
        if "bool" in t: return 1
        return 16  # fallback for strings

    def _estimate_build_bytes(rows: int, key_type: str, payload_bytes=0, overhead=48, load=1.3):
        per_row = _size_of_type(key_type) + payload_bytes + overhead
        return int(rows * per_row * load)

    def _io_costs_hpjsmj(bytes_L, bytes_R, mem_bytes, page_size=64*1024*1024, Cr=1, Cw=1):
        # Pages
        P_L = math.ceil(bytes_L / page_size)
        P_R = math.ceil(bytes_R / page_size)
        # Buffers (RAM pages)
        B = max(2, mem_bytes // page_size)

        # HPJ cost (special case Cr=Cw=1): 3*(P(L)+P(R))
        hpj_ios = 3 * (P_L + P_R)

        # External BigSort cost per table (simplified):
        # 2*N * (1 + ceil(log_{B-1}(N))) with B buffers
        def bigsort_cost(P):
            base = max(2, B - 1)
            passes = max(1, math.ceil(math.log(max(1, P), base)))
            return 2 * P * (1 + passes)

        # SMJ = BigSort(L)+BigSort(R)+Merge(L,R)
        smj_ios = bigsort_cost(P_L) + bigsort_cost(P_R) + (P_L + P_R)  # merge read

        return hpj_ios, smj_ios

    def decide_join_algo(meta, left_tbl, right_tbl, left_key, right_key,
                    rows_left=None, rows_right=None,
                    mem_budget_bytes=12*1024**3, overhead=48,
                    left_already_sorted=False, right_already_sorted=False,
                    need_sorted_output=False):
        """
        Choose join algorithm based on:
        1. If smaller table fits in memory (5× expansion) --> HPJ
        2. Otherwise --> SMJ
        """
        # Sizes and rows
        L = left_tbl
        R = right_tbl
        rows_L = rows_left  if rows_left is not None else int(meta[L]["rows"])
        rows_R = rows_right if rows_right is not None else int(meta[R]["rows"])
        tL = meta[L]["columns"].get(left_key, "int64")
        tR = meta[R]["columns"].get(right_key, "int64")
        
        bytes_L = int(meta[L]["bytes_on_disk"])
        bytes_R = int(meta[R]["bytes_on_disk"])
        
        # Determine smaller table
        smaller_size = min(bytes_L, bytes_R)
        smaller_table = L if bytes_L <= bytes_R else R
        larger_table = R if smaller_table == L else L
        
        # Check if smaller table fits in memory for HPJ (5× expansion factor for pandas)
        estimated_ram_needed = smaller_size * 5
        fits_in_memory = estimated_ram_needed < mem_budget_bytes
    
        # If smaller table fits → HPJ is faster
        if fits_in_memory:
            algo = "HPJ"
            build_tbl = smaller_table
            build_side = "L" if smaller_table == L else "R"
            
            # Compute I/O costs for reporting
            build_L = _estimate_build_bytes(rows_L, tL, overhead=overhead)
            build_R = _estimate_build_bytes(rows_R, tR, overhead=overhead)
            
            P_L = math.ceil(bytes_L / (64*1024*1024))
            P_R = math.ceil(bytes_R / (64*1024*1024))
            hpj_ios, smj_ios = _io_costs_hpjsmj(bytes_L, bytes_R, mem_budget_bytes)
            smj_ios = None  # not computed since we chose HPJ
            
            return {
                "algorithm": algo,
                "build_side": build_side,
                "build_table": build_tbl,
                "hpj_ios": hpj_ios,
                "smj_ios": smj_ios,
                "hpj_build_bytes": {"L": build_L, "R": build_R},
                "estimated_ram_needed": estimated_ram_needed,
                "fits_in_memory": True,
            }
        
        # Else: SMJ for limited memory
        
        else:
            algo = "SMJ"
            
            # Compute I/O costs
            P_L = math.ceil(bytes_L / (64*1024*1024))
            P_R = math.ceil(bytes_R / (64*1024*1024))
            B = max(2, mem_budget_bytes // (64*1024*1024))
            
            # BigSort cost per table (accounts for already-sorted data)
            def bigsort_cost(P, already_sorted=False):
                if already_sorted:
                    return 0  # no sort needed, just read once
                base = max(2, B - 1)
                passes = max(1, math.ceil(math.log(max(1, P), base)))
                return 2 * P * (1 + passes)
            
            sort_cost_L = bigsort_cost(P_L, left_already_sorted)
            sort_cost_R = bigsort_cost(P_R, right_already_sorted)
            
            # SMJ total: sort both + merge (read both once)
            smj_ios = sort_cost_L + sort_cost_R + (P_L + P_R)
            
            # HPJ cost for comparison: 3*(P(L)+P(R)) when Cr=Cw=1
            hpj_ios = 3 * (P_L + P_R)
            
            # Optionally prefer SMJ if output needs to be sorted for GROUP BY
            if need_sorted_output:
                smj_ios *= 0.9  # slight bias toward SMJ
            
            build_L = _estimate_build_bytes(rows_L, tL, overhead=overhead)
            build_R = _estimate_build_bytes(rows_R, tR, overhead=overhead)
            
            return {
                "algorithm": algo,
                "build_side": None,  # SMJ has no build side
                "build_table": None,
                "hpj_ios": hpj_ios,
                "smj_ios": smj_ios,
                "hpj_build_bytes": {"L": build_L, "R": build_R},
                "estimated_ram_needed": estimated_ram_needed,
                "fits_in_memory": False,
                "left_already_sorted": left_already_sorted,
                "right_already_sorted": right_already_sorted,
            }
            
    def _ndv_guess(meta, tbl, col, coverage=0.9):
        """Estimate distinct values for a column in a base table."""
        if tbl == "Songs" and col == "song_id": return int(meta["Songs"]["rows"])
        if tbl == "Users" and col == "user_id": return int(meta["Users"]["rows"])
        if tbl == "Listens" and col == "song_id": return int(min(meta["Songs"]["rows"], meta["Listens"]["rows"]) * coverage)
        if tbl == "Listens" and col == "user_id": return int(min(meta["Users"]["rows"], meta["Listens"]["rows"]) * coverage)
        return int(meta[tbl]["rows"])

    def _rows_after_join(left_rows, right_rows, ndvL, ndvR, left_unique=False, right_unique=False):
        """Estimate output rows for an equi-join."""
        if left_unique and not right_unique:
            return right_rows  # PK-FK: output size = FK side
        if right_unique and not left_unique:
            return left_rows
        # Generic: rows_out ≈ (rows_L × rows_R) / max(ndv_L, ndv_R)
        denom = max(1, max(ndvL, ndvR))
        est = (left_rows * right_rows) // denom
        # Bound between max(rows) and rows_L × rows_R
        return max(min(est, left_rows * right_rows), max(left_rows, right_rows))

    # 1) Alias: table-name mappings
    alias_to_table_lower = parsed_query["tables"]
    alias_to_meta = {a: meta_key_for(t) for a, t in alias_to_table_lower.items()}

    # 2) Columns needed per alias
    cols_needed_by_alias: dict[str, set] = {a: set() for a in alias_to_table_lower.keys()}

    # 2a) Joins (include join keys)
    for j in parsed_query["joins"]["Joins"]:
        cols_needed_by_alias[j["left_alias"]].add(j["left_column"])
        cols_needed_by_alias[j["right_alias"]].add(j["right_column"])

    # 2b) GROUP BY columns
    for a, c in parsed_query.get("GroupBy", []):
        cols_needed_by_alias[a].add(c)

    # 2c) SELECT items and aggregations
    for item in parsed_query.get("select", []):
        if item["kind"] == "column":
            a, c = item["source"]
            cols_needed_by_alias[a].add(c)

    for agg in parsed_query.get("aggregations", {}).values():
        a, c = agg["source"]
        cols_needed_by_alias[a].add(c)

    # 2d) ORDER BY columns (aggregations already covered)
    for ob in parsed_query.get("orderBy", []):
        if ob["kind"] == "column":
            a, c = ob["source"]
            cols_needed_by_alias[a].add(c)

    # 2e) Convert to real table names
    columns_to_load = {}
    for a, cols in cols_needed_by_alias.items():
        tkey = alias_to_meta[a]
        # intersect with actual table columns for safety
        actual_cols = set(metadata[tkey]["columns"].keys())
        columns_to_load[tkey] = sorted(list(set(cols) & actual_cols)) if actual_cols else sorted(list(cols))

    # 3) Join order and algorithm selection
    left_alias, left_table_lower = parsed_query["joins"]["base_table"]
    left_table_key = meta_key_for(left_table_lower)

    # Track running estimate for the "current" intermediate result
    current_rows = int(metadata[left_table_key]["rows"])
    current_key_ndv = {}
    left_sorted_on = None  # track if current left side is sorted (key name or None)

    joins = parsed_query["joins"]["Joins"]
    join_plan = []

    for join in joins:
        right_table_key = meta_key_for(join["joined_table_name"])
        left_key  = join["left_column"]
        right_key = join["right_column"]

        # Check if sides are already sorted
        left_already_sorted = (left_sorted_on == left_key)
        right_already_sorted = False  # base tables assumed unsorted

        # --- Algorithm decision ---
        choice = decide_join_algo(
            metadata,
            left_table_key,
            right_table_key,
            left_key,
            right_key,
            rows_left=current_rows,
            mem_budget_bytes=12*1024**3,
            overhead=48,
            left_already_sorted=left_already_sorted,
            right_already_sorted=right_already_sorted,
            need_sorted_output=len(parsed_query.get("GroupBy", [])) > 0,
        )
        # Add join keys and table names to the choice dict
        choice["left_table"] = left_table_key
        choice["right_table"] = right_table_key
        choice["left_key"] = left_key
        choice["right_key"] = right_key


        join_plan.append(choice)
        print(f"Join {left_table_key}({current_rows} rows) ⋈ {right_table_key}: {choice['algorithm']}")

        # --- Update running estimate ---
        provider_alias = join["left_alias"] if join["left_alias"] != join["joined_table_alias"] else join["right_alias"]
        provider_tbl = meta_key_for(parsed_query["tables"][provider_alias])

        ndvL = current_key_ndv.get(left_key, _ndv_guess(metadata, provider_tbl, left_key))
        ndvR = _ndv_guess(metadata, right_table_key, right_key)

        left_unique  = (provider_tbl == "Songs" and left_key == "song_id") or \
                    (provider_tbl == "Users" and left_key == "user_id")
        right_unique = (right_table_key == "Songs" and right_key == "song_id") or \
                    (right_table_key == "Users" and right_key == "user_id")

        right_rows = int(metadata[right_table_key]["rows"])
        next_rows = _rows_after_join(current_rows, right_rows, ndvL, ndvR, left_unique, right_unique)

        current_rows = next_rows
        current_key_ndv[left_key]  = min(ndvL, current_rows)
        current_key_ndv[right_key] = min(ndvR, current_rows)

        # Track if output is sorted (SMJ produces sorted output on join key)
        if choice["algorithm"] == "SMJ":
            # SMJ output is sorted on both keys; track the one we'll use next
            left_sorted_on = left_key  # or right_key, depending on which side continues
        else:
            # HPJ output is unsorted
            left_sorted_on = None

        # Advance left_table_key
        if choice["algorithm"] == "HPJ":
            left_table_key = right_table_key if choice["build_table"] == left_table_key else left_table_key
        else:
            left_table_key = right_table_key if right_rows >= current_rows else left_table_key


    print(f"\nFinal estimated rows after all joins: {current_rows}")

    # 4) Assemble final plan dictionary
    plan = {
        "columns_to_load": columns_to_load,        # Dict[table_name, List[column]]
        "join_plan": join_plan,                    # List of join decisions from the loop
        "estimated_final_rows": current_rows,      # Estimated rows after all joins
        "group_by": parsed_query.get("GroupBy", []),
        "select": parsed_query.get("select", []),
        "aggregations": parsed_query.get("aggregations", {}),
        "order_by": parsed_query.get("orderBy", []),
        "memory_budget_bytes": memory_bytes,
    }

    return plan

In [58]:
def get_tables_dict_for_query_planner(tables, size="100MB"):
    '''
    Returns a new table dictionary to be size specific
    1. selects entries from original table dictionary that ends with size
    2. drops size prefix from dictionary key
    '''
    return {name.split("_")[0]: db for name,db in tables.items() if name.endswith(size)}

In [100]:
class QueryPlanner:
    def __init__(self, tables, query):
        self.tables = tables
        self.query = query
    
    def parse_query(self):
        self.parsed_query = parse_sql(self.query)
    
    def load_metadata(self): 
        self.metadata = analyze_metadata_before_loading(self.tables)
    
    def load_plan(self):
        self.plan = plan_query_execution(self.metadata, self.parsed_query)

        # columns_to_load: Dict[str, List[str]]
        # Maps table name to list of column names to read.
        # Only includes columns needed for joins, filters, GROUP BY, aggregations, and SELECT.
        # Example: {"Songs": ["song_id", "title"], "Listens": ["song_id", "user_id"], "Users": ["user_id", "age"]}
        self.columns_to_load = self.plan["columns_to_load"]

        # join_plan: List[Dict]
        # Ordered list of join steps. Each dict contains:
        #   - "algorithm": "HPJ" or "SMJ"
        #   - "build_side": "L" or "R" (for HPJ only, tells the smaller side, which side to build hash table on)
        #   - "build_table": table name like "Songs" (for HPJ only, the smaller side)
        #   - "hpj_ios", "smj_ios": estimated I/O costs for comparison
        #   - "hpj_build_bytes": {"L": bytes, "R": bytes} - memory needed for each side
        #   - "fits_in_memory": bool (whether build side fits in RAM)
        self.join_plan = self.plan["join_plan"]

        # group_by: List[Tuple[str, str]]
        # List of (alias, column_name) tuples for GROUP BY clause.
        # Example: [('s', 'song_id'), ('s', 'title')] for GROUP BY s.song_id, s.title
        self.group_by = self.plan["group_by"]

        # select: List[Dict]
        # List of SELECT items. Each dict contains:
        #   - "kind": "column" or "aggregation"
        #   - For "column": {"kind": "column", "source": (alias, col_name), "alias": output_name or None}
        #   - For "aggregation": {"kind": "aggregation", "agg_key": int, "alias": output_name or None}
        # Example: [
        #   {"kind": "column", "source": ('s', 'song_id'), "alias": None},
        #   {"kind": "aggregation", "agg_key": 1, "alias": "avg_age"}
        # ]
        self.select = self.plan["select"]

        # aggregations: Dict[int, Dict]
        # Maps aggregation ID (agg_key) to aggregation spec. Each spec contains:
        #   - "func": "avg", "count", "sum", etc.
        #   - "source": (alias, column_name) tuple - which column to aggregate
        #   - "distinct": bool - whether COUNT(DISTINCT ...) or not
        #   - "output_name": string or None - the AS alias for output
        # Example: {
        #   1: {"func": "avg", "source": ('u', 'age'), "distinct": False, "output_name": "avg_age"},
        #   2: {"func": "count", "source": ('l', 'user_id'), "distinct": True, "output_name": None}
        # }
        self.aggregations = self.plan["aggregations"]

        # order_by: List[Dict]
        # List of ORDER BY clauses. Each dict contains:
        #   - "kind": "column" or "aggregation"
        #   - For "column": {"kind": "column", "source": (alias, col_name), "direction": "asc"|"desc"}
        #   - For "aggregation": {"kind": "aggregation", "agg_key": int, "direction": "asc"|"desc"}
        self.order_by = self.plan["order_by"]

    def plan(self):
        self.parse_query()
        self.load_metadata()
        self.load_plan()


class QueryExecutor:
    def __init__(self, tables, num_partitions=8, output_dir="temp", 
                 planner=None, size="100MB", use_streaming=True, 
                 join_algo_override = None,
                 parquet_batch_size=1000000,
                 num_pages_per_split=100):
        self.tables = tables
        self.num_partitions = num_partitions
        self.use_streaming = use_streaming
        self.parquet_batch_size = parquet_batch_size
        self.output_dir = output_dir
        self.planner = planner or QueryPlanner(get_tables_dict_for_query_planner(tables), query)
        self.num_pages_per_split = num_pages_per_split
        self.join_algo = join_algo_override
        os.makedirs(self.output_dir, exist_ok=True)

        self.planner.plan()

    def execute_hardcoded_query(self):
        """
        Executes the following SQL query:

        SELECT s.song_id, AVG(u.age) AS avg_age,
        COUNT(DISTINCT l.user_id)
        FROM Songs s
        JOIN Listens l ON s.song_id = l.song_id
        JOIN Users u ON l.user_id = u.user_id
        GROUP BY s.song_id, s.title
        ORDER BY COUNT(DISTINCT l.user_id) DESC, s.song_id;
        """
        # Hardcoded
        columns = self.planner.columns_to_load

        
        join_order = ["Listens"] + [join_step["build_table"] for join_step in self.planner.join_plan]
        join_algorithm =  [
            (self.join_algo if self.join_algo else join_step["algorithm"])
            for join_step in self.planner.join_plan
        ]
        join_keys = [join_step["left_key"] for join_step in self.planner.join_plan]

        # do joins
        result = self.tables[f"{join_order[0]}_{SIZE}"]
        columns_table1 = columns[join_order[0]]
        for i in range(1, len(join_order)):
            table = self.tables[f"{join_order[i]}_{SIZE}"]
            if join_algorithm[i-1] == "HPJ":
                hpj = FastHashPartitionJoin(num_partitions=self.num_partitions, parquet_batch_size=self.parquet_batch_size, use_streaming=self.use_streaming)
                result = hpj.join(result, table, join_keys[i-1], join_keys[i-1], columns_table1=columns_table1, columns_table2=columns[join_order[i]])
            else:
                smj = SortMergeJoin(num_pages_per_split=self.num_pages_per_split)
                result = smj.join(result, table, join_keys[i-1], join_keys[i-1], columns_table1=columns_table1, columns_table2=columns[join_order[i]])
            columns_table1 = None # after first join, keep all columns from Left table

        # do group by 
        groupby_cols = ["song_id", "title"]
        avg_col = "age"
        avg_col_name = "avg_age"
        distinct_col = "user_id"
        distinct_col_name = "count_distinct_user_id"
        select_cols = ["song_id"]

        groupby_average_distinct = HashGroupbyAverageAndDistinct(
            num_partitions=self.num_partitions,
            parquet_batch_size=self.parquet_batch_size,
            use_streaming=self.use_streaming
        )

        result = groupby_average_distinct.groupby_average_distinct(
            result, 
            groupby_cols=groupby_cols,
            average_col=avg_col,
            average_col_name=avg_col_name,
            distinct_col=distinct_col,
            distinct_col_name=distinct_col_name,
            select_cols=select_cols
        )
        
        # sort by count distinct
        smj_sorter = SortMergeJoin()
        sorted_result = smj_sorter._external_sort(
            result,
            "count_distinct_user_id",
            self.output_dir,
            "result",
            ascending = False,
            tiebreak_key="song_id"
        )

        return sorted_result

# Section 5: Performance Benchmarking

In [105]:
def benchmark_query(executor, dataset_size):
    """Benchmark the query execution time and memory usage."""
    print(f"\nBenchmarking with {dataset_size} dataset...")
    start_mem = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    start_time = time.time()

    result = executor.execute_hardcoded_query()

    end_time = time.time()
    end_mem = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)

    print(f"Execution Time: {end_time - start_time:.2f} seconds")
    print(f"Memory Usage: {end_mem - start_mem:.2f} MB")
    return result

## 100MB Benchmark

In [109]:
%%memit
SIZE = "100MB"
executor_smj = QueryExecutor(tables, size=SIZE, join_algo_override="SMJ", 
                             output_dir="temp_smj_100mb")
output_smj_100mb = benchmark_query(executor_smj, SIZE)



Method 'parse_sql' took 0.0001 seconds.
Join Songs(10000 rows) ⋈ Listens: HPJ
Join Listens(1000000 rows) ⋈ Users: HPJ

Final estimated rows after all joins: 1000000

Benchmarking with 100MB dataset...
Method '_external_sort' took 0.1552 seconds.
Method '_external_sort' took 0.0714 seconds.
Method '_streaming_inner_join' took 9.1655 seconds.
Method 'join' took 9.3925 seconds.
Method '_external_sort' took 0.8786 seconds.
Method '_external_sort' took 0.0721 seconds.
Method '_streaming_inner_join' took 48.1490 seconds.
Method 'join' took 49.1001 seconds.
Method 'groupby_average_distinct' took 3.8671 seconds.
Method '_external_sort' took 0.0611 seconds.
Execution Time: 62.42 seconds
Memory Usage: 87.05 MB
peak memory: 965.44 MiB, increment: 149.02 MiB


In [116]:
%%memit
executor_hpj = QueryExecutor(tables, size=SIZE, join_algo_override="HPJ", 
                             output_dir="temp_hpj_100mb")
output_hpj_100mb = benchmark_query(executor_hpj, SIZE)



Method 'parse_sql' took 0.0002 seconds.
Join Songs(10000 rows) ⋈ Listens: HPJ
Join Listens(1000000 rows) ⋈ Users: HPJ

Final estimated rows after all joins: 1000000

Benchmarking with 100MB dataset...
Method '_hash_partition' took 0.3273 seconds.
Method '_hash_partition' took 0.0120 seconds.
Method '_build_hash_map' took 0.0135 seconds.
Method '_build_hash_map' took 0.0133 seconds.
Method '_build_hash_map' took 0.0131 seconds.
Method '_build_hash_map' took 0.0128 seconds.
Method '_build_hash_map' took 0.0130 seconds.
Method '_build_hash_map' took 0.0140 seconds.
Method '_build_hash_map' took 0.0133 seconds.
Method '_build_hash_map' took 0.0129 seconds.
Method 'join' took 2.0401 seconds.
Method '_hash_partition' took 0.3189 seconds.
Method '_hash_partition' took 0.0220 seconds.
Method '_build_hash_map' took 0.0596 seconds.
Method '_build_hash_map' took 0.0604 seconds.
Method '_build_hash_map' took 0.0600 seconds.
Method '_build_hash_map' took 0.0605 seconds.
Method '_build_hash_map' too

In [117]:
# ensure results are equal, utilizes custom implemented equals method on ColumnarDbFile
output_hpj_100mb == output_smj_100mb

True

## 1GB Benchmark

In [122]:
%%memit
SIZE = "1GB"
executor_smj = QueryExecutor(tables, size=SIZE, join_algo_override="SMJ", 
                             output_dir="temp_smj_1GB")
output_smj_1gb = benchmark_query(executor_smj, SIZE)

Method 'parse_sql' took 0.0001 seconds.
Join Songs(10000 rows) ⋈ Listens: HPJ
Join Listens(1000000 rows) ⋈ Users: HPJ

Final estimated rows after all joins: 1000000

Benchmarking with 1GB dataset...
Method '_external_sort' took 1.4152 seconds.
Method '_external_sort' took 0.0861 seconds.
Method '_streaming_inner_join' took 89.3497 seconds.
Method 'join' took 90.8516 seconds.
Method '_external_sort' took 7.7476 seconds.
Method '_external_sort' took 0.0875 seconds.
Method '_streaming_inner_join' took 510.1438 seconds.
Method 'join' took 517.9800 seconds.
Method 'groupby_average_distinct' took 43.7831 seconds.
Method '_external_sort' took 0.0842 seconds.
Execution Time: 652.70 seconds
Memory Usage: 1454.02 MB
peak memory: 2856.08 MiB, increment: 2659.83 MiB


In [120]:
%%memit
executor_hpj = QueryExecutor(tables, size=SIZE, join_algo_override="HPJ", 
                             output_dir="temp_hpj_1gb")
output_hpj_1gb = benchmark_query(executor_hpj, SIZE)

Method 'parse_sql' took 0.0002 seconds.
Join Songs(10000 rows) ⋈ Listens: HPJ
Join Listens(1000000 rows) ⋈ Users: HPJ

Final estimated rows after all joins: 1000000

Benchmarking with 1GB dataset...
Method '_hash_partition' took 2.8575 seconds.
Method '_hash_partition' took 0.0751 seconds.
Method '_build_hash_map' took 0.1243 seconds.
Method '_build_hash_map' took 0.1247 seconds.
Method '_build_hash_map' took 0.1243 seconds.
Method '_build_hash_map' took 0.1278 seconds.
Method '_build_hash_map' took 0.1259 seconds.
Method '_build_hash_map' took 0.1261 seconds.
Method '_build_hash_map' took 0.1256 seconds.
Method '_build_hash_map' took 0.1253 seconds.
Method 'join' took 12.3049 seconds.
Method '_hash_partition' took 3.4995 seconds.
Method '_hash_partition' took 0.1758 seconds.
Method '_build_hash_map' took 0.6315 seconds.
Method '_build_hash_map' took 0.6109 seconds.
Method '_build_hash_map' took 0.6154 seconds.
Method '_build_hash_map' took 0.6145 seconds.
Method '_build_hash_map' took

# 10 GB Challenge
Design considerations that allowed this our discussed below

In [121]:
%%memit
SIZE = "10GB"
executor_hpj = QueryExecutor(tables, size=SIZE, join_algo_override="HPJ", 
                             output_dir="temp_hpj_10GB")
output_hpj = benchmark_query(executor_hpj, SIZE)

Method 'parse_sql' took 0.0002 seconds.
Join Songs(10000 rows) ⋈ Listens: HPJ
Join Listens(1000000 rows) ⋈ Users: HPJ

Final estimated rows after all joins: 1000000

Benchmarking with 10GB dataset...
Method '_hash_partition' took 29.2806 seconds.
Method '_hash_partition' took 0.5365 seconds.
Method '_build_hash_map' took 1.2689 seconds.
Method '_build_hash_map' took 1.2185 seconds.
Method '_build_hash_map' took 1.2179 seconds.
Method '_build_hash_map' took 1.2235 seconds.
Method '_build_hash_map' took 1.2203 seconds.
Method '_build_hash_map' took 1.2251 seconds.
Method '_build_hash_map' took 1.2227 seconds.
Method '_build_hash_map' took 1.2263 seconds.
Method 'join' took 126.1223 seconds.
Method '_hash_partition' took 43.6144 seconds.
Method '_hash_partition' took 1.4022 seconds.
Method '_build_hash_map' took 6.2268 seconds.
Method '_build_hash_map' took 6.1643 seconds.
Method '_build_hash_map' took 6.2018 seconds.
Method '_build_hash_map' took 6.1267 seconds.
Method '_build_hash_map' 

## Performance Analysis

| DATASET SIZE | ALGORITHM | TOTAL TIME (S) | JOIN TIME (First + Second Join) (S) | PEAK MEMORY (GB) |
| ------------ | --------- | -------------- | ------------- | ---------------- |
| 100MB        | HPJ       | 7.71   | 2.04 + 2.39 = 4.43   |  0.93 |
| 100MB        | SMJ       | 62.42   | 9.39 + 49.1 = 58.49  | 1.01  |
| 1GB          | HPJ       | **66.83**  | 12.30 + 17.50 = 29.80  | 1.75   |
| 1GB          | SMJ       | 625.41   | 91.94 + 491.85 = 583.79    | 3.00    |
| **10GB**          | HPJ       | 918.81   | 29.28 + 126.12 = 155.4    | 2.33   |


Our HPJ algorithm wildly outperformed the benchmarks in terms of speed and displayed great memory efficiency. We implemented the following optimizations to achieve speed and memory performance:

- Partitioned Processing: Divides data into smaller partitions using hash-based distribution, enabling processing of datasets larger than available RAM by only loading one partition pair at a time into memory.
- Asymmetric Join Strategy: Identifies the smaller side of each partition to build the in-memory hash table, then streams the larger side in configurable batches (default 50k rows), reducing peak memory usage from O(2×partition_size) to O(smaller_partition + batch_size).
- Vectorized Operations for join: Replaces row-by-row Python loops with bulk NumPy/Pandas operations - using groupby for hash map construction (5-10x faster), array-based index collection, and single bulk .iloc[array] calls instead of thousands of individual .iloc[i] accesses.
- Streaming Output Mode: Optionally writes results to a single Parquet file using ParquetWriter instead of creating multiple files, reducing I/O overhead and file fragmentation for large result sets. See ColumnarDbFile.start_stream()
    - this provided minimal speedup (~1-2 seconds) but provides better disk usage

Within our SMJ algorithm, we identified several bottlenecks. The primary bottleneck in Sort-Merge Join (SMJ) is the external sorting phase, which incurs O(n log n) CPU cost and 2–3× I/O amplification per sort pass when data exceeds memory. For large tables sorting dominates runtime, especially with small memory budgets that force many merge passes. Key optimizations include: 
1) increasing buffer size: to reduce the number of external merge passes—each doubling of B can halve the pass count via B-way merging
2) parallel sorting: by partitioning input data across threads and sorting runs concurrently before a final merge
3) Hybrid approaches: For workloads with skewed join keys, we could have implemented that hash-partition only the heavy-hitter groups after sorting can prevent memory blowups during the merge without re-sorting the entire dataset.

**Design Choices to Allow for 10GB dataset processing**: The most important design consideration was to never load a full table into RAM. Instead we paginate the reading. We relied on loading datasets in batches (see ColumnarDBFile.iter_pages) or by calling iter_batches on intermediate parquet files. Keeping a batch size that is large but still within RAM constraints was crucial for performance. 

These strategies were implemented in external_sort to avoid fully sorting in memory. 

In our HPJ join implementation, we rely on keep a large number of partitions (~8) and only loading the smaller partition file into memory to build the hash table. The larger partition file is loaded in batches as discussed above. 

Finally, in our groupby implementation, it would be infeasible to call pd.groupby on a whole dataset (would require loading into memory). Thus we hash partition on the groupby key and call groupby on each resulting partition file, gaurenteeing that the same groupby key is in the same file. 

Additional Potential Optimizations:
- the GroupBy operation showed performance slower (see outputs above) than HPJ joins. this could be improved by implementing COUNT(DISTINCT) with HyperLogLog approximations