# Import Libraries

In [None]:
# Step 1: Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sentencepiece import SentencePieceTrainer
from datasets import load_dataset
import random

# Import data

In [None]:
# Load marco data
data = load_dataset('ms_marco', 'v1.1')

In [None]:
# Split the dataset into training, validation, and test sets
train_dataset = data['train']
val_dataset = data['validation']
test_dataset = data['test']

# convert to pandas
data_train_df = pd.DataFrame(train_dataset)
data_validation_df = pd.DataFrame(val_dataset)
data_test_df = pd.DataFrame(test_dataset)

In [None]:
data_train_df

In [None]:
# Step 4: Extract queries and documents from the dataset
queries = data_train_df['query'].values

In [None]:
relevant_documents = data_train_df['passages'].values

In [35]:
# Step 5: Generate triples of queries, relevant (positive) documents and irrelevant (negative) documents

triples = []
for i in range(len(queries)):
    positive_document = relevant_documents[i]
    negative_document = relevant_documents[(i+1)%len(relevant_documents)]  #  negative sampling
    triples.append((queries[i], positive_document, negative_document))


In [None]:
# Step 6: Tokenize the generated data using Sentencepiece
# First, train the SentencePiece model
vocab_size = 5000
SentencePieceTrainer.Train(
    input=triples,
    model_prefix='m',
    vocab_size=vocab_size,
    character_coverage=1.0,
    model_type='unigram'
)

In [None]:
# Then, we can use the trained model to tokenize our data
sp = spm.SentencePieceProcessor()
sp.Load('m.model')

tokenized_triples = [(sp.EncodeAsPieces(query), 
                      sp.EncodeAsPieces(pos_doc), 
                      sp.EncodeAsPieces(neg_doc)) for query, pos_doc, neg_doc in triples]

In [None]:
# Example of tokenizing a query and document
tokenized_query = sp.encode(train_triples[0][0], out_type=str)
tokenized_positive_doc = sp.encode(train_triples[0][1], out_type=str)
tokenized_negative_doc = sp.encode(train_triples[0][2], out_type=str)