In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import os
from pathlib import Path
import sys

if '..' not in sys.path: sys.path.append('..')

from pydantic import BaseModel
from pydantic_yaml import to_yaml_file, parse_yaml_file_as
from transformers import GPT2Tokenizer

from mllm.config.model import create_mllm_encdec_cfg, create_mllm_ranker_cfg, TokenizerCfg
from mllm.tokenization.chunk_tokenizer import calc_max_inp_size, gen_all_tokens, tokenizer_from_config



In [5]:
cfg_dpath = Path(os.path.abspath('.')).parent / 'mllm' / 'config' / 'cfg_v001'
cfg_dpath

PosixPath('/home/misha/prog/mllm/mllm/config/cfg_v001')

In [6]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', model_max_length=10000)
n_tokens_init = len(tokenizer)
tok_dict = gen_all_tokens(tokenizer)
pad_tok, qbeg_tok, qend_tok = tok_dict['pad'].ind, tok_dict['query_begin'].ind, tok_dict['query_end'].ind


In [7]:
tkz_cfg = TokenizerCfg(name='gpt2', n_tokens_init=n_tokens_init, model_max_length=10000, custom_tokens=tok_dict)
tkz_cfg

TokenizerCfg(name='gpt2', n_tokens_init=50257, model_max_length=10000, custom_tokens={'doc_begin': CustomToken(name='doc_begin', repr='<|doc_begin|>', special=False, ind=50257), 'doc_end': CustomToken(name='doc_end', repr='<|doc_end|>', special=False, ind=50258), 'doc_id_begin': CustomToken(name='doc_id_begin', repr='<|doc_id_begin|>', special=False, ind=50259), 'doc_id_end': CustomToken(name='doc_id_end', repr='<|doc_id_end|>', special=False, ind=50260), 'doc_offset_begin': CustomToken(name='doc_offset_begin', repr='<|doc_offset_begin|>', special=False, ind=50261), 'doc_offset_end': CustomToken(name='doc_offset_end', repr='<|doc_offset_end|>', special=False, ind=50262), 'doc_title_begin': CustomToken(name='doc_title_begin', repr='<|doc_title_begin|>', special=False, ind=50263), 'doc_title_end': CustomToken(name='doc_title_end', repr='<|doc_title_end|>', special=False, ind=50264), 'doc_body_begin': CustomToken(name='doc_body_begin', repr='<|doc_body_begin|>', special=False, ind=50265),

In [8]:
tkz_cfg_fpath = cfg_dpath / 'tokenizer_cfg_01.yaml'
to_yaml_file(tkz_cfg_fpath, tkz_cfg)

In [11]:
tkz = tokenizer_from_config(tkz_cfg)
tkz

GPT2Tokenizer(name_or_path='gpt2', vocab_size=50257, model_max_length=10000, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|pad|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	50257: AddedToken("<|doc_begin|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50258: AddedToken("<|doc_end|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50259: AddedToken("<|doc_id_begin|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50260: AddedToken("<|doc_id_end|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50261: AddedToken("<|doc_offset_begin|>", rstrip=False, lstrip=False, single_word=False,

In [12]:
n_tokens_init

50257

In [4]:
model_cfg = create_mllm_ranker_cfg(
    n_vocab=len(tokenizer), inp_len=100, d_word_wec=256,
    n_levels=1, enc_n_layers=1, dec_n_layers=1,
    n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024,
    pad_idx=-1, dropout_rate=0.0, enc_with_emb_mat=True,
)
model_cfg

MllmRankerCfg(vocab_encoder=VocabEncoderCfg(n_vocab=50257, d_word_vec=256, d_model=256, pad_idx=-1, inp_len=100, dropout_rate=0.0), encoders=[EncoderCfg(n_layers=1, n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024, pad_idx=-1, with_graph_mat=False, inp_len=100, dropout_rate=0.0, with_emb_mat=True)], decoders=[EncoderCfg(n_layers=1, n_heads=8, d_k=32, d_v=32, d_model=256, d_inner=1024, pad_idx=-1, with_graph_mat=False, inp_len=0, dropout_rate=0.0, with_emb_mat=False)])

In [12]:
ranker_cfg_fpath = cfg_dpath / 'ranker_model_cfg_01.yaml'
to_yaml_file(ranker_cfg_fpath, model_cfg)

PosixPath('/home/misha/prog/mllm/mllm/config')