<a href="https://colab.research.google.com/github/berkyalcinkaya/cs145-project2-systems/blob/main/cs145_project2_systems_template_fa2025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Collaborators

1.   Berk Yalcinkaya
2.   Nick Allen


# Setup

In [1]:
import pandas as pd
import os
import uuid
import argparse
import time
import psutil
import heapq
import pyarrow as pa
import pyarrow.parquet as pq
import random
import string
import numpy as np
from typing import List, Optional
import shutil
import glob
import gc
from IPython.display import display
import tempfile
from pathlib import Path

# Section 0: Generate Test Data

This section has already been implemented for you.

In [2]:
import gc


def generate_songs_chunk(start, size, string_length=100):
    data = {
        "song_id": range(start, start + size),
        "title": [f"Song_{i}" for i in range(start, start + size)],
    }
    base_strings = generate_base_strings(size, string_length)
    for i in range(1, 11):
        data[f"extra_col_{i}"] = np.roll(base_strings, shift=i)
    return pd.DataFrame(data)


def generate_users_chunk(start, size, string_length=100):
    data = {
        "user_id": range(start, start + size),
        "age": [18 + ((start + i) % 60) for i in range(size)],
    }
    base_strings = generate_base_strings(size, string_length)
    for i in range(1, 11):
        data[f"extra_col_{i}"] = np.roll(base_strings, shift=i)
    return pd.DataFrame(data)


def generate_listens_chunk(start, size, num_users, num_songs, string_length=16):
    data = {
        "listen_id": range(start, start + size),
        "user_id": np.random.randint(0, num_users, size=size),
        "song_id": np.random.randint(0, num_songs, size=size),
    }
    base_strings = generate_base_strings(size, string_length)
    for i in range(1, 11):
        data[f"extra_col_{i}"] = np.roll(base_strings, shift=i)
    return pd.DataFrame(data)


def generate_base_strings(num_records, string_length):
    chars = np.array(list("ab"))
    random_indices = np.random.randint(0, len(chars), size=(num_records, string_length))
    char_array = chars[random_indices]
    return np.array(list(map("".join, char_array)))


def _write_parquet_streamed(
    filename,
    total_rows,
    make_chunk_fn,
    chunk_size=250_000,
    compression="snappy",
):
    """
    Stream DataFrame chunks to a single Parquet file with one ParquetWriter.
    - schema_df: optional small DataFrame to lock schema; if None we'll infer from the first chunk.
    """
    written = 0

    first_chunk = make_chunk_fn(0, min(chunk_size, total_rows))
    first_table = pa.Table.from_pandas(first_chunk, preserve_index=False)
    writer = pq.ParquetWriter(filename, first_table.schema, compression=compression)
    writer.write_table(first_table)

    written += len(first_chunk)
    del first_chunk
    gc.collect()

    while written < total_rows:
        take = min(chunk_size, total_rows - written)
        chunk_df = make_chunk_fn(written, take)
        writer.write_table(pa.Table.from_pandas(chunk_df, preserve_index=False))
        written += take
        del chunk_df
        gc.collect()

    writer.close()


def generate_test_data(target_size="100MB"):
    """
    Generate datasets with proper foreign key relationships.

    Target COMPRESSED Parquet file sizes on disk:
    100MB total compressed:
        - Songs: 10K rows → ~5MB (5% of total)
        - Users: 50K rows → ~20MB (20% of total)
        - Listens: 1M rows → ~75MB (75% of total)
    1GB total compressed:
        - Songs: 100K rows → ~50MB (5% of total)
        - Users: 500K rows → ~200MB (20% of total)
        - Listens: 10M rows → ~750MB (75% of total)

    Each table needs:
        - Primary key column(s)
        - 10 additional string columns of k characters each
        - For Users: add 'age' column (random 18-80)

    CRITICAL: Listens table must have valid foreign keys!
    Every song_id must exist in Songs
    Every user_id must exist in Users
    """

    assert target_size in ["100MB", "1GB", "10GB"]
    if target_size == "100MB":
        num_songs = 10_000
        num_users = 50_000
        num_listens = 1_000_000

        songs_chunk = 10_000
        users_chunk = 50_000
        listens_chunk = 1_000_000
    if target_size == "1GB":
        num_songs = 100_000
        num_users = 500_000
        num_listens = 10_000_000

        songs_chunk = 10_000
        users_chunk = 50_000
        listens_chunk = 1_000_000
    else: 
        num_songs = 1_000_000
        num_users = 5_000_000
        num_listens = 100_000_000

        songs_chunk = 10_000
        users_chunk = 50_000
        listens_chunk = 1_000_000

    print("Writing Songs")
    _write_parquet_streamed(
        filename=f"songs_{target_size}.parquet",
        total_rows=num_songs,
        make_chunk_fn=lambda start, size: generate_songs_chunk(start, size),
        chunk_size=songs_chunk,
    )

    print("Writing Users")
    _write_parquet_streamed(
        filename=f"users_{target_size}.parquet",
        total_rows=num_users,
        make_chunk_fn=lambda start, size: generate_users_chunk(start, size),
        chunk_size=users_chunk,
    )

    print("Writing Listens")
    _write_parquet_streamed(
        filename=f"listens_{target_size}.parquet",
        total_rows=num_listens,
        make_chunk_fn=lambda start, size: generate_listens_chunk(
            start, size, num_users, num_songs
        ),
        chunk_size=listens_chunk,
    )

    print("Done!")

In [3]:
random.seed(0)

generate_test_data('1GB')

Writing Songs
Writing Users
Writing Listens
Done!


# Section 1: Parquet-based Columnar Storage

Implement Parquet-based storage for the tables
- For simplicity, store all data for a table in a single Parquet file and use a single DataFrame object as a buffer

In [4]:
# see ed: https://edstem.org/us/courses/87394/discussion/7251811 for advice on writing to a parquet without loading existing into RAM
# a ColumnarDbFile is actually a directory with an arbitrary number of parquet files inside
# Append writes a new file with the next postfix
# Retrieve reads all parquet files and concatenates them together, done natively by pandas
class ColumnarDbFile:
    def __init__(self, table_name, file_dir='data', file_pfx=''):
        self.file_pfx = file_pfx
        self.table_name = table_name
        self.file_dir = file_dir
        #os.makedirs(self.file_dir, exist_ok=True)
        self.base_file_name = f"{self.file_dir}/{self.file_pfx}_{self.table_name}"
        os.makedirs(self.base_file_name, exist_ok=True)

    def build_table(self, data):
        """Build and save table data to Parquet."""
        data.to_parquet(f"{self.base_file_name}/{self.table_name}-0.parquet")
        return

    def retrieve_data(self, columns=None):
        """Create pd.DataFrame by reading from Parquet"""
        return pd.read_parquet(self.base_file_name, columns=columns)

    def append_data(self, data):
        """Append new data to Parquet"""
        # Use glob to count the number of parquet files in the directory
        data.to_parquet(self.get_new_parquet_file())
        return

    def _get_num_parquets(self):
        return len(glob.glob(f"{self.base_file_name}/*.parquet"))

    def table_metadata(self):
        """Return total rows and total byte size of the table without loading data."""
        parquet_files = glob.glob(f"{self.base_file_name}/*.parquet")

        total_rows = 0
        total_bytes = 0

        for file in parquet_files:
            pf = pq.ParquetFile(file)
            meta = pf.metadata

            total_rows += meta.num_rows
            total_bytes += meta.serialized_size  # includes footer + metadata

        return {
            "num_files": len(parquet_files),
            "total_rows": total_rows,
            "total_compressed_bytes": total_bytes,
        }

    def table_disk_usage(self):
        parquet_files = glob.glob(f"{self.base_file_name}/*.parquet")

        total_bytes = sum(os.path.getsize(f) for f in parquet_files)

        return {
            "num_files": len(parquet_files),
            "total_bytes": total_bytes
        }

    @staticmethod
    def fits_in_12GB(bytes_needed: int) -> bool:
        TWELVE_GB = 12 * 1024**3
        return bytes_needed <= TWELVE_GB

    @staticmethod
    def can_process_parquet(bytes_on_disk: int, compression_factor: int = 5) -> bool:
        """
        Returns True if a Parquet dataset of `bytes_on_disk` can be processed
        within 12 GB of RAM, after accounting for decompression expansion.
        """
        estimated_ram = bytes_on_disk * compression_factor
        TWELVE_GB = 12 * 1024**3
        return estimated_ram <= TWELVE_GB


In [5]:
print("Building tables...")
if os.path.exists('data'):
    shutil.rmtree('data')
tables = {
    'Songs': ColumnarDbFile("Songs", file_dir='data'),
    'Users': ColumnarDbFile("Users", file_dir='data'),
    'Listens': ColumnarDbFile("Listens", file_dir='data')
}

