In [None]:
# if the following command generates an error, you probably didn't enable 
# the cluster security option "Allow API access to all Google Cloud services"
# under Manage Security â†’ Project Access when setting up the cluster
!gcloud dataproc clusters list --region us-central1

In [None]:
!pip install -q graphframes

In [None]:
import pyspark
import sys
from collections import Counter, OrderedDict, defaultdict
import itertools
from itertools import islice, count, groupby
import pandas as pd
import os
import re
from operator import itemgetter
import nltk
from nltk.stem.porter import *
from nltk.corpus import stopwords
from time import time
from pathlib import Path
import pickle
import pandas as pd
from google.cloud import storage

import hashlib
def _hash(s):
    return hashlib.blake2b(bytes(s, encoding='utf8'), digest_size=5).hexdigest()

nltk.download('stopwords')

from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf, SparkFiles
from pyspark.sql import SQLContext
from graphframes import *

In [None]:
# if nothing prints here you forgot to include the initialization script when starting the cluster
!ls -l /usr/lib/spark/jars/graph*

In [None]:
spark

In [None]:
bucket_name = 'ir_3_207472234' 
relative_path = f"meta_data"
client = storage.Client()

In [None]:
paths = "gs://ir_3_207472234/multistream*_preprocessed.parquet"

parquetFile = spark.read.parquet(paths)

doc_title_pairs = parquetFile.select("title", "id").rdd


In [None]:
doc_title_pairs.take(5)

In [None]:
parquetFile.count()

In [None]:
from pyspark import SparkFiles
from pyspark.sql import functions as F
from pyspark.storagelevel import StorageLevel
import builtins
import numpy as np
import os, struct

# -----------------------------
# Config
# -----------------------------
GCS_BASE = "gs://ir_3_207472234/metadata" 
ID2POS_GCS = "gs://ir_3_207472234/meta_data/doc_id_to_pos.npy"

NUM_PARTS = 256  # tune: 256-2048 depending on cluster


# Make doc_id_to_pos.npy available on every executor
sc.addFile(ID2POS_GCS)

# -----------------------------
# Map to (pos, title_bytes)
# - load id_to_pos once per executor JVM via lazy global
# -----------------------------
_id2pos = None

def get_id2pos():
    global _id2pos
    if _id2pos is None:
        path = SparkFiles.get("doc_id_to_pos.npy")
        # mmap = low memory, fast enough; file is local on executor
        _id2pos = np.load(path, mmap_mode="r")
    return _id2pos

id2pos = get_id2pos()
INVALID_POS = np.uint32(2**32 - 1) if id2pos.dtype == np.uint32 else -1

def to_pos_and_bytes(row):
    # row is pyspark.sql.Row(title=..., id=...)
    doc_id = int(row["id"])
    title = row["title"]
    if title is None:
        title = ""  # keep empty titles

    id2pos = get_id2pos()
    pos = id2pos[doc_id]

    # skip missing doc_id if your mapping uses sentinel
    if pos == INVALID_POS:
        return None

    b = title.encode("utf-8")
    return (int(pos), b)

pos_bytes = (
    doc_title_pairs
    .map(to_pos_and_bytes)
    .filter(lambda x: x is not None)
)

# -----------------------------
# Ensure order by pos (required for contiguous blob indexing)
# Use repartitionAndSortWithinPartitions for scalable sort
# -----------------------------
from pyspark.rdd import portable_hash


sorted_rdd = pos_bytes.sortBy(lambda x: x[0], numPartitions=NUM_PARTS)

# Materialize cache
sorted_rdd = sorted_rdd.persist(StorageLevel.DISK_ONLY)
_ = sorted_rdd.count()

# -----------------------------
# Pass A: partition summaries (byte totals + doc counts)
# -----------------------------
def part_summary(iterable):
    n = 0
    total_bytes = 0
    last_pos = None
    for pos, b in iterable:
        n += 1
        total_bytes += len(b)
        last_pos = pos
    # return (n_docs, total_bytes, last_pos) for sanity checks
    yield (n, total_bytes, last_pos if last_pos is not None else -1)

summaries = sorted_rdd.mapPartitions(part_summary).collect()

part_doc_counts = [s[0] for s in summaries]
part_byte_counts = [s[1] for s in summaries]

