In [1]:
import pandas as pd
import numpy as np

from datasets import load_dataset


In [24]:
train_dataset = load_dataset('ms_marco', 'v1.1', split='train')
test_dataset = load_dataset('ms_marco', 'v1.1', split='test')

In [None]:
# view structure of the passages column in the dataset
texts = train_dataset['passages'][0]
texts


{'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 'passage_text': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.",
  "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydn

In [18]:
import random

def create_training_triplets(dataset):
    """
    Create (query, positive_passage, negative_passage) triplets from the given dataset.
    
    Args:
        dataset (list of dict): Each item should have 'Query' and 'Passages' keys. 
                                'Passages' must contain 'is_selected' and 'passage_text'.
    
    Returns:
        list of tuples: Each tuple is (query, positive_passage, negative_passage)
    """
    all_passages = []

    # Pre-collect all passages for negative sampling
    for row in dataset:
        all_passages.extend(row['passages']['passage_text'])


    triplets = []

    for row in dataset:
        query = row['query']
        passages = row['passages']['passage_text']
        labels = row['passages']['is_selected']

        # Find the index of the positive passage
        if 1 not in labels:
            continue  # Skip if no positive passage
        pos_index = labels.index(1)
        positive = passages[pos_index]

        # Select a random negative passage (ensuring it's not from the same row)
        while True:
            negative = random.choice(all_passages)
            if negative != positive and negative not in passages:
                break

        triplets.append((query, positive, negative))

    return triplets


In [22]:

# Generate triplets
train_triplets = create_training_triplets(train_dataset)

# Print a sample
for t in train_triplets[:3]:
    print(f"Query: {t[0]}\nPositive: {t[1]}\nNegative: {t[2]}\n{'-'*40}")


["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", 'RBA Rec

In [23]:
import json
from typing import List, Tuple

def save_triplets_to_json(triplets: List[Tuple[str, str, str]], output_file: str) -> None:
    """
    Save triplets to a JSON file.
    
    Args:
        triplets: List of (query, positive_passage, negative_passage) tuples
        output_file: Path to save the JSON file
    """
    # Convert tuples to dictionaries for better readability
    triplets_dict = [
        {
            "query": query,
            "positive_passage": pos,
            "negative_passage": neg
        }
        for query, pos, neg in triplets
    ]
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(triplets_dict, f, ensure_ascii=False, indent=2)

def load_triplets_from_json(input_file: str) -> List[Tuple[str, str, str]]:
    """
    Load triplets from a JSON file.
    
    Args:
        input_file: Path to the JSON file
        
    Returns:
        List of (query, positive_passage, negative_passage) tuples
    """
    with open(input_file, 'r', encoding='utf-8') as f:
        triplets_dict = json.load(f)
    
    # Convert dictionaries back to tuples
    triplets = [
        (item["query"], item["positive_passage"], item["negative_passage"])
        for item in triplets_dict
    ]
    
    return triplets


# Save triplets
save_triplets_to_json(train_triplets, "triplets.json")

## Tokenizing and Embedding

In [25]:
from sentence_transformers import SentenceTransformer
import torch

# Load the pre-trained Sentence-BERT model
model = SentenceTransformer('all-MiniLM-L6-v2')

def embed_triplets(triplets):
    """
    Given a list of triplets (query, positive_passage, negative_passage),
    this function will return a tensor of embeddings for the queries, positive passages, and negative passages.
    
    Args:
    triplets (list of tuples): Each tuple is (query, positive_passage, negative_passage)
    
    Returns:
    tuple: Three tensors, (query_embeddings, positive_embeddings, negative_embeddings)
    """
    
    # Extract the queries, positive passages, and negative passages
    queries, positive_passages, negative_passages = zip(*triplets)
    
    # Encode the queries, positive passages, and negative passages
    query_embeddings = model.encode(queries, convert_to_tensor=True)
    positive_embeddings = model.encode(positive_passages, convert_to_tensor=True)
    negative_embeddings = model.encode(negative_passages, convert_to_tensor=True)
    
    return query_embeddings, positive_embeddings, negative_embeddings



query_embeddings, positive_embeddings, negative_embeddings = embed_triplets(train_triplets)

# The embeddings will be in tensor form, and you can now use them for downstream tasks like calculating triplet loss.
print(query_embeddings.shape)  # Should return a tensor with the shape (n_triplets, embedding_dim)
print(positive_embeddings.shape)
print(negative_embeddings.shape)




torch.Size([79704, 384])
torch.Size([79704, 384])
torch.Size([79704, 384])