size = "1GB"
songs_data = pd.read_parquet(f'songs_{size}.parquet')
users_data = pd.read_parquet(f'users_{size}.parquet')
listens_data = pd.read_parquet(f'listens_{size}.parquet')

tables['Songs'].build_table(songs_data)
tables['Users'].build_table(users_data)
tables['Listens'].build_table(listens_data)
print("Tables built successfully.")

Building tables...
Tables built successfully.


In [6]:
# retrieve data
tables['Songs'].retrieve_data(columns = ['song_id', 'title'])

Unnamed: 0,song_id,title
0,0,Song_0
1,1,Song_1
2,2,Song_2
3,3,Song_3
4,4,Song_4
...,...,...
99995,99995,Song_99995
99996,99996,Song_99996
99997,99997,Song_99997
99998,99998,Song_99998


In [7]:
tables['Listens'].retrieve_data(columns = ['listen_id', 'user_id', 'song_id'])

Unnamed: 0,listen_id,user_id,song_id
0,0,8822,42708
1,1,27918,88764
2,2,100379,46603
3,3,496553,36186
4,4,448685,94938
...,...,...,...
9999995,9999995,118644,8260
9999996,9999996,296462,23279
9999997,9999997,298289,77129
9999998,9999998,209754,88446


Analyze and report on:
- Space efficiency compared to row storage
  - e.g. Compare file sizes on disk: How much disk space does Parquet use vs. a row storage format like CSV?
- Compression ratios achieved with Parquet
  - e.g. Compare Parquet’s uncompressed encoded size (reported in its metadata) to its compressed on-disk size to compute compression ratios.
  - You could also report the memory expansion factor: how much larger the dataset becomes when loaded into a `pd.DataFrame` compared to the compressed file size.
- Read/write performance characteristics
  - e.g. Read performance: How long does it take to read all columns from Parquet vs. CSV?
  - e.g. Columnar advantage: How long does it take to read selective columns from Parquet vs. reading all columns?
  - e.g. Write performance: How long does it take to write data to Parquet vs. CSV?

In [8]:
def analyze(size="100MB"):
    """Analyze storage efficiency, compression, and read/write performance."""

    table_files = {
        "Songs": f"songs_{size}.parquet",
        "Users": f"users_{size}.parquet",
        "Listens": f"listens_{size}.parquet",
    }

    report_rows = []

    for table_name, parquet_file in table_files.items():
        parquet_path = Path(parquet_file)

        df = pd.read_parquet(parquet_path)
        mem_usage_bytes = df.memory_usage(deep=True).sum() # memory usage of the dataframe
        parquet_size_bytes = parquet_path.stat().st_size # size of the parquet file on disk

        parquet_file_obj = pq.ParquetFile(parquet_path)
        metadata = parquet_file_obj.metadata
        uncompressed_bytes = 0

        # iterate over all row groups and columns to get the total uncompressed size of the parquet file
        for rg_idx in range(metadata.num_row_groups):
            row_group = metadata.row_group(rg_idx)
            for col_idx in range(row_group.num_columns):
                column_meta = row_group.column(col_idx)
                if column_meta.total_uncompressed_size is not None:
                    uncompressed_bytes += column_meta.total_uncompressed_size

        # calculate compression ratio and memory expansion
        compression_ratio = (
            uncompressed_bytes / parquet_size_bytes
        )
        memory_expansion = (
            mem_usage_bytes / parquet_size_bytes
        )

        # test reading speed of parquet file vs csv, for all columns and selective columns
        # pick 1 less than the total number of columns to test reading selective columns
        subset_columns = list(df.columns)[0:len(df.columns)-1]

        with tempfile.TemporaryDirectory() as tmpdir:
            tmpdir_path = Path(tmpdir)

            csv_path = tmpdir_path / f"{parquet_path.stem}.csv"
            start = time.perf_counter()
            df.to_csv(csv_path, index=False)
            write_csv_time = time.perf_counter() - start
            csv_size_bytes = csv_path.stat().st_size

            parquet_tmp_path = tmpdir_path / f"{parquet_path.stem}.parquet"
            start = time.perf_counter()
            df.to_parquet(parquet_tmp_path, index=False)
            write_parquet_time = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_parquet(parquet_path)
            read_parquet_all = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_csv(csv_path)
            read_csv_all = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_parquet(parquet_path, columns=subset_columns)
            read_parquet_subset = time.perf_counter() - start

            start = time.perf_counter()
            _ = pd.read_csv(csv_path, usecols=subset_columns)
            read_csv_subset = time.perf_counter() - start

        size_saving_pct = (
            100.0 * (1 - parquet_size_bytes / csv_size_bytes)
        )

        # append the results to the report
        report_rows.append(
            {
                "table": table_name,
                "parquet_size_mb": parquet_size_bytes / (1024 ** 2),
                "csv_size_mb": csv_size_bytes / (1024 ** 2),
                "size_saving_pct": size_saving_pct,
                "compression_ratio": compression_ratio,
                "memory_expansion": memory_expansion,
                "read_parquet_all_s": read_parquet_all,
                "read_csv_all_s": read_csv_all,
                "read_parquet_subset_s": read_parquet_subset,
                "read_csv_subset_s": read_csv_subset,
                "write_parquet_s": write_parquet_time,
                "write_csv_s": write_csv_time,
            }
        )

        del df
        gc.collect()

    summary = pd.DataFrame(report_rows)
    print("Analysis Summary for Tables of Size " + size + " (sizes in MB, times in seconds):")
    return summary

In [9]:
display(analyze(size="1GB"))

Analysis Summary for Tables of Size 1GB (sizes in MB, times in seconds):


Unnamed: 0,table,parquet_size_mb,csv_size_mb,size_saving_pct,compression_ratio,memory_expansion,read_parquet_all_s,read_csv_all_s,read_parquet_subset_s,read_csv_subset_s,write_parquet_s,write_csv_s
0,Songs,42.678607,97.92129,56.415396,2.420447,3.478947,0.116536,1.014622,0.118276,0.918474,0.286086,1.379869
1,Users,203.417436,486.268065,58.167634,2.472125,3.530262,0.962916,5.373042,0.830418,4.827727,1.578626,6.692811
2,Listens,834.557252,1817.278606,54.076538,2.40917,7.702006,5.041973,30.441947,6.367491,26.831997,9.867773,45.131197


In [None]:
# display(analyze(size="1GB"))

# Section 2: Parse SQL Query

In this section, you should implement logic to parse the following SQL query:
```sql
    SELECT s.song_id, AVG(u.age) AS avg_age,
       COUNT(DISTINCT l.user_id) AS count_distinct_users,
    FROM Songs s
    JOIN Listens l ON s.song_id = l.song_id
    JOIN Users u ON l.user_id = u.user_id
    GROUP BY s.song_id, s.title
    ORDER BY COUNT(DISTINCT l.user_id) DESC, s.song_id;
```