# prefix sums for base byte offset per partition
base_offsets = [0] * len(part_byte_counts)
running = 0
for i, bc in enumerate(part_byte_counts):
    base_offsets[i] = running
    running += bc

total_docs = builtins.sum(part_doc_counts)
total_bytes = running

print("total_docs:", total_docs, "total_bytes:", total_bytes)

base_offsets_bc = sc.broadcast(base_offsets)

In [None]:
!hadoop fs -mkdir -p gs://ir_3_207472234/metadata/tmp_titles_data
!hadoop fs -mkdir -p gs://ir_3_207472234/metadata/tmp_titles_offsets

In [None]:
!hadoop fs -ls gs://ir_3_207472234/metadata/ | head
!hadoop fs -ls gs://ir_3_207472234/metadata/tmp_titles_data | head


In [None]:
# -----------------------------
# Pass B: write binary shards
# We'll write to local executor disk then use Hadoop FS to copy to GCS
# (works in Dataproc + most Spark-on-GCP setups)
# -----------------------------
from py4j.java_gateway import java_import



end_off_path = "/tmp/end_off.bin"
with open(end_off_path, "wb") as f:
    f.write(struct.pack("<Q", total_bytes))

import os, struct, subprocess




def write_shards(part_idx, iterable):
    base = base_offsets_bc.value[part_idx]
    cur = base

    tmp_dir = f"/tmp/p5_titles_{part_idx:05d}"
    os.makedirs(tmp_dir, exist_ok=True)

    data_path = os.path.join(tmp_dir, f"titles_data_part_{part_idx:05d}.bin")
    off_path  = os.path.join(tmp_dir, f"titles_offsets_part_{part_idx:05d}.bin")

    with open(data_path, "wb") as fdata, open(off_path, "wb") as foff:
        for pos, b in iterable:
            foff.write(struct.pack("<Q", cur))
            fdata.write(b)
            cur += len(b)
            
    gcs_data_dir = f"{GCS_BASE}/tmp_titles_data"
    gcs_off_dir  = f"{GCS_BASE}/tmp_titles_offsets"
    
    subprocess.check_call([
        "hadoop", "fs", "-copyFromLocal", "-f",
        data_path,
        f"{gcs_data_dir}/part-{part_idx:05d}.bin"
    ])

    subprocess.check_call([
        "hadoop", "fs", "-copyFromLocal", "-f",
        off_path,
        f"{gcs_off_dir}/part-{part_idx:05d}.bin"
    ])
    yield (part_idx, cur)




end_offsets = sorted_rdd.mapPartitionsWithIndex(write_shards).collect()
end_offsets.sort()

for part_idx, cur_end in end_offsets:
    expected_end = base_offsets[part_idx] + part_byte_counts[part_idx]
    if cur_end != expected_end:
        print("Mismatch at part", part_idx, "cur_end", cur_end, "expected", expected_end)
        break
else:
    print("All partition byte counts match.")


print("done writing shards")



In [None]:
!hadoop fs -ls gs://ir_3_207472234/metadata/tmp_titles_data | wc -l

In [None]:
BUCKET = "gs://ir_3_207472234"
GCS_BASE = f"{BUCKET}/metadata"

FINAL_DATA = f"{GCS_BASE}/titles_data.bin"
FINAL_OFFS = f"{GCS_BASE}/titles_offsets.bin"

TMP_DATA_DIR = f"{GCS_BASE}/tmp_titles_data"
TMP_OFFS_DIR = f"{GCS_BASE}/tmp_titles_offsets"
TMP_END_OFF  = f"{TMP_OFFS_DIR}/_end_off.bin"

In [None]:
!gsutil ls gs://ir_3_207472234/metadata/tmp_titles_data/part-*.bin | sort > /tmp/data_parts.txt

In [None]:
!rm -rf /tmp/data_batches && mkdir -p /tmp/data_batches
!split -l 32 /tmp/data_parts.txt /tmp/data_batches/batch_


In [None]:
%%bash
# compose each batch -> intermediate object
i=0
for f in /tmp/data_batches/batch_*; do
  out="gs://ir_3_207472234/metadata/_tmp_titles_data_inter_${i}.bin"
  gsutil compose $(cat "$f") "$out"
  i=$((i+1))
don

In [None]:
!gsutil ls gs://ir_3_207472234/metadata/_tmp_titles_data_inter_*.bin | sort > /tmp/data_inter.txt
!gsutil compose $(cat /tmp/data_inter.txt) gs://ir_3_207472234/metadata/titles_data.bin

In [None]:
!gsutil rm {TMP_DATA_DIR}/part-*.bin || true
!gsutil rm {TMP_DATA_INTER}*.bin || true

In [None]:
!gsutil ls gs://ir_3_207472234/metadata/tmp_titles_offsets/part-*.bin | sort > /tmp/offs_parts.txt

In [None]:
!gsutil cp /tmp/end_off.bin {TMP_END_OFF}

In [None]:
with open("/tmp/offs_parts.txt", "a") as f:
    f.write(TMP_END_OFF + "\n")

In [None]:
!rm -rf /tmp/offs_batches && mkdir -p /tmp/offs_batches
!split -l 32 /tmp/offs_parts.txt /tmp/offs_batches/batch_

In [None]:
%%bash
i=0
for f in /tmp/offs_batches/batch_*; do
  out="gs://ir_3_207472234/metadata/_tmp_titles_offsets_inter_${i}.bin"
  gsutil compose $(cat "$f") "$out"
  i=$((i+1))
done

In [None]:
!gsutil ls gs://ir_3_207472234/metadata/_tmp_titles_offsets_inter_*.bin | sort > /tmp/offs_inter.txt
!gsutil compose $(cat /tmp/offs_inter.txt) gs://ir_3_207472234/metadata/titles_offsets.bin

In [None]:
!wc -l /tmp/offs_inter.txt
!tail -3 /tmp/offs_inter.txt

In [None]:
!gsutil rm gs://ir_3_207472234/metadata/tmp_titles_data/part-*.bin || true
!gsutil rm gs://ir_3_207472234/metadata/tmp_titles_offsets/part-*.bin || true

!gsutil rm gs://ir_3_207472234/metadata/_tmp_titles_data_inter_*.bin || true
!gsutil rm gs://ir_3_207472234/metadata/_tmp_titles_offsets_inter_*.bin || true

!gsutil rm gs://ir_3_207472234/metadata/_tmp_titles_offsets_end.bin || true

In [None]:
!gsutil ls -l gs://ir_3_207472234/metadata/titles_data.bin
!gsutil ls -l gs://ir_3_207472234/metadata/titles_offsets.bin

In [None]:
import numpy as np, os, subprocess, tempfile

FINAL_DATA = "gs://ir_3_207472234/metadata/titles_data.bin"
FINAL_OFFS = "gs://ir_3_207472234/metadata/titles_offsets.bin"


tmpdir = tempfile.mkdtemp()
local_offs = os.path.join(tmpdir, "titles_offsets.bin")

# download offsets (not huge)
subprocess.check_call(["gsutil", "cp", FINAL_OFFS, local_offs])

offs = np.memmap(local_offs, dtype=np.uint64, mode="r")
print("offsets_len:", len(offs))
print("last_offset:", int(offs[-1]))

# get data length without downloading it
lines = subprocess.check_output(["gsutil", "ls", "-l", FINAL_DATA], text=True).strip().splitlines()
first = [ln for ln in lines if not ln.startswith("TOTAL:")][0]
data_len = int(first.split()[0])
print("data_len:", data_len)

assert int(offs[-1]) == data_len
print("OK: offsets[-1] matches titles_data length")

In [None]:
# use example

# id_to_pos = np.load("doc_id_to_pos.npy", mmap_mode="r")
# offsets   = np.memmap("titles_offsets.bin", dtype=np.uint64, mode="r")
# data      = np.memmap("titles_data.bin", dtype=np.uint8,  mode="r")
# INVALID_POS = np.uint32(2**32 - 1) if id2pos.dtype == np.uint32 else -1

# def get_title(doc_id: int) -> str:
#     pos = id_to_pos[doc_id]
#     if pos == INVALID_POS:
#         return ""
#     start = offsets[pos]
#     end   = offsets[pos + 1]
#     return data[start:end].tobytes().decode("utf-8")
