In [19]:
# Training loop:
# 1. Create torch dataset from the data containing the query, positive, negative example.
# 2. Create a dataloader from the dataset.
# 3. Run a single batch through the model.

In [20]:
!poetry add pandas

The following packages are already present in the pyproject.toml and will be skipped:

  • [36mpandas[39m

If you want to update it to the latest compatible version, you can use `poetry update package`.
If you prefer to upgrade it to the latest available version, you can use `poetry add package@latest`.

Nothing to add.


In [21]:
from torch.utils.data import DataLoader, Dataset
import pandas as pd

In [22]:
## Load data
train = pd.read_parquet("train.parquet")
test = pd.read_parquet("test.parquet")
validate = pd.read_parquet("validate.parquet")

In [23]:
print(train.shape, test.shape, validate.shape)

(82326, 6) (9650, 6) (10047, 6)


In [24]:
full_dataset = pd.concat([train, test, validate])

In [25]:
total_size = train.shape[0] + test.shape[0] + validate.shape[0]
print(total_size)
# Rerun this code only if you want to generate new indices.
import numpy as np
import pickle

# np.random.seed(seed=999999)
# indices = np.random.randint(0, total_size, total_size)
# train_indices = indices[0:train.shape[0]]
# test_indices = indices[train.shape[0]:train.shape[0]+test.shape[0]]
# validate_indices = indices[train.shape[0]+test.shape[0]:]
# all_indices = {'train': train_indices.tolist(), 'test': test_indices.tolist(), 'validate': validate_indices.tolist()}
# print(f"Length of all indices: {sum(len(v) for k, v in all_indices.items())}")
# # write these to a pickle file

# with open('indices.pkl', 'wb') as f:
#     pickle.dump(all_indices, f)

# create train, test, validation datasets based on query_ids
# query_ids = full_dataset['query_id']
# train_query_ids = query_ids[0:train.shape[0]]
# test_query_ids = query_ids[train.shape[0]:train.shape[0]+test.shape[0]]
# validate_query_ids = query_ids[train.shape[0]+test.shape[0]:]
# print(train_query_ids.shape, test_query_ids.shape, validate_query_ids.shape)
# all_query_ids  = {'train': train_query_ids.tolist(), 'test': test_query_ids.tolist(), 'validate': validate_query_ids.tolist()}
# # write these to a pickle file

# with open('query_ids.pkl', 'wb') as f:
#     pickle.dump(all_query_ids, f)

102023


In [26]:
# Unpickle and test
with open("indices.pkl", "rb") as f:
    all_indices = pickle.load(f)

    print(f"{len(all_indices['train'])} training indices {all_indices['train'][0:10]}")
    print(f"{len(all_indices['test'])} test indices {all_indices['test'][0:10]}")
    print(
        f"{len(all_indices['validate'])} validate indices {all_indices['validate'][0:10]}"
    )

with open("query_ids.pkl", "rb") as f:
    all_query_ids = pickle.load(f)
    print(
        f"{len(all_query_ids['train'])} training query ids {all_query_ids['train'][0:10]}"
    )
    print(f"{len(all_query_ids['test'])} test query ids {all_query_ids['test'][0:10]}")
    print(
        f"{len(all_query_ids['validate'])} validate query ids {all_query_ids['validate'][0:10]}"
    )

82326 training indices [36233, 13668, 19591, 66699, 34957, 34882, 23273, 27926, 90079, 68436]
9650 test indices [27928, 5317, 97860, 83225, 90861, 42316, 42625, 98527, 74133, 95026]
10047 validate indices [94011, 72401, 60227, 70263, 35765, 47026, 31472, 3385, 21289, 545]
82326 training query ids [19699, 19700, 19701, 19702, 19703, 19704, 19705, 19706, 19707, 19708]
9650 test query ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
10047 validate query ids [9652, 9653, 9654, 9655, 9656, 9657, 9658, 9659, 9660, 9661]


In [58]:
# train[['passages', 'query']].iloc[0]
train_dataset = full_dataset.iloc[all_indices["train"]]
test_dataset = full_dataset.iloc[all_indices["test"]]
validate_dataset = full_dataset.iloc[all_indices["validate"]]

train_dataset.head()["query"].iloc[0]

'is there a limit to roth ira contributions'

In [63]:
import hashlib
from concurrent.futures import ProcessPoolExecutor


def pprint(text):
    print(f"\r{text}", end="")


def create_lookups(full_dataset):
    print(f"dataset shape: {full_dataset.shape}")
    all_urls = full_dataset["passages"].apply(lambda x: x["url"]).tolist()
    unique_urls = set([item for sublist in all_urls for item in sublist])
    print(f"Total number of urls: {sum(len(i) for i in all_urls)}")
    print(f"Total number of unique urls: {len(unique_urls)}")

    # Use an md5 hash for the urls for deterministic mapping
    def generate_md5_hash(s):
        return hashlib.md5(s.encode("utf-8")).hexdigest()

    ids_to_urls = {generate_md5_hash(url): url for url in unique_urls}

    urls_to_ids = {url: i for i, url in ids_to_urls.items()}
    print(f"Total number of hashed urls: {len(ids_to_urls)}")

    query_ids = full_dataset["query_id"].tolist()
    assert len(query_ids) == len(set(query_ids))
    print(f"Total number of queries: {len(query_ids)}")
    return ids_to_urls, urls_to_ids

def add_hashed_urls(dataset, urls_to_ids):
    dataset.loc[:, "hashed_urls"] = dataset["passages"].progress_apply(
        lambda x: np.array(list(set([urls_to_ids[url] for url in x["url"]])))
    )
def get_triples(dataset):
    return dataset[["query_id", "hashed_urls", "irrelevant_urls"]].values.tolist()
    # for i, row in (
    #     enumerate(dataset.iterrows()
    #     if not create_small
    #     else dataset.head(10).iterrows())
    # ):
    #     relevant_url_ids = row[1]["hashed_urls"]
    #     query_id = row[1]["query_id"]
    #     # relevant_url_ids = np.array(list(set([urls_to_ids[url] for url in urls])))
    #     irrelevant_url_ids = np.setdiff1d(
    #         master_url_id_set, relevant_url_ids, assume_unique=True
    #     )
    #     sampled_ids = np.random.choice(
    #         irrelevant_url_ids, size=len(relevant_url_ids), replace=False
    #     )

    #     triple = (query_id, list(relevant_url_ids), sampled_ids.tolist())
    #     triples.append(triple)
    #     if (i+1) % 1000 == 0:
    #         pprint(f"Processed {i} rows")
    # print(f"Query ID, Query: {query_id}, {row[1]['query']}")
    # print(f"R {NL.join([ids_to_urls[i]+NL+TAB+str(i) for i in relevant_url_ids])}")
    # print(f"IR: {NL.join([ids_to_urls[i]+NL+TAB+str(i) for i in irrelevant_url_ids])}")


# test_triples = create_triples(test_dataset, ids_to_urls, urls_to_ids, create_small=False)
# validate_triples = create_triples(validate_dataset, ids_to_urls, urls_to_ids, create_small=False)

In [64]:
ids_to_urls, urls_to_ids = create_lookups(full_dataset)
add_hashed_urls(train_dataset, urls_to_ids)
add_hashed_urls(test_dataset, urls_to_ids)
add_hashed_urls(validate_dataset, urls_to_ids)

dataset shape: (102023, 6)
Total number of urls: 837729
Total number of unique urls: 456487
Total number of hashed urls: 456487
Total number of queries: 102023


100%|██████████| 82326/82326 [00:00<00:00, 169606.21it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset.loc[:, "hashed_urls"] = dataset["passages"].progress_apply(
100%|██████████| 9650/9650 [00:00<00:00, 165736.61it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset.loc[:, "hashed_urls"] = dataset["passages"].progress_apply(
100%|██████████| 10047/10047 [00:00<00:00, 164394.28it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: 

In [30]:
from tqdm import tqdm

tqdm.pandas()

In [60]:
master_url_id_set = np.array(
    list(set(ids_to_urls.keys()))
)  # Define or load your master_url_id_set here

In [131]:
# Splitting the dataset for demonstration; in practice, adjust based on your dataset size
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import utils 
import importlib
importlib.reload(utils)
importlib.reload(utils.data_utils)
from utils.data_utils import add_negative_samples


In [133]:
training_dataset_copy = train_dataset.copy()
add_negative_samples(training_dataset_copy, ids_to_urls)

  return bound(*args, **kwds)
Processing Batches: 100%|██████████| 10/10 [12:47<00:00, 76.77s/it]  


In [140]:
test_dataset_copy = test_dataset.copy()
add_negative_samples(test_dataset_copy, ids_to_urls)

  return bound(*args, **kwds)
Processing Batches: 100%|██████████| 10/10 [01:29<00:00,  8.98s/it]


In [141]:
validate_dataset_copy = validate_dataset.copy()
add_negative_samples(validate_dataset_copy, ids_to_urls)

  return bound(*args, **kwds)
Processing Batches: 100%|██████████| 10/10 [01:38<00:00,  9.87s/it]


In [None]:
df.to_parquet('train_triplets.parquet', engine='pyarrow')

In [145]:
triple_columns = ["query_id", "hashed_urls", "negative_sample_urls"]
training_dataset_copy[triple_columns].to_parquet('training_dataset_triplets.parquet', engine='pyarrow')

In [148]:
test_dataset_copy[triple_columns].to_parquet('test_dataset_triplets.parquet', engine='pyarrow')

In [150]:
validate_dataset_copy[triple_columns].to_parquet('validate_dataset_triplets.parquet', engine='pyarrow')

In [138]:
# There is a unit test to test this in the tests/ dir.
s1 = list([len(i[1]["negative_sample_urls"]) for i in training_dataset_copy.iterrows()])
s2 = list([len(i[1]["hashed_urls"]) for i in training_dataset_copy.iterrows()])
print(len(s1))
print(len(s2))
print(s1[0:10])
print(s2[0:10])
assert(s1 == s2)


82326
82326
[7, 9, 5, 5, 8, 5, 7, 6, 6, 7]
[7, 9, 5, 5, 8, 5, 7, 6, 6, 7]


In [None]:
def create_triplet_dataset(which: str):
    """
    Structure:
    query, positive, negative
    """
    pass

In [65]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load("spm_AllTexts.model")

True

In [None]:
class TripletDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]