You should manually extract the components from the provided query (i.e. you don't need to implement a general SQL parser, just handle this specific query).

In [10]:
query = """SELECT s.song_id, AVG(u.age) AS avg_age,
COUNT(DISTINCT l.user_id)
FROM Songs s
JOIN Listens l ON s.song_id = l.song_id
JOIN Users u ON l.user_id = u.user_id
GROUP BY s.song_id, s.title
ORDER BY COUNT(DISTINCT l.user_id) DESC, s.song_id;
"""

In [11]:
import re
import re

def parse_tables(query):

    # pattern matches: "from songs s" or "join listens l"
    pattern = r"(from|join)\s+([a-z_]+)\s+([a-z])"

    matches = re.findall(pattern, query)

    tables = {}
    for _, table_name, alias in matches:
        tables[alias] = table_name

    return tables

def parse_joins(query):

    # 1) Get the base table from the FROM clause
    base_match = re.search(r"from\s+([a-z_]+)\s+([a-z])", query)
    if not base_match:
        raise ValueError("Could not find FROM clause")

    base_table_name = base_match.group(1)
    base_alias = base_match.group(2)
    base_table = (base_alias, base_table_name)

    # 2) Get each JOIN clause, in order
    # pattern matches:
    #   join listens l on s.song_id = l.song_id
    join_pattern = (
        r"join\s+([a-z_]+)\s+([a-z])\s+on\s+"
        r"([a-z])\.([a-z_]+)\s*=\s*([a-z])\.([a-z_]+)"
    )

    joins = []
    for m in re.finditer(join_pattern, query):
        joined_table_name = m.group(1)
        joined_alias = m.group(2)
        left_alias = m.group(3)
        left_col = m.group(4)
        right_alias = m.group(5)
        right_col = m.group(6)

        joins.append(
            {
                "joined_table_alias": joined_alias,
                "joined_table_name": joined_table_name,
                "left_alias": left_alias,
                "left_column": left_col,
                "right_alias": right_alias,
                "right_column": right_col,
            }
        )

    return {"base_table" : base_table, "Joins" : joins}


def parse_group_by(query):
    """
    Return GROUP BY columns as a list of (alias, column) tuples.
    Example: [('s', 'song_id'), ('s', 'title')]
    """
    q = query.lower()

    # Capture whatever is between GROUP BY and ORDER BY/semicolon/end
    match = re.search(r"group\s+by\s+(.+?)(order\s+by|;|$)", q, re.DOTALL)
    if not match:
        return []

    groupby_text = match.group(1).strip()

    columns = []
    for col in groupby_text.split(","):
        col = col.strip()

        # Expect pattern: alias.column
        alias, column = col.split(".")
        columns.append((alias, column))

    return columns

def parse_select_and_aggregations(query):
    """
    Build:
      aggregations: {agg_key: {...}}
      select: list of items that may refer to agg_key
    """
    q = query.lower()

    m = re.search(r"select\s+(.+?)\s+from", q, re.DOTALL)
    if not m:
        return [], {}

    select_text = m.group(1).strip()
    raw_items = [item.strip() for item in select_text.split(",") if item.strip()]

    select_list = []
    aggregations = {}
    agg_id = 1

    for idx, item in enumerate(raw_items, start=1):
        # AVG(...)
        if item.startswith("avg("):
            m_avg = re.match(
                r"avg\(\s*([a-z])\.([a-z_]+)\s*\)(\s+as\s+([a-z_]+))?",
                item
            )
            if not m_avg:
                raise ValueError(f"Could not parse AVG aggregation: {item}")
            alias_letter = m_avg.group(1)
            col_name = m_avg.group(2)
            out_alias = m_avg.group(4) if m_avg.group(4) else None

            aggregations[agg_id] = {
                "func": "avg",
                "source": (alias_letter, col_name),
                "distinct": False,
                "output_name": out_alias,
            }

            select_list.append(
                {
                    "kind": "aggregation",
                    "agg_key": agg_id,
                    "alias": out_alias,

                }
            )
            agg_id += 1

        # COUNT(DISTINCT ...)
        elif item.startswith("count("):
            m_cnt = re.match(
                r"count\(\s*distinct\s+([a-z])\.([a-z_]+)\s*\)(\s+as\s+([a-z_]+))?",
                item
            )
            if not m_cnt:
                raise ValueError(f"Could not parse COUNT aggregation: {item}")
            alias_letter = m_cnt.group(1)
            col_name = m_cnt.group(2)
            out_alias = m_cnt.group(4) if m_cnt.group(4) else None

            aggregations[agg_id] = {
                "func": "count",
                "source": (alias_letter, col_name),
                "distinct": True,
                "output_name": out_alias,
            }

            select_list.append(
                {
                    "kind": "aggregation",
                    "agg_key": agg_id,
                    "alias": out_alias,
                }
            )
            agg_id += 1

        # Plain column: alias.column
        else:
            alias_letter, col_name = item.split(".")
            select_list.append(
                {
                    "kind": "column",
                    "source": (alias_letter, col_name),
                    "alias": None,
                }
            )

    return select_list, aggregations


def parse_order_by(query, aggregations):
    """
    Build order_by list where entries can refer to aggregations via agg_key.
    """
    q = query.lower()

    m = re.search(r"order\s+by\s+(.+?)(;|$)", q, re.DOTALL)
    if not m:
        return []

    order_text = m.group(1).strip()
    raw_items = [item.strip() for item in order_text.split(",") if item.strip()]

    order_by = []

    for item in raw_items:
        direction = "asc"
        expr = item

        if expr.endswith(" desc"):
            direction = "desc"
            expr = expr[:-5].strip()
        elif expr.endswith(" asc"):
            direction = "asc"
            expr = expr[:-4].strip()

        # COUNT(DISTINCT ...) → match an aggregation
        if expr.startswith("count("):
            m_cnt = re.match(
                r"count\(\s*distinct\s+([a-z])\.([a-z_]+)\s*\)",
                expr
            )
            if not m_cnt:
                raise ValueError(f"Could not parse ORDER BY aggregation: {expr}")
            src = (m_cnt.group(1), m_cnt.group(2))

            agg_key = None
            for k, agg in aggregations.items():
                if (
                    agg["func"] == "count"
                    and agg["distinct"]
                    and agg["source"] == src
                ):
                    agg_key = k
                    break

            if agg_key is None:
                raise ValueError(f"No matching aggregation found for ORDER BY expr: {expr}")

            order_by.append(
                {
                    "kind": "aggregation",
                    "agg_key": agg_key,
                    "direction": direction,
                }
            )

        else:
            # assume plain column: alias.column
            alias_letter, col_name = expr.split(".")
            order_by.append(
                {
                    "kind": "column",
                    "source": (alias_letter, col_name),
                    "direction": direction,
                }
            )

    return order_by

def parse_sql(query):
    """
    YOUR TASK: Extract tables, joins, and aggregations
    """
    # Parse SQL string to identify:
    # - Tables involved
    # - Join conditions
    # - GROUP BY columns
    # - Aggregation functions
    # Your implementation here
    query = query.lower()
    output = {}

    output["tables"] = parse_tables(query)
    output["joins"] = parse_joins(query)
    output["GroupBy"] = parse_group_by(query)
    output["select"], output["aggregations"] = parse_select_and_aggregations(query)
    output["orderBy"] = parse_order_by(query, output["aggregations"])

    return output




In [23]:
parsed_sql = parse_sql(query)
for key, value in output.items():
    print(f"{key}: {value}")

tables: {'s': 'songs', 'l': 'listens', 'u': 'users'}
joins: {'base_table': ('s', 'songs'), 'Joins': [{'joined_table_alias': 'l', 'joined_table_name': 'listens', 'left_alias': 's', 'left_column': 'song_id', 'right_alias': 'l', 'right_column': 'song_id'}, {'joined_table_alias': 'u', 'joined_table_name': 'users', 'left_alias': 'l', 'left_column': 'user_id', 'right_alias': 'u', 'right_column': 'user_id'}]}
GroupBy: [('s', 'song_id'), ('s', 'title')]
select: [{'kind': 'column', 'source': ('s', 'song_id'), 'alias': None}, {'kind': 'aggregation', 'agg_key': 1, 'alias': 'avg_age'}, {'kind': 'aggregation', 'agg_key': 2, 'alias': None}]
aggregations: {1: {'func': 'avg', 'source': ('u', 'age'), 'distinct': False, 'output_name': 'avg_age'}, 2: {'func': 'count', 'source': ('l', 'user_id'), 'distinct': True, 'output_name': None}}
orderBy: [{'kind': 'aggregation', 'agg_key': 2, 'direction': 'desc'}, {'kind': 'column', 'source': ('s', 'song_id'), 'direction': 'asc'}]


# Section 3: Implement Join Algorithms

In this section, you will implement the execution operators (*how* to join) and aggregation after joins.

**Reminder:** If you use temporary files or folders, you should clean them up either as part of your join logic, or after each run. Otherwise you might run into correctness issues!

In [13]:
import hashlib

def HASHVALUE(value, B):
    if isinstance(value, int):
        return hash(value) % B
    sha256 = hashlib.sha256()
    sha256.update(str(value).encode("utf-8"))
    return int(sha256.hexdigest(), 16) % B

Implement `HashPartitionJoin`:
1. Hash partition both tables
2. Build hash table from smaller partition
3. Probe with larger partition
4. Return joined results

In [None]:
class HashPartitionJoin:
    def __init__(self, num_partitions=4, parquet_batch_size=1000):
        self.num_partitions = num_partitions
        self.parquet_batch_size = parquet_batch_size

    def join(self, table1: ColumnarDbFile, table2: ColumnarDbFile, join_key1, join_key2,
             temp_dir='temp', columns_table1=None, columns_table2=None):
        """
        Perform a hash partition join between two ColumnarDbFile instances.

        Parameters:
        - table1: Left table (ColumnarDbFile)
        - table2: Right table (ColumnarDbFile)
        - join_key1: Join key from table1
        - join_key2: Join key from table2
        - temp_dir: Directory to store temporary files
        - columns_table1: List of columns to select from table1
        - columns_table2: List of columns to select from table2

        Returns:
        - join_result_table: ColumnarDbFile instance containing the join results
        """
        os.makedirs(temp_dir, exist_ok=True)
        # Partition both tables
        partitions1 = self._hash_partition(table1, join_key1, temp_dir, 'left', columns_table1)
        partitions2 = self._hash_partition(table2, join_key2, temp_dir, 'right', columns_table2)

    def _hash_partition(self, table: ColumnarDbFile, join_key, output_dir, side, columns=None):
        # Find all parquet files in the directory
        parquet_files = glob.glob(f"{table.base_file_name}/*.parquet")
        if not parquet_files:
            raise ValueError(f"No parquet files found in {table.base_file_name}")
        
        writers: dict[int, pq.ParquetWriter] = {}
        
        # Process each parquet file in the directory
        for parquet_file_path in parquet_files:
            parquet_file = pq.ParquetFile(parquet_file_path)
            # Ensure join_key is included in columns for partitioning
            read_columns = columns
            if columns and join_key not in columns:
                read_columns = list(columns) + [join_key]
            
            for batch in parquet_file.iter_batches(batch_size=self.parquet_batch_size, columns=read_columns):
                batch_df = batch.to_pandas()
                
                # Add partition column based on join_key
                batch_df["_part"] = batch_df[join_key].apply(lambda x: HASHVALUE(x, self.num_partitions))
                
                # Filter to requested columns if specified (but keep _part for grouping)
                if columns:
                    # Select requested columns plus _part
                    batch_df = batch_df[columns + ["_part"]]

                # Group rows by partition id and write them out
                for part_id, part_df in batch_df.groupby("_part"):
                    # Drop helper column before writing
                    part_df = part_df.drop(columns=["_part"])

                    # Convert to Arrow Table
                    part_table = pa.Table.from_pandas(part_df, preserve_index=False)

                    # Lazily create writer for this partition
                    writer = writers.get(part_id)
                    if writer is None:
                        part_path = self._make_partition_path(output_dir, side, part_id)
                        writer = pq.ParquetWriter(part_path, part_table.schema)
                        writers[part_id] = writer

                    # Append this batch's rows for this partition as a new row group
                    writer.write_table(part_table)

        # Close all writers
        for w in writers.values():
            w.close()


In [None]:
# Optional: Verify your implementation against pd.merge
def test_hash_partition_join_comprehensive():
    """
    Comprehensive test that validates both structure AND actual data values.
    This ensures the HPJ implementation is truly correct.
    """
    print("="*70)
    print("Comprehensive Hash Partition Join Test")
    print("="*70)
    
    if not os.path.exists('data'):
        print("ERROR: 'data' directory does not exist. Please run the table building cell first.")
        return False
    
    tables = {
        'Songs': ColumnarDbFile("Songs", file_dir='data'),
        'Users': ColumnarDbFile("Users", file_dir='data'),
        'Listens': ColumnarDbFile("Listens", file_dir='data')
    }
    
    all_tests_passed = True
    
    # Test Case 1: Songs JOIN Listens - FULL DATA VALIDATION
    print("\n" + "="*70)
    print("Test Case 1: Songs JOIN Listens - Full Data Validation")
    print("="*70)
    
    songs_table = tables['Songs']
    listens_table = tables['Listens']
    
    songs_cols = ['song_id', 'title']
    listens_cols = ['listen_id', 'song_id', 'user_id']
    
    # Perform joins
    hpj1 = HashPartitionJoin(num_partitions=4, parquet_batch_size=1000)
    result_table1 = hpj1.join(
        songs_table, listens_table,
        join_key1='song_id', join_key2='song_id',
        temp_dir='temp_test_songs_listens_comp',
        columns_table1=songs_cols,
        columns_table2=listens_cols
    )
    
    hpj_result1 = result_table1.retrieve_data()
    
    # Get pd.merge result
    songs_df = songs_table.retrieve_data(columns=songs_cols)
    listens_df = listens_table.retrieve_data(columns=listens_cols)
    pd_result1 = pd.merge(songs_df, listens_df, on='song_id', how='inner')
    
    print(f"\nHPJ result shape: {hpj_result1.shape}")
    print(f"pd.merge result shape: {pd_result1.shape}")
    
    test1_passed = True
    
    # 1. Row count check
    if len(hpj_result1) != len(pd_result1):
        print(f"✗ Row count mismatch! HPJ: {len(hpj_result1)}, pd.merge: {len(pd_result1)}")
        test1_passed = False
        all_tests_passed = False
    else:
        print("✓ Row counts match!")
    
    # 2. Column check
    hpj_cols = set(hpj_result1.columns)
    pd_cols = set(pd_result1.columns)
    if hpj_cols != pd_cols:
        print(f"✗ Column mismatch! HPJ: {hpj_cols}, pd.merge: {pd_cols}")
        test1_passed = False
        all_tests_passed = False
    else:
        print("✓ Columns match!")
    
    if test1_passed:
        # 3. Sort both results for comparison
        sort_cols = ['song_id', 'listen_id'] if 'listen_id' in hpj_result1.columns else ['song_id']
        hpj_sorted = hpj_result1.sort_values(sort_cols).reset_index(drop=True)
        pd_sorted = pd_result1.sort_values(sort_cols).reset_index(drop=True)
        
        # 4. Check unique keys
        hpj_song_ids = set(hpj_result1['song_id'].unique())
        pd_song_ids = set(pd_result1['song_id'].unique())
        if hpj_song_ids != pd_song_ids:
            print(f"✗ Unique song_ids differ!")
            test1_passed = False
            all_tests_passed = False
        else:
            print("✓ Unique song_ids match!")
        
        # 5. FULL DATA VALUE COMPARISON - This is the critical check!
        print("\nPerforming full data value comparison...")
        data_matches = True
        
        # Compare each column
        for col in sorted(hpj_cols):
            hpj_col_data = hpj_sorted[col].values
            pd_col_data = pd_sorted[col].values
            
            # Use np.array_equal for exact comparison
            if not np.array_equal(hpj_col_data, pd_col_data):
                print(f"✗ Column '{col}' data mismatch!")
                
                # Find first mismatch
                mismatch_idx = np.where(hpj_col_data != pd_col_data)[0]
                if len(mismatch_idx) > 0:
                    idx = mismatch_idx[0]
                    print(f"  First mismatch at row {idx}:")
                    print(f"    HPJ: {hpj_col_data[idx]}")
                    print(f"    pd.merge: {pd_col_data[idx]}")
                    print(f"  Total mismatches: {len(mismatch_idx)}")
                
                data_matches = False
                break
        
        if data_matches:
            print("✓ All data values match exactly!")
            print(f"✓ Verified {len(hpj_sorted)} rows × {len(hpj_cols)} columns = {len(hpj_sorted) * len(hpj_cols)} values")
        else:
            print("✗ Data values do NOT match!")
            test1_passed = False
            all_tests_passed = False
        
        # 6. Check for duplicate rows (should be same in both)
        hpj_duplicates = hpj_sorted.duplicated().sum()
        pd_duplicates = pd_sorted.duplicated().sum()
        if hpj_duplicates != pd_duplicates:
            print(f"⚠ Duplicate row counts differ (HPJ: {hpj_duplicates}, pd.merge: {pd_duplicates})")
        else:
            print(f"✓ Duplicate row counts match ({hpj_duplicates} duplicates)")
    
    if test1_passed:
        print("\n✓ Test Case 1 PASSED - Full validation successful!")
    else:
        print("\n✗ Test Case 1 FAILED!")
    
    # Summary
    print("\n" + "="*70)
    print("COMPREHENSIVE TEST SUMMARY")
    print("="*70)
    if all_tests_passed:
        print("✓ ALL TESTS PASSED: Hash Partition Join is CORRECT!")
        print("  - Row counts match")
        print("  - Column structure matches")
        print("  - Unique keys match")
        print("  - ALL DATA VALUES match exactly")
    else:
        print("✗ TESTS FAILED: Implementation has issues")
    print("="*70)
    
    return all_tests_passed

test_hash_partition_join_comprehensive()


Implement `SortMergeJoin`:
1. Sort both tables by join key
2. Merge sorted sequences
3. Handle duplicates

In [14]:
import pyarrow as pa
import pyarrow.parquet as pq

class ParquetGroupChunkIter:
    """
    Streams a Parquet file that is globally sorted by `join_key`,
    yielding chunks of consecutive rows with the same key.
    Each call to next_chunk() returns (table_chunk, key_value, is_last_for_key).
    Memory bounded by IN_BATCH_ROWS and GROUP_CHUNK_ROWS.
    """
    def __init__(self, file_path: str, columns, join_key: str,
                 in_batch_rows: int = 64_000, group_chunk_rows: int = 64_000):
        self._pf = pq.ParquetFile(file_path)
        self._cols = columns
        self._join_key = join_key
        self._key_idx = (self._cols.index(join_key) if self._cols is not None
                 else self._pf.schema_arrow.get_field_index(join_key))
        if self._key_idx == -1:
            raise ValueError(f"join key {join_key} not in schema")
        self._IN_BATCH_ROWS = in_batch_rows
        self._GROUP_CHUNK_ROWS = group_chunk_rows

        self._it = self._pf.iter_batches(columns=self._cols, batch_size=self._IN_BATCH_ROWS)
        self._batch = None
        self._i = 0
        self._cur_key = None
        self._eof = False

    def _next_batch(self):
        try:
            self._batch = next(self._it)
            self._i = 0
        except StopIteration:
            self._batch = None

    def _skip_nulls(self):
        while self._batch is not None:
            key_arr = self._batch.column(self._key_idx)
            n = self._batch.num_rows
            while self._i < n and key_arr[self._i].as_py() is None:
                self._i += 1
            if self._i < n:
                return
            self._next_batch()

    def next_chunk(self):
        if self._eof:
            return None, None, True

        if self._batch is None:
            self._next_batch()
            if self._batch is None:
                self._eof = True
                return None, None, True

        self._skip_nulls()
        if self._batch is None:
            self._eof = True
            return None, None, True

        if self._cur_key is None:
            self._cur_key = self._batch.column(self._key_idx)[self._i].as_py()

        parts = []
        collected = 0
        more_for_key = False

        while collected < self._GROUP_CHUNK_ROWS:
            if self._batch is None:
                break

            key_arr = self._batch.column(self._key_idx)
            n = self._batch.num_rows
            j = self._i

            # take rows while the key matches and we stay under the chunk cap
            while j < n and key_arr[j].as_py() == self._cur_key and \
                  (collected + (j - self._i)) < self._GROUP_CHUNK_ROWS:
                j += 1

            take = j - self._i
            if take > 0:
                parts.append(self._batch.slice(self._i, take))
                collected += take
                self._i = j

            if collected >= self._GROUP_CHUNK_ROWS:
                # capped chunk; if more rows with same key remain, mark continuation
                if self._i < n and key_arr[self._i].as_py() == self._cur_key:
                    more_for_key = True
                else:
                    # peek next batch start
                    save_batch, save_i = self._batch, self._i
                    self._next_batch()
                    if self._batch is not None:
                        k0 = self._batch.column(self._key_idx)[0].as_py()
                        more_for_key = (k0 == self._cur_key)
                    # restore for the caller to continue
                    self._batch, self._i = save_batch, save_i
                break

            if self._i >= n:
                # move to next batch and check if key continues
                self._next_batch()
                self._skip_nulls()
                if self._batch is None:
                    break
                if self._batch.column(self._key_idx)[self._i].as_py() != self._cur_key:
                    break
            else:
                # key changed within this batch
                break

        tbl = pa.Table.from_batches(parts) if parts else None

        # if no more rows remain for this key after this chunk, clear cur_key
        if not more_for_key:
            if self._batch is not None:
                key_arr = self._batch.column(self._key_idx)
                n = self._batch.num_rows
                while self._i < n and key_arr[self._i].as_py() == self._cur_key:
                    self._i += 1
                if self._i >= n:
                    self._next_batch()
            self._cur_key = None

        key_for_chunk = None
        if tbl is not None:
            key_for_chunk = tbl.column(self._key_idx).chunk(0)[0].as_py() \
                if tbl.num_rows > 0 else None

        return tbl, key_for_chunk, not more_for_key

In [15]:

BWAY_MERGE_FACTOR = 10

class SortMergeJoin:
    def __init__(
        self, bway_merge_factor: int = BWAY_MERGE_FACTOR, num_pages_per_split=2
    ):
        self.bway_merge_factor = bway_merge_factor
        self.num_pages_per_split = num_pages_per_split

    def _streaming_inner_join(self,
                            left_sorted: ColumnarDbFile,
                            right_sorted: ColumnarDbFile,
                            join_key1: str,
                            join_key2: str,
                            temp_dir: str,
                            columns_table1: Optional[List[str]],
                            columns_table2: Optional[List[str]],
                            in_batch_rows: int = 64_000,
                            group_chunk_rows: int = 64_000,
                            max_product_rows: int = 200_000) -> ColumnarDbFile:

        def ensure_key(cols, key):
            if cols is None:
                return None
            return cols if key in cols else cols + [key]

        # Compute input file paths
        left_path = os.path.join(left_sorted.base_file_name, f"{left_sorted.table_name}-0.parquet")
        right_path = os.path.join(right_sorted.base_file_name, f"{right_sorted.table_name}-0.parquet")

        # Load schemas to default to all columns if None
        l_pf = pq.ParquetFile(left_path)
        r_pf = pq.ParquetFile(right_path)
        l_schema = l_pf.schema_arrow
        r_schema = r_pf.schema_arrow

        if columns_table1 is None:
            columns_table1 = l_schema.names
        if columns_table2 is None:
            columns_table2 = r_schema.names
        columns_table1 = ensure_key(columns_table1, join_key1)
        columns_table2 = ensure_key(columns_table2, join_key2)

        # Output: keep left key; drop duplicate key from right
        left_out_cols = columns_table1
        right_out_cols = [c for c in columns_table2 if c != join_key2]

        # Prepare iterators
        L = ParquetGroupChunkIter(left_path, columns_table1, join_key1,
                                    in_batch_rows, group_chunk_rows)
        R = ParquetGroupChunkIter(right_path, columns_table2, join_key2,
                                    in_batch_rows, group_chunk_rows)

        # Output writer
        result_table = ColumnarDbFile("join_result", file_dir=temp_dir)
        out_path = os.path.join(result_table.base_file_name, f"{result_table.table_name}-0.parquet")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        writer = None

        def write_join_product(l_tbl: pa.Table, r_tbl: pa.Table):
            nonlocal writer
            if l_tbl is None or r_tbl is None:
                return
            l_df = l_tbl.select(left_out_cols).to_pandas()
            r_df = r_tbl.select(right_out_cols).to_pandas()
            m, n = len(l_df), len(r_df)
            if m == 0 or n == 0:
                return
            block_n = max(1, min(n, max_product_rows // max(1, m)))
            start = 0
            while start < n:
                end = min(n, start + block_n)
                r_block = r_df.iloc[start:end]
                out_df = l_df.assign(_cj=1).merge(r_block.assign(_cj=1), on="_cj").drop(columns=["_cj"])
                out_tbl = pa.Table.from_pandas(out_df, preserve_index=False)
                if writer is None:
                    writer = pq.ParquetWriter(out_path, schema=out_tbl.schema)
                writer.write_table(out_tbl)
                start = end

        # Drive the SMJ
        l_tbl, l_key, l_last = L.next_chunk()
        r_tbl, r_key, r_last = R.next_chunk()

        while l_tbl is not None and r_tbl is not None:
            if l_key is None:
                l_tbl, l_key, l_last = L.next_chunk()
                continue
            if r_key is None:
                r_tbl, r_key, r_last = R.next_chunk()
                continue

            if l_key < r_key:
                l_tbl, l_key, l_last = L.next_chunk()
            elif l_key > r_key:
                r_tbl, r_key, r_last = R.next_chunk()
                
            else:
                # l_key == r_key: join full groups. Cache right group (spill if large),
                # then stream left group chunk-by-chunk against the cached/spilled right.
                import uuid

                MAX_CACHED_GROUP_ROWS = 500_000  # tune as you like

                r_chunks: List[pa.Table] = []
                r_rows = 0
                r_spill_path = None
                r_writer = None

                def _open_r_spill(schema: pa.Schema):
                    nonlocal r_writer, r_spill_path
                    if r_writer is None:
                        r_spill_path = os.path.join(
                            temp_dir, f"_smj_rspill_{uuid.uuid4().hex}.parquet"
                        )
                        r_writer = pq.ParquetWriter(r_spill_path, schema=schema)

                def _cache_r_chunk(tbl: pa.Table):
                    nonlocal r_rows, r_writer
                    if tbl is None or tbl.num_rows == 0:
                        return
                    if r_writer is None and (r_rows + tbl.num_rows) <= MAX_CACHED_GROUP_ROWS:
                        r_chunks.append(tbl)
                    else:
                        if r_writer is None:
                            _open_r_spill(tbl.schema)
                            for t in r_chunks:
                                r_writer.write_table(t)
                            r_chunks.clear()
                        r_writer.write_table(tbl)
                    r_rows += tbl.num_rows

                # 1) Collect entire right group for this key
                _cache_r_chunk(r_tbl)
                while not r_last:
                    r_tbl, r_key2, r_last = R.next_chunk()
                    if r_tbl is None or r_key2 != l_key:
                        break
                    _cache_r_chunk(r_tbl)
                if r_writer is not None:
                    r_writer.close()

                def right_iter():
                    if r_spill_path is not None:
                        pf = pq.ParquetFile(r_spill_path)
                        for b in pf.iter_batches(batch_size=group_chunk_rows):
                            yield pa.Table.from_batches([b])
                    else:
                        for t in r_chunks:
                            yield t

                # 2) Stream left group; for each left chunk, replay right group
                while True:
                    for r_part in right_iter():
                        write_join_product(l_tbl, r_part)
                    if l_last:
                        break
                    l_tbl, l_key2, l_last = L.next_chunk()
                    if l_tbl is None or l_key2 != l_key:
                        break

                # 3) Advance right iterator to next key (we ended exactly at its boundary)
                r_tbl, r_key, r_last = R.next_chunk()
                # proceed to next loop iteration
                continue

        if writer is not None:
            writer.close()
        return result_table

    def _flush_run(
        self,
        dfs: List[pd.DataFrame],
        join_key: str,
        output_dir: str,
        side: str,
        run_idx: int,
    ) -> str:

        df_run = pd.concat(dfs, ignore_index=True)
        df_run_sorted = df_run.sort_values(by=join_key)

        run_file = os.path.join(output_dir, f"{side}_run_{run_idx}.parquet")
        df_run_sorted.to_parquet(run_file)

        dfs.clear()
        del df_run, df_run_sorted
        gc.collect()

        return run_file


    def _external_sort(
        self,
        table: ColumnarDbFile,
        join_key: str,
        output_dir: str,
        side: str,
        columns: Optional[List[str]] = None,
    ) -> ColumnarDbFile:
        """
        Perform an external sort on a table based on the join key and return a sorted ColumnarDbFile.
        Use _bway_merge to merge sorted files
        """

        # Get table size (on disk)
        disk_usage = table.table_disk_usage()
        total_bytes = disk_usage["total_bytes"]

        # Check if we can safely process in 12 GB RAM
        if False: #table.can_process_parquet(total_bytes):

            # read data in and sort all in RAM
            df = table.retrieve_data(columns=columns)
            df_sorted = df.sort_values(by=join_key).reset_index(drop=True)

            # create paraquet in output dir for the table
            sorted_name = f"{side}_{table.table_name}_sorted"
            sorted_table = ColumnarDbFile(sorted_name, file_dir=output_dir)
            sorted_table.build_table(df_sorted)

            # clean unnecessary overhead and return table
            del df, df_sorted
            gc.collect()
            return sorted_table

        else:
            print("sorting table ", table.table_name, "with ", total_bytes, "bytes using external sort")
            print("GBs : ", total_bytes / (1024 * 1024 * 1024))
            # Get list of parquet files in the table directory
            parquet_files = glob.glob(f"{table.base_file_name}/*.parquet")

            runs_path: List[str] = []
            run_idx = 0
            current_dfs: List[pd.DataFrame] = []
            current_row_groups = 0

            # loop through all the parquet files
            print(f"Sorting {len(parquet_files)} files")
            for file in parquet_files:
                pf = pq.ParquetFile(file)

                # safe bounded unit of work for sorting
                num_row_groups = pf.metadata.num_row_groups

                for rg in range(num_row_groups):
                    print("reading row group ", rg)

                    # read a row group as a chunk
                    batch = pf.read_row_group(rg, columns=columns)
                    df_chunk = batch.to_pandas()
                    current_dfs.append(df_chunk)
                    current_row_groups += 1

                    # treating a row group as a page
                    # change to 2 to practice spliting runs
                    if current_row_groups > self.num_pages_per_split:

                        print("flushing run ", run_idx)
                        # sort current run in ram
                        # save as a parquet file in the current run directory
                        run_file = self._flush_run(
                        current_dfs, join_key, output_dir, side, run_idx
                        )

                        # runs path is a list of the sorted parquet files
                        runs_path.append(run_file)
                        run_idx += 1
                        current_row_groups = 0
                        print(f"Flushed run {run_idx} at {run_file}")

            # flush remaining partial run, get sorted run_file as a df
            if current_dfs:
              run_file = self._flush_run(
                  current_dfs, join_key, output_dir, side, run_idx
              )
              runs_path.append(run_file)
            # output_dir has a list of sorted parquet files for the current table

            # Create the wrapper first so we write where it will read
            sorted_table = ColumnarDbFile(
                table_name=f"{side}_{table.table_name}_sorted",
                file_dir=output_dir,
            )

            # Write the final merged file inside that directory, matching ColumnarDbFile
            final_sorted_path = os.path.join(
                sorted_table.base_file_name, f"{sorted_table.table_name}-0.parquet"
            )
            print("merging all runs into ", final_sorted_path)
            self._merge_all_runs(runs_path, final_sorted_path, join_key)

            return sorted_table

    def _merge_all_runs(self, sorted_files: List[str], output_file: str, join_key: str):
        """
        Merge multiple sorted Parquet files into a single sorted Parquet file.
        """
        B = self.bway_merge_factor

        # copy that we will mutate
        runs = list(sorted_files)
        pass_idx = 0

        # while list of runs left to merge
        while len(runs) > 1:
          print("merging pass ", pass_idx)
          next_runs = []

          # B - 1 input buffers +1 output buffer
          for i in range(0, len(runs), B - 1):
                batch = runs[i : i + (B - 1)]   # B-1 input buffers, still a list

                # choose an output path for this merged batch
                # on the final pass, we want the result at `output_file`
                if len(runs) <= B - 1 and len(next_runs) == 0:
                    # last pass, first (and only) merged run -> final output
                    merged_path = output_file
                else:
                    # intermediate pass: write to a temp run file
                    base_dir = os.path.dirname(output_file)
                    merged_path = os.path.join(
                        base_dir,
                        f"bway_pass{pass_idx}_run{len(next_runs)}.parquet",
                    )

                # b-way merge this batch into merged_path
                self._bway_merge(batch, merged_path, join_key)

                next_runs.append(merged_path)

          runs = next_runs
          pass_idx += 1

        # At this point, runs has exactly one file.
        final_run = runs[0]
        if final_run != output_file:
            # In case we didn't land exactly on output_file path
            if os.path.exists(output_file):
                os.remove(output_file)
            shutil.move(final_run, output_file)

        return output_file


    def _bway_merge(self, sorted_files: List[str], output_file: str, join_key: str):
        """
        Streaming B-way merge of already-sorted Parquet 'runs' into a single
        sorted Parquet file at `output_file`, ordered by `join_key`.
        Memory use is bounded: keeps one small batch per input + an output buffer.
        """
        import heapq
        import pyarrow as pa
        import pyarrow.parquet as pq

        if not sorted_files:
            raise ValueError("No input files to merge.")

        # Tunables: keep these moderate to bound memory
        IN_BATCH_ROWS = 64_000         # rows read per input batch
        OUT_ROW_GROUP_ROWS = 128_000   # rows written per output row group

        # Determine common schema and key index from first file
        first_pf = pq.ParquetFile(sorted_files[0])
        schema = first_pf.schema_arrow
        key_index = schema.get_field_index(join_key)
        if key_index == -1:
            raise ValueError(f"join_key '{join_key}' not found in schema.")

        # Initialize streaming readers (preserve on-disk order)
        readers = [pq.ParquetFile(p).iter_batches(batch_size=IN_BATCH_ROWS) for p in sorted_files]

        # Per-input state
        states = []  # list of dicts: { 'iter': iterator, 'batch': RecordBatch|None, 'ridx': int }
        heap = []    # min-heap of ((is_null, key_value), src_idx)

        def _prime_state(it):
            try:
                return next(it)
            except StopIteration:
                return None

        def _push_heap_for_src(src_idx):
            st = states[src_idx]
            batch = st["batch"]
            ridx = st["ridx"]
            if batch is None or ridx >= batch.num_rows:
                return
            key_arr = batch.column(key_index)
            key_val = key_arr[ridx].as_py()
            # Use (is_null, value) to make None sort after non-null deterministically
            heapq.heappush(heap, ((key_val is None, key_val), src_idx))

        # Prime batches and heap
        for it in readers:
            batch = _prime_state(it)
            states.append({"iter": it, "batch": batch, "ridx": 0})
        for i in range(len(states)):
            if states[i]["batch"] is not None and states[i]["batch"].num_rows > 0:
                _push_heap_for_src(i)

        # Open output writer
        writer = pq.ParquetWriter(output_file, schema=schema)
        out_batches: List[pa.RecordBatch] = []
        out_rows = 0

        # Core streaming merge loop
        while heap:
            _, src_idx = heapq.heappop(heap)
            st = states[src_idx]
            batch = st["batch"]
            ridx = st["ridx"]

            # Append a single-row slice to output buffer
            out_batches.append(batch.slice(ridx, 1))
            out_rows += 1

            # Advance this source
            st["ridx"] += 1
            if st["ridx"] >= batch.num_rows:
                st["batch"] = _prime_state(st["iter"])
                st["ridx"] = 0
            if st["batch"] is not None:
                _push_heap_for_src(src_idx)

            # Flush output buffer as a row group when big enough
            if out_rows >= OUT_ROW_GROUP_ROWS:
                writer.write_table(pa.Table.from_batches(out_batches))
                out_batches.clear()
                out_rows = 0

        # Final flush
        if out_rows > 0:
            writer.write_table(pa.Table.from_batches(out_batches))
        writer.close()

    def join(
        self,
        table1: ColumnarDbFile,
        table2: ColumnarDbFile,
        join_key1: str,
        join_key2: str,
        temp_dir: str = "temp",
        columns_table1: Optional[List[str]] = None,
        columns_table2: Optional[List[str]] = None,
    ) -> Optional[ColumnarDbFile]:
        """
        Perform a sort-merge join between two ColumnarDbFile instances and return a sorted ColumnarDbFile.
        """
        os.makedirs(temp_dir, exist_ok=True)

        # Sort both tables externally
        sorted_table1 = self._external_sort(
            table1, join_key1, temp_dir, "left", columns_table1
        )
        sorted_table2 = self._external_sort(
            table2, join_key2, temp_dir, "right", columns_table2
        )

        result_table = self._streaming_inner_join(
            left_sorted=sorted_table1,
            right_sorted=sorted_table2,
            join_key1=join_key1,
            join_key2=join_key2,
            temp_dir=temp_dir,
            columns_table1=columns_table1,
            columns_table2=columns_table2,
        )
        return result_table

In [None]:
songs_table = tables['Songs']
users_table = tables['Users']
listens_table = tables['Listens']

smj = SortMergeJoin()

# Example: join Songs with Listens on song_id
sorted_join_result = smj.join(
    songs_table,
    listens_table,
    join_key1="song_id",
    join_key2="song_id",
    temp_dir="temp_songs_listens",
    columns_table1= ["song_id", "title"],
    columns_table2= ["song_id", "user_id"]
)

display(sorted_join_result)


In [17]:

# Correctness test vs pandas on a manageable subset
import os, shutil
import pandas as pd

# Choose columns and sample sizes
JOIN_KEY_LEFT = "song_id"
JOIN_KEY_RIGHT = "song_id"
LEFT_COLS = ["song_id", "title"]
RIGHT_COLS = ["song_id", "user_id"]

N_LEFT = 200_000
N_RIGHT = 200_000

# 1) Create temporary, smaller ColumnarDbFiles from the big tables
TEST_DIR = "temp_smj_correctness"
if os.path.exists(TEST_DIR):
    shutil.rmtree(TEST_DIR)
os.makedirs(TEST_DIR, exist_ok=True)

left_df_full = tables["Songs"].retrieve_data(columns=LEFT_COLS)
right_df_full = tables["Listens"].retrieve_data(columns=RIGHT_COLS)

left_df = left_df_full.head(N_LEFT).reset_index(drop=True)
right_df = right_df_full.head(N_RIGHT).reset_index(drop=True)

left_tmp = ColumnarDbFile("LeftTest", file_dir=TEST_DIR)
right_tmp = ColumnarDbFile("RightTest", file_dir=TEST_DIR)
left_tmp.build_table(left_df)
right_tmp.build_table(right_df)

# 2) Run your streaming SMJ on the temp tables
smj = SortMergeJoin()
smj_out = smj.join(
    left_tmp,
    right_tmp,
    join_key1=JOIN_KEY_LEFT,
    join_key2=JOIN_KEY_RIGHT,
    temp_dir=os.path.join(TEST_DIR, "join_out"),
    columns_table1=LEFT_COLS,
    columns_table2=RIGHT_COLS,
)

# 3) Materialize results
# SMJ output keeps only the left join key (your implementation drops the right duplicate)
smj_df = smj_out.retrieve_data()

# Pandas baseline: merge then drop the duplicate right key to match SMJ schema
baseline_df = pd.merge(
    left_df,
    right_df,
    left_on=JOIN_KEY_LEFT,
    right_on=JOIN_KEY_RIGHT,
    how="inner",
)
baseline_df = baseline_df[LEFT_COLS + [c for c in RIGHT_COLS if c != JOIN_KEY_RIGHT]]

# 4) Normalize order and compare
order_cols = [JOIN_KEY_LEFT] + [c for c in LEFT_COLS if c != JOIN_KEY_LEFT] + [c for c in RIGHT_COLS if c != JOIN_KEY_RIGHT]
smj_df = smj_df.sort_values(by=order_cols).reset_index(drop=True)
baseline_df = baseline_df.sort_values(by=order_cols).reset_index(drop=True)

# Ensure identical dtypes for fair compare when feasible
for c in order_cols:
    if c in smj_df.columns and c in baseline_df.columns:
        try:
            baseline_df[c] = baseline_df[c].astype(smj_df[c].dtype)
        except Exception:
            pass

# 5) Assertions and quick diagnostics
same_shape = smj_df.shape == baseline_df.shape
same_rows = smj_df.equals(baseline_df)

print("Row counts - SMJ vs Pandas:", smj_df.shape[0], baseline_df.shape[0])
print("Column sets equal:", set(smj_df.columns) == set(baseline_df.columns))
print("Shapes equal:", same_shape)
print("Frames equal:", same_rows)

if not same_rows:
    # Show a small diff preview
    from itertools import islice
    merged_chk = smj_df.merge(
        baseline_df, how="outer", indicator=True, on=order_cols
    )
    print("Mismatched samples:")
    display(merged_chk[merged_chk["_merge"] != "both"].head(10))

sorting table  LeftTest with  1269458 bytes using external sort
GBs :  0.0011822748929262161
Sorting 1 files
reading row group  0
merging all runs into  temp_smj_correctness/join_out/_left_LeftTest_sorted/left_LeftTest_sorted-0.parquet
sorting table  RightTest with  2034101 bytes using external sort
GBs :  0.0018944041803479195
Sorting 1 files
reading row group  0
merging all runs into  temp_smj_correctness/join_out/_right_RightTest_sorted/right_RightTest_sorted-0.parquet
Row counts - SMJ vs Pandas: 200000 200000
Column sets equal: True
Shapes equal: True
Frames equal: True


Implement GROUP BY after joins:
- Here you could use `pd.groupby` or do manual aggregation

In [None]:
# Your implementation here

# Section 4: Query Planning & Optimization

In this section, you'll implement smart query planning using metadata analysis. The key idea is to **avoid loading data unnecessarily** by:
1. Analyzing Parquet metadata first (row counts, column names, file sizes)
2. Making intelligent decisions about join order and algorithm selection
3. Loading only the columns you actually need for the query

In [None]:
def analyze_metadata_before_loading(file_paths):
    """YOUR TASK: Get table statistics WITHOUT loading data
    
    Hints:
    - Use pq.ParquetFile() to access metadata
    - Extract: num_rows, column names, file sizes
    - DON'T use pd.read_parquet() here - that loads data!
    """
    # TODO: For each table ('songs', 'users', 'listens'):
    #   - Open the Parquet file (but don't load data)
    #   - Extract metadata like row count, columns, sizes
    #   - Store in a dictionary
    pass  # Your implementation here
    metadata = {}

    for table_name, coldb in tables.items():
        # Use provided helpers for counts/sizes
        tm = coldb.table_metadata()        # num_files, total_rows, total_compressed_bytes
        du = coldb.table_disk_usage()      # num_files, total_bytes (on disk)

        # Get schema from the first parquet file (metadata only)
        base_dir = coldb.base_file_name
        parquet_files = sorted(glob.glob(f"{base_dir}/*.parquet"))
        columns = {}
        if parquet_files:
            pf = pq.ParquetFile(parquet_files[0])
            # Use Arrow schema to read field names and types
            columns = {field.name: str(field.type) for field in pf.schema_arrow}

        metadata[table_name] = {
            "num_files": tm["num_files"],
            "rows": int(tm["total_rows"]),
            "columns": columns,  # name -> type
            "bytes_on_disk": int(du["total_bytes"]),
            "total_compressed_bytes": int(tm["total_compressed_bytes"]),
            "can_process_in_12GB": ColumnarDbFile.can_process_parquet(int(du["total_bytes"]))
        }

    return metadata

def estimate_smj_build_bytes
    # BigSort(R) + BigSort(S) + MergeJoin(R, S) + C_w * Out


def plan_query_execution(metadata, parsed_query, memory_bytes=12 * 1024**3, overhead_per_row=24):
    """YOUR TASK: Use metadata to make smart decisions

    Questions to answer:
    - Which table is smallest? Largest?
    - Will a hash table fit in memory?
    - Which columns does the query actually need?
    - What's the optimal join order?
    """
    # TODO: Based on metadata, decide:
    #   1. Join order (smallest first? or different strategy?)
    #   2. Algorithm choice (HPJ if fits in memory, else SMJ)
    #   3. Which columns to load for each table
    """
    Use parsed SQL + table metadata to:
      - pick columns to read (column pruning)
      - choose join order (follow SQL joins; build on smaller side)
      - choose algorithm per step (HPJ if build fits in memory, else SMJ)
    Returns a plan dict.
    """
    # 0) Helpers
    def meta_key_for(table_name_lower: str) -> str:
        # Our metadata keys are capitalized: 'Songs','Users','Listens'
        return table_name_lower.capitalize()

    def size_of_type(t: str) -> int:
        t = (t or "").lower()
        if "int64" in t or "timestamp" in t or "float64" in t or "double" in t:
            return 8
        if "int32" in t or "float32" in t:
            return 4
        if "bool" in t:
            return 1
        # strings/unknown: average placeholder
        return 16

    def estimate_build_bytes(rows: int, cols: list[str], column_types: dict[str,str]) -> int:
        if rows <= 0:
            return 0
        per_row = sum(size_of_type(column_types.get(c, "")) for c in cols) + overhead_per_row
        return int(rows * per_row)

    # 1) Alias: table-name mappings
    alias_to_table_lower = parsed_query["tables"]
    alias_to_meta = {a: meta_key_for(t) for a, t in alias_to_table_lower.items()}

    # 2) Columns needed per alias
    cols_needed_by_alias: dict[str, set] = {a: set() for a in alias_to_table_lower.keys()}

    # 2a) Joins (include join keys)
    for j in parsed_query["joins"]["Joins"]:
        cols_needed_by_alias[j["left_alias"]].add(j["left_column"])
        cols_needed_by_alias[j["right_alias"]].add(j["right_column"])

    # 2b) GROUP BY columns
    for a, c in parsed_query.get("GroupBy", []):
        cols_needed_by_alias[a].add(c)

    # 2c) SELECT items and aggregations
    for item in parsed_query.get("select", []):
        if item["kind"] == "column":
            a, c = item["source"]
            cols_needed_by_alias[a].add(c)

    for agg in parsed_query.get("aggregations", {}).values():
        a, c = agg["source"]
        cols_needed_by_alias[a].add(c)

    # 2d) ORDER BY columns (aggregations already covered)
    for ob in parsed_query.get("orderBy", []):
        if ob["kind"] == "column":
            a, c = ob["source"]
            cols_needed_by_alias[a].add(c)

    # 2e) Convert to real table names
    columns_to_load = {}
    for a, cols in cols_needed_by_alias.items():
        tkey = alias_to_meta[a]
        # intersect with actual table columns for safety
        actual_cols = set(metadata[tkey]["columns"].keys())
        columns_to_load[tkey] = sorted(list(set(cols) & actual_cols)) if actual_cols else sorted(list(cols))

    # 3) Plan join steps (follow parsed order; build smaller side each time)
    base_alias, base_table_lower = parsed_query["joins"]["base_table"]
    current_alias = base_alias
    current_tkey = alias_to_meta[current_alias]
    current_rows = int(metadata[current_tkey]["rows"])

    join_steps = []

    for j in parsed_query["joins"]["Joins"]:
        joined_alias = j["joined_table_alias"]
        joined_tkey = alias_to_meta[joined_alias]
        joined_rows = int(metadata[joined_tkey]["rows"])

        # Which side's key belongs to the current relation?
        if j["left_alias"] == joined_alias:
            joined_key = j["left_column"]
            current_key = j["right_column"]
        else:
            joined_key = j["right_column"]
            current_key = j["left_column"]

        # Determine which alias supplies the "current" key for this join
        provider_alias = j["left_alias"] if j["left_alias"] != joined_alias else j["right_alias"]
        current_types_tkey = alias_to_meta[provider_alias]

        # Estimate build memory for both choices
        current_build_cols = [current_key]
        joined_build_cols = columns_to_load.get(joined_tkey, []) or [joined_key]

        current_build_bytes = estimate_build_bytes(
            current_rows, current_build_cols, metadata[current_types_tkey]["columns"]
        )
        joined_build_bytes = estimate_build_bytes(
            joined_rows, joined_build_cols, metadata[joined_tkey]["columns"]
        )

        # Choose smaller build
        build_side = "current" if current_build_bytes <= joined_build_bytes else "joined"
        build_bytes = min(current_build_bytes, joined_build_bytes)
        algorithm = "HPJ" if build_bytes <= memory_bytes else "SMJ"

        # Record step (left=current relation, right=joined table)
        step = {
            "left_relation": current_tkey,                 # current so far
            "right_relation": joined_tkey,                 # the new table
            "left_on": current_key,
            "right_on": joined_key,
            "build_side": build_side,                      # 'current' or 'joined'
            "estimated_build_bytes": build_bytes,
            "algorithm": algorithm,
        }
        join_steps.append(step)

        # Update current relation stats: result rows ~ max of inputs for FK joins
        current_rows = max(current_rows, joined_rows)
        current_tkey = f"({current_tkey}⋈{joined_tkey})"   # label for readability

    plan = {
        "columns_to_load": columns_to_load,        # table-name -> [columns]
        "join_steps": join_steps,                  # ordered join execution plan
        "memory_budget_bytes": memory_bytes,
        "group_by": parsed_query.get("GroupBy", []),
        "select": parsed_query.get("select", []),
        "aggregations": parsed_query.get("aggregations", {}),
        "order_by": parsed_query.get("orderBy", []),
    }
    return plan


# After planning, load ONLY what you need:
# Example (you implement the actual logic):
# columns_needed = ['song_id', 'artist']  # From your planning
# df = pd.read_parquet('songs.parquet', columns=columns_needed)

In [41]:
joins = parsed_sql["joins"]["Joins"]
for join in joins:
    print(join)
print("\n\n")

metadata = analyze_metadata_before_loading(tables)
for key, value in metadata.items():
    print(key, value)
print("\n\n")

plan = plan_query_execution(metadata, parsed_sql)
for key, value in plan.items():
    print(key, value)


{'joined_table_alias': 'l', 'joined_table_name': 'listens', 'left_alias': 's', 'left_column': 'song_id', 'right_alias': 'l', 'right_column': 'song_id'}
{'joined_table_alias': 'u', 'joined_table_name': 'users', 'left_alias': 'l', 'left_column': 'user_id', 'right_alias': 'u', 'right_column': 'user_id'}



Songs {'num_files': 1, 'rows': 100000, 'columns': {'song_id': 'int64', 'title': 'string', 'extra_col_1': 'string', 'extra_col_2': 'string', 'extra_col_3': 'string', 'extra_col_4': 'string', 'extra_col_5': 'string', 'extra_col_6': 'string', 'extra_col_7': 'string', 'extra_col_8': 'string', 'extra_col_9': 'string', 'extra_col_10': 'string'}, 'bytes_on_disk': 43162300, 'total_compressed_bytes': 8782, 'can_process_in_12GB': True}
Users {'num_files': 1, 'rows': 500000, 'columns': {'user_id': 'int64', 'age': 'int64', 'extra_col_1': 'string', 'extra_col_2': 'string', 'extra_col_3': 'string', 'extra_col_4': 'string', 'extra_col_5': 'string', 'extra_col_6': 'string', 'extra_col_7': 'string', 'ex

In [None]:
class QueryPlanner:
    pass # Your implementation here


class QueryExecutor:
    def __init__(self, tables, num_partitions=8, output_dir="temp", planner=None):
        self.tables = tables
        self.num_partitions = num_partitions
        self.output_dir = output_dir
        self.planner = planner or QueryPlanner()
        os.makedirs(self.output_dir, exist_ok=True)

    def execute_hardcoded_query(self):
        """
        Executes the following SQL query:

        SELECT s.song_id, AVG(u.age) AS avg_age,
        COUNT(DISTINCT l.user_id)
        FROM Songs s
        JOIN Listens l ON s.song_id = l.song_id
        JOIN Users u ON l.user_id = u.user_id
        GROUP BY s.song_id, s.title
        ORDER BY COUNT(DISTINCT l.user_id) DESC, s.song_id;
        """

        # Your implementation here

# Section 5: Performance Benchmarking

In [None]:
def benchmark_query(executor, dataset_size):
    """Benchmark the query execution time and memory usage."""
    print(f"\nBenchmarking with {dataset_size} dataset...")
    start_mem = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
    start_time = time.time()

    result = executor.execute_hardcoded_query()

    end_time = time.time()
    end_mem = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)

    print(f"Execution Time: {end_time - start_time:.2f} seconds")
    print(f"Memory Usage: {end_mem - start_mem:.2f} MB")
    return result

## 100MB Benchmark

In [None]:
# Your implementation here

## 1GB Benchmark

In [None]:
# Your implementation here

## Performance Analysis

In [None]:
# Your implementation here