# Prepare data for training

In [86]:
import pandas as pd
from datasets import load_dataset

In [87]:
data = load_dataset('ms_marco', 'v1.1')

In [88]:
df_train_raw = pd.DataFrame(data['train'])
df_validation_raw = pd.DataFrame(data['validation'])
df_test_raw = pd.DataFrame(data['test'])


In [89]:
def unwrap_passages(row: pd.Series) -> pd.DataFrame:
  df_psg = pd.DataFrame(row['passages'])
  df_psg['query'] = row['query']
  df_psg['query_id'] = row['query_id']
  df_psg['query_type'] = row['query_type']

  return df_psg


In [90]:
def unwrap_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    chunks = []
    for _, row in df.iterrows():
        chunks.append(unwrap_passages(row))
    df_out = pd.concat(chunks).reset_index(drop=True)
    df_out.rename(columns={'passage_text': 'doc_relevant', 'url': 'url_relevant'}, inplace=True)
    df_out['relevance'] = df_out['is_selected'] + 1
    return df_out
    # TODO: vectorize this, do it much faster



In [91]:
df_train = unwrap_dataframe(df_train_raw)
df_validation = unwrap_dataframe(df_validation_raw)
df_test = unwrap_dataframe(df_test_raw)

In [92]:
df_train.head(2)

Unnamed: 0,is_selected,doc_relevant,url_relevant,query,query_id,query_type,relevance
0,0,"Since 2007, the RBA's outstanding reputation h...",https://en.wikipedia.org/wiki/Reserve_Bank_of_...,what is rba,19699,description,1
1,0,The Reserve Bank of Australia (RBA) came into ...,https://en.wikipedia.org/wiki/Reserve_Bank_of_...,what is rba,19699,description,1


In [93]:
df_train[['doc_irrelevant', 'url_irrelevant']] = df_train[['doc_relevant', 'url_relevant']].sample(n=len(df_train), replace=True).reset_index(drop=True)
df_validation[['doc_irrelevant', 'url_irrelevant']] = df_validation[['doc_relevant', 'url_relevant']].sample(n=len(df_validation), replace=True).reset_index(drop=True)
df_test[['doc_irrelevant', 'url_irrelevant']] = df_test[['doc_relevant', 'url_relevant']].sample(n=len(df_test), replace=True).reset_index(drop=True)



In [94]:
df_train.head(2)

Unnamed: 0,is_selected,doc_relevant,url_relevant,query,query_id,query_type,relevance,doc_irrelevant,url_irrelevant
0,0,"Since 2007, the RBA's outstanding reputation h...",https://en.wikipedia.org/wiki/Reserve_Bank_of_...,what is rba,19699,description,1,Nope! The only difference between brown eggs &...,http://www.answers.com/Q/Do_brown_eggs_have_th...
1,0,The Reserve Bank of Australia (RBA) came into ...,https://en.wikipedia.org/wiki/Reserve_Bank_of_...,what is rba,19699,description,1,Non-metals are the elements in groups 14-16 of...,http://www.chemicalelements.com/groups/nonmeta...


In [95]:
save_cols = ['doc_relevant', 'url_relevant', 'doc_irrelevant', 'url_irrelevant', 'query', 'query_id', 'query_type', 'relevance']

In [96]:
df_train[save_cols].to_parquet('./training.parquet')
df_validation[save_cols].to_parquet('./validation.parquet')
df_test[save_cols].to_parquet('./test.parquet')
