In [1]:
import pandas as pd
import numpy as np

from datasets import load_dataset


  from .autonotebook import tqdm as notebook_tqdm


In [28]:
train_dataset = load_dataset('ms_marco', 'v1.1', split='train')
test_dataset = load_dataset('ms_marco', 'v1.1', split='test')

In [10]:
# view structure of the passages column in the dataset
texts = test_dataset['passages'][0]
texts


{'is_selected': [0, 0, 1, 0, 0, 0, 0],
 'passage_text': ['We have been feeding our back yard squirrels for the fall and winter and we noticed that a few of them have missing fur. One has a patch missing down his back and under both arms. Also another has some missing on his whole chest. They are all eating and seem to have a good appetite.',
  'Critters cannot stand the smell of human hair, so sprinkling a barrier of hair clippings around your garden, or lightly working it into the soil when you plant bulbs, apparently does have some merit. The whole thing kind of makes me laugh. It never occurred to me that we are the ones that stink.',
  "Spread some human hair around your vegetable and flower gardens. This will scare the squirrels away because humans are predators of squirrels. It is better if the hair hasn't been washed so the squirrels will easily pick up the human scent.",
  '1 You can sprinkle blood meal around your garden as well. 2  Don’t trap and relocate squirrels. 3  This i

In [22]:
import random

def create_triplets(dataset, include_hard_negatives=False, num_hard_negatives=3):
    """
    Create (query, positive_passage, negative_passage) triplets from the given dataset.
    For training, add additional triplets with in-row "hard" negatives.
    
    Args:
        dataset (list of dict): Each item should have 'Query' and 'Passages' keys. 
                                'Passages' must contain 'is_selected' and 'passage_text'.
        include_hard_negatives (bool): Whether to add hard negative samples (for training only)
        num_hard_negatives (int): Number of hard negatives to sample per positive (if available)
    
    Returns:
        list of tuples: Each tuple is (query, positive_passage, negative_passage)
    """
    all_passages = []

    # Pre-collect all passages for negative sampling
    for row in dataset:
        all_passages.extend(row['passages']['passage_text'])

    triplets = []

    for row in dataset:
        query = row['query']
        passages = row['passages']['passage_text']
        labels = row['passages']['is_selected']

        # Find the index of the positive passage
        if 1 not in labels:
            continue  # Skip if no positive passage
        pos_index = labels.index(1)
        positive = passages[pos_index]

        # Select a random negative passage (ensuring it's not from the same row)
        while True:
            negative = random.choice(all_passages)
            if negative != positive and negative not in passages:
                break

        triplets.append((query, positive, negative))

        if include_hard_negatives:
            # Create additional triplets with in-row negatives (if available)
            non_selected_indices = [i for i, label in enumerate(labels) if label == 0]
            
            if non_selected_indices:
                # Sample up to num_hard_negatives unique in-row negatives
                # (or fewer if not enough are available)
                sampled_indices = random.sample(
                    non_selected_indices, 
                    min(num_hard_negatives, len(non_selected_indices))
                )
                
                for neg_index in sampled_indices:
                    in_row_negative = passages[neg_index]
                    triplets.append((query, positive, in_row_negative))

    return triplets

In [23]:

# Generate triplets
train_triplets_stage1 = create_triplets(train_dataset, include_hard_negatives=False)
train_triplets_stage2 = create_triplets(train_dataset, include_hard_negatives=True)

test_triplets = create_triplets(test_dataset, include_hard_negatives=False)

# Print samples from both train and test sets
print("TRAINING TRIPLETS (stage 1):")
for t in train_triplets_stage1[:10]:
    print(f"Query: {t[0]}\nPositive: {t[1]}\nNegative: {t[2]}\n{'-'*40}")

print("\nTRAINING TRIPLETS (stage 2):")
for t in train_triplets_stage2[:10]:
    print(f"Query: {t[0]}\nPositive: {t[1]}\nNegative: {t[2]}\n{'-'*40}")

print("\nTEST TRIPLETS:")
for t in test_triplets[:3]:
    print(f"Query: {t[0]}\nPositive: {t[1]}\nNegative: {t[2]}\n{'-'*40}")

TRAINING TRIPLETS (stage 1):
Query: what is rba
Positive: Results-Based Accountability® (also known as RBA) is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole. RBA is also used by organizations to improve the performance of their programs. Creating Community Impact with RBA. Community impact focuses on conditions of well-being for children, families and the community as a whole that a group of leaders is working collectively to improve. For example: “Residents with good jobs,” “Children ready for school,” or “A safe and clean neighborhood”.
Negative: However, an adult can have this type of eczema as well. Flare-ups of this type of eczema can be caused by irritants, allergens, stress, fabrics, and dry skin, to name just a few. If you have food allergies, you may be more likely to experience eczema flare-ups. While it can be difficult to isolate what irritants or allergens caus

In [24]:
print(len(train_triplets_stage1), "training triplets - stage 1")
print(len(train_triplets_stage2), "training triplets - stage 2")
print(len(test_triplets), "test triplets")

79704 training triplets - stage 1
318372 training triplets - stage 2
9345 test triplets


In [27]:
from data_prep import save_triplets_to_json

# Save triplets
save_triplets_to_json(train_triplets_stage1, "train_triplets_stage1.json")
save_triplets_to_json(train_triplets_stage2, "train_triplets_stage2.json")
save_triplets_to_json(test_triplets, "test_triplets.json")

In [None]:
from data_prep import save_passages_to_file

# Combine both train and test passages to get all passages in one list
train_passages = [passage for row in train_dataset for passage in row['passages']['passage_text']]
test_passages = [passage for row in test_dataset for passage in row['passages']['passage_text']]
all_passages = train_passages + test_passages
print(len(all_passages), "total passages")

# Save to a JSON file
save_passages_to_file(all_passages, 'all_docs.json')


755369 total passages
Passages saved to all_docs.json


In [37]:
for i, passage in enumerate(all_passages[:15]):
    print(f"Passage {i+1}: {passage}\n{'-'*40}")

Passage 1: Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.
----------------------------------------
Passage 2: The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New S

Loaded 755369 passages from all_docs.json
["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and a