Imports

In [1]:
import pandas as pd
from datasets import load_dataset
import random

from typing import Tuple, List, Dict
import numpy as np
from concurrent.futures import ProcessPoolExecutor
from pandas.core.frame import DataFrame
from tqdm import tqdm  # Add this import for progress tracking
import multiprocessing as mp

Set Seed

In [2]:
random.seed(42)

Data Processing Functions

In [3]:
def unravel_passages(dataset):
    # Preallocate lists for better memory efficiency
    n_total = sum(len(p['passage_text']) for p in dataset['passages'])
    queries = np.empty(n_total, dtype=object)
    passages = np.empty(n_total, dtype=object)
    urls = np.empty(n_total, dtype=object)
    
    idx = 0
    for i, query in enumerate(dataset['query']):
        n_passages = len(dataset['passages'][i]['passage_text'])
        queries[idx:idx+n_passages] = [query] * n_passages
        passages[idx:idx+n_passages] = dataset['passages'][i]['passage_text']
        urls[idx:idx+n_passages] = dataset['passages'][i]['url']
        idx += n_passages
    
    return pd.DataFrame({'query': queries, 'passage': passages, 'url': urls})

def pre_sample_irrelevant(all_passages_ids, num_queries, samples_per_query=20):
    # Pre-sample irrelevant passages for each query
    pre_samples = {query_id: random.sample(all_passages_ids, samples_per_query) 
                   for query_id in range(num_queries)}
    return pre_samples

def create_triplets_dataframe(unraveled_data, pre_samples):
    # Use vectorized operations instead of apply
    relevant_passages = unraveled_data.groupby('query_id')['passage_id'].agg(list).reset_index(name='relevant')
    all_passages_ids = set(unraveled_data['passage_id'])
    
    # Vectorized filtering using numpy operations
    relevant_passages['irrelevant'] = relevant_passages.apply(
        lambda row: np.setdiff1d(
            pre_samples[row['query_id']], 
            row['relevant']
        )[:len(row['relevant'])].tolist(),
        axis=1
    )
    
    return relevant_passages

def prepare_mappings_optim(unraveled_data):
    unique_queries = pd.DataFrame({'query': unraveled_data['query'].unique()})
    unique_passages = pd.DataFrame({'passage': unraveled_data['passage'].unique()})
    unique_queries['query_id'] = unique_queries.index
    unique_passages['passage_id'] = unique_passages.index
    return unique_queries, unique_passages

def map_ids(
    unraveled_data: DataFrame,
    unique_queries: DataFrame,
    unique_passages: DataFrame
) -> DataFrame:
    """Optimized mapping of IDs using hash joins."""
    # Use more efficient merge strategy
    return (unraveled_data
            .merge(unique_queries, on='query', how='left', copy=False)
            .merge(unique_passages, on='passage', how='left', copy=False))

def expand_triplets(triplets_df: DataFrame) -> DataFrame:
    """Vectorized triplet expansion."""
    # Create expanded arrays
    query_ids = np.repeat(triplets_df['query_id'].values, 
                         triplets_df['relevant'].str.len())
    
    # Flatten the lists using list comprehension
    positives = [p for sublist in triplets_df['relevant'] for p in sublist]
    negatives = [n for sublist in triplets_df['irrelevant'] for n in sublist]
    
    return pd.DataFrame({
        'query_id': query_ids,
        'positive_passage_id': positives,
        'negative_passage_id': negatives
    })
    
def process_dataset(dataset_split):
    # Add parallel processing for large datasets
    with mp.Pool(mp.cpu_count()) as pool:
        unraveled_data = unravel_passages(dataset_split)
        unique_queries, unique_passages = prepare_mappings_optim(unraveled_data)
        
        # Parallel processing for ID mapping
        chunk_size = len(unraveled_data) // mp.cpu_count()
        chunks = [unraveled_data[i:i + chunk_size] for i in range(0, len(unraveled_data), chunk_size)]
        
        mapped_chunks = pool.starmap(
            map_ids,
            [(chunk, unique_queries, unique_passages) for chunk in chunks]
        )
        unraveled_data = pd.concat(mapped_chunks)
    
    pre_samples = pre_sample_irrelevant(
        list(set(unraveled_data['passage_id'])), 
        unique_queries.shape[0]
    )
    triplets_df = create_triplets_dataframe(unraveled_data, pre_samples)
    return triplets_df, unique_queries, unique_passages


In [None]:
dataset = load_dataset("ms_marco", "v1.1", split='train', streaming=True)

In [None]:
print(next(iter(dataset)))

{'answers': ['Results-Based Accountability is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole.'], 'passages': {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 'passage_text': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonw

In [None]:
dataset

DatasetDict({
    validation: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 10047
    })
    train: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 82326
    })
    test: Dataset({
        features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
        num_rows: 9650
    })
})

In [8]:
dataset['validation'][0]

{'answers': ['Approximately $15,000 per year.'],
 'passages': {'is_selected': [1, 0, 0, 0, 0, 0],
  'passage_text': ['The average Walgreens salary ranges from approximately $15,000 per year for Customer Service Associate / Cashier to $179,900 per year for District Manager. Average Walgreens hourly pay ranges from approximately $7.35 per hour for Laboratory Technician to $68.90 per hour for Pharmacy Manager. Salary information comes from 7,810 data points collected directly from employees, users, and jobs on Indeed.',
   'The average revenue in 2011 of a Starbuck Store was $1,078,000, up  from $1,011,000 in 2010.    The average ticket (total purchase) at domestic Starbuck stores in  No … vember 2007 was reported at $6.36.    In 2008, the average ticket was flat (0.0% change).',
   'In fiscal 2014, Walgreens opened a total of 184 new locations and acquired 84 locations, for a net decrease of 273 after relocations and closings. How big are your stores? The average size for a typical Walgr

In [5]:
train_triplets, train_queries, train_passages = process_dataset(pd.DataFrame(dataset['train']))

KeyboardInterrupt: 

In [None]:
test_triplets, test_queries, test_passages = process_dataset(pd.DataFrame(dataset['test']))

In [None]:
val_triplets, val_queries, val_passages = process_dataset(pd.DataFrame(dataset['validation']))

In [None]:
# train_triplets.to_parquet('../data/mappings/train_triplets_compressed.parquet', index=False)
# train_queries.to_parquet('../data/mappings/train_queries.parquet', index=False)
# train_passages.to_parquet('../data/mappings/train_passages.parquet', index=False)

# test_triplets.to_parquet('../data/mappings/test_triplets_compressed.parquet', index=False)
# test_queries.to_parquet('../data/mappings/test_queries.parquet', index=False)
# test_passages.to_parquet('../data/mappings/test_passages.parquet', index=False)

# val_triplets.to_parquet('../data/mappings/val_triplets_compressed.parquet', index=False)
# val_queries.to_parquet('../data/mappings/val_queries.parquet', index=False)
# val_passages.to_parquet('../data/mappings/val_passages.parquet', index=False)

In [None]:
train_triplets_exp = expand_triplets(train_triplets)
test_triplets_exp = expand_triplets(test_triplets)
val_triplets_exp = expand_triplets(val_triplets)

In [None]:
# train_triplets_exp.to_parquet('../data/mappings/train_triplets_expanded.parquet', index=False)
# test_triplets_exp.to_parquet('../data/mappings/test_triplets_expanded.parquet', index=False)
# val_triplets_exp.to_parquet('../data/mappings/val_triplets_expanded.parquet', index=False)