In [1]:
%cd ../

/Users/hoangle/Projects/recsys


In [2]:
import sys
from pathlib import Path

import yaml
import polars as pl
from transformers import AutoTokenizer
from loguru import logger

In [3]:
logger.remove()
logger.add(sys.stderr, level="DEBUG")

1

# Load data, config

## Load config

In [4]:
path = "colbert/configs.yaml"

with open(path) as file:
    conf = yaml.safe_load(file)

conf

{'MODEL_NAME': 'bert-base-cased',
 'Nq': 32,
 'PATH_TOKENIZER': 'colbert/processed/tokenizer',
 'RAW_DATA': {'query': 'data/raw/hotpotqa/queries.jsonl',
  'corpus': 'data/raw/hotpotqa/corpus.jsonl',
  'train': 'data/raw/hotpotqa/qrels/train.tsv',
  'val': 'data/raw/hotpotqa/qrels/dev.tsv',
  'test': 'data/raw/hotpotqa/qrels/test.tsv'},
 'PROCESSED': {'query': 'colbert/processed/query_tokenized.parquet',
  'corpus': 'colbert/processed/corpus_tokenized_[i].parquet'},
 'TOKEN': {'query': '[Q]', 'document': '[D]'}}

## Load queries

In [5]:
queries = pl.read_ndjson(conf['RAW_DATA']['query'])
queries.head()

_id,text,metadata
str,str,struct[2]
"""5ab6d31155429954757d3384""","""What country of origin does Ho…","{""American"",[[""House of Cosbys"", ""0""], [""Bill Cosby"", ""0""]]}"
"""5ac0d92f554299012d1db645""","""How many fountains where prese…","{""1,200 musical water fountains"",[[""Steve Davison"", ""0""], [""Steve Davison"", ""1""], … [""World of Color"", ""2""]]}"
"""5abd01335542993a06baf9fc""","""Chris Larceny directed the mus…","{""the Fugees"",[[""Chris Larceny"", ""3""], [""Wyclef Jean"", ""0""], [""Wyclef Jean"", ""2""]]}"
"""5abff8c95542994516f4555c""","""The person where local traditi…","{""the Iroquois Confederacy"",[[""Cross Lake"", ""1""], [""Hiawatha"", ""0""]]}"
"""5adec8ad55429975fa854f8f""","""The actor who played Carl Swee…","{""Denise DeClue"",[[""About Last Night (1986 film)"", ""1""], [""Tim Kazurinsky"", ""0""]]}"


In [6]:
TOK_QUERY = conf['TOKEN']['query']
TOK_DOC = conf['TOKEN']['document']

# Process data

## Add special tokens to queries

In [7]:
queries = (
    queries
    .select(
        pl.col('_id').alias('qid'),
        pl.format("[CLS] [Q] {}", pl.col('text')).alias('text')
    )
)

queries.head()

qid,text
str,str
"""5ab6d31155429954757d3384""","""[CLS] [Q] What country of orig…"
"""5ac0d92f554299012d1db645""","""[CLS] [Q] How many fountains w…"
"""5abd01335542993a06baf9fc""","""[CLS] [Q] Chris Larceny direct…"
"""5abff8c95542994516f4555c""","""[CLS] [Q] The person where loc…"
"""5adec8ad55429975fa854f8f""","""[CLS] [Q] The actor who played…"


## Tokenize

## Initialize or load tokenizer

In [8]:
path_tokenizer = Path(conf['PATH_TOKENIZER'])

if not path_tokenizer.exists():
    logger.info("Tokenizer not found. Create new one.")

    tokenizer = AutoTokenizer.from_pretrained(conf['MODEL_NAME'])
    tokenizer.add_tokens([TOK_DOC, TOK_QUERY], special_tokens=False)
    tokenizer.add_special_tokens({'pad_token': tokenizer.mask_token})
else:
    logger.info("Tokenizer found. Load pre-trained.")

    tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)

    tokenizer.save_pretrained(path_tokenizer)

[32m2025-02-22 11:32:38.918[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m10[0m - [1mTokenizer found. Load pre-trained.[0m


## Tokenize query

In [9]:
queries_tokenized = tokenizer(
    queries['text'].to_list(),
    add_special_tokens=False,
    padding_side='right',
    max_length=conf['Nq'],
    truncation=True,
    padding=True,
    return_tensors='np'
)

In [10]:
queries = (
    queries
    .with_columns(
        pl.Series(queries_tokenized['input_ids']).alias('tok_ids')
    )
)

queries.head()

qid,text,tok_ids
str,str,"array[i64, 32]"
"""5ab6d31155429954757d3384""","""[CLS] [Q] What country of orig…","[101, 28997, … 103]"
"""5ac0d92f554299012d1db645""","""[CLS] [Q] How many fountains w…","[101, 28997, … 3635]"
"""5abd01335542993a06baf9fc""","""[CLS] [Q] Chris Larceny direct…","[101, 28997, … 5110]"
"""5abff8c95542994516f4555c""","""[CLS] [Q] The person where loc…","[101, 28997, … 103]"
"""5adec8ad55429975fa854f8f""","""[CLS] [Q] The actor who played…","[101, 28997, … 103]"


# Save processed

In [11]:
queries.write_parquet(conf['PROCESSED']['query'])