In [None]:
from transformers import AutoTokenizer

tokenizer_base = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
tokenizer_base.save_pretrained("Llama-3.1-8B-Base")

In [None]:
import json
tokenizer_json = json.load(open("Llama-3.1-8B-Base/tokenizer.json"))
vocab_base = tokenizer_json["model"]["vocab"]
merges_base = tokenizer_json["model"]["merges"]

In [None]:
len(vocab_base)

In [None]:
from collections import defaultdict
vocab_to_merges = defaultdict(list)
for k,v in vocab_base.items():
    if v < 256: continue
    print('------------------------------')
    for item in merges_base:
        left,right = item.split()
        if left+right == k:
            print('++',k,item)
            vocab_to_merges[k].append(item)

In [None]:
import pickle as pkl
pkl.dump(vocab_to_merges, open("./Llama-3.1-8B-Base/vocab_to_merges_MAPPING.pkl", "wb"))

In [None]:

tokenizer_json = json.load(open("Llama-3.1-8B-Base/tokenizer.json"))
vocab_base = tokenizer_json["model"]["vocab"]
merges_base = tokenizer_json["model"]["merges"]

In [None]:
import os, glob
import pickle as pkl
import json
from collections import defaultdict
os.mkdir("EBM-MEDVOC-FromScratch-Llama3_Vocab")

for fname in glob.glob("./EBM/*.txt"):
    print('***********Processing:',fname)
    vocab_to_merges = defaultdict(list) #pkl.load(open("./Llama-3.1-8B-Base/vocab_to_merges_MAPPING.pkl", "rb"))
    
    tokenizer_base = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
    
    tokenizer_json = json.load(open("Llama-3.1-8B-Base/tokenizer.json", 'r', encoding='utf-8'))
    vocab_base = tokenizer_json["model"]["vocab"]
    merges_base = tokenizer_json["model"]["merges"]
    
    words_to_add = open(fname,'r',encoding='utf-8').read().splitlines()
    words_to_add = sorted(words_to_add, key=lambda x: len(x))

    for word in words_to_add:
        split = tokenizer_base.tokenize(word if not word.startswith('Ġ') else ' '+word[1:])
        if len(split) == 1:
            continue
        
        if len(split) == 2: #pass
            vocab_to_merges[word] = [split[0],split[1]]
        
        if len(split) >= 3:
            # print('--word:',word,split)
            new_word = split[0]
            for i in range(1,len(split)):
                left = new_word
                right = split[i]
                new_word += split[i]
                # print('new_word:',new_word, 'Merge:',left,right)
                if new_word not in vocab_to_merges:
                    vocab_to_merges[new_word] = [left,right]
    
    idx = 0
    for key,val in vocab_to_merges.items():
        if key not in tokenizer_json["model"]["vocab"]:
        # print(key,val,idx)
            tokenizer_json["model"]["vocab"][key] = 128000+idx
            tokenizer_json["model"]["merges"].append(val)
            idx += 1
        
    tokenizer_json['post_processor']['processors'][-1]['special_tokens']['<|begin_of_text|>']['ids'] = [128000+idx]
    
    dump_dir = f'EBM-FromScratch-Llama3_Vocab/EBM_{fname.split("/")[-1][:-4]}'
    tokenizer_base.save_pretrained(dump_dir)
    
    with open(dump_dir+'/tokenizer.json', 'w', encoding='utf-8') as f:
        json.dump(tokenizer_json, f)
    f.close()

In [None]:
import pandas as pd
df = pd.read_csv('../Llama-3-EBM-MedicalLookup-Fragment/EBM_SplitMoreThan1_OOV.csv')
freq_ebm = df['Count'].to_list()
terms_EBM = df['Word'].to_list()
split_bart = df['Splits'].to_list()

sum_num = 0.
sum_den = 0.
for idx,term in enumerate(terms_EBM):
    sum_num += split_bart[idx]*freq_ebm[idx]
    sum_den += freq_ebm[idx]

old_score = sum_num/sum_den


import glob
from collections import defaultdict

dict_scores = defaultdict(lambda : defaultdict(dict))
for fname in sorted(glob.glob('./EBM-MEDVOC-FromScratch-Llama3_Vocab/*') ,key = lambda x: [int(x.split('/')[-1].split('_')[-3][:-1]),float(x.split('/')[-1].split('_')[-2])]):
    print('Processing:',fname)
    # if 'BioASQ_0K_0.0_' in fname: continue
    domain_tok = AutoTokenizer.from_pretrained(fname)
    sum_num = 0.
    sum_den = 0.
    
    for idx,term in enumerate(terms_EBM):
        sum_num += min(len(domain_tok.tokenize(term)),len(domain_tok.tokenize(' '+term)))*freq_ebm[idx]
        sum_den += freq_ebm[idx]

    key = fname.split('/')[-1].split('_')
    dict_scores[key[-3]][key[-2]] = [round(sum_num/sum_den,2),len(domain_tok)-128256]

with open(f'EBM-FERTILITY','a') as f:
    f.write(f'-------------\nEBM-MEDVOC-SelfFromScratch\n--------------------\n')
    f.write('BART_Tok: '+str(round(old_score,2))+'\n')
    for k1 in dict_scores:
        if k1 == '0K': continue
        f.write('data\t')
        for k2 in dict_scores[k1]:
            f.write(k2+'\t')
        f.write('\n')
        break

    for k1 in dict_scores:
        f.write(k1+'\t')
        for k2 in dict_scores[k1]:
            f.write(f'{dict_scores[k1][k2][0]}/{dict_scores[k1][k2][1]}\t')
        f.write('\n')
f.close()





In [None]:
idx = 0
gain_in_fragments = []
tokenizer = AutoTokenizer.from_pretrained("./BioASQ-FromScratch-Llama3_Vocab/BioASQ_10K_1.0_")
import numpy as np
with open('../../../../../TxtInputFiles/PAC_input.txt') as f:
    for line in f:
        try:
            line = line.strip()
            org_enc = tokenizer_base.encode(line)
            vocab_enc = tokenizer.encode(line)
            
            org_dec = tokenizer_base.decode(org_enc)
            vocab_dec = tokenizer.decode(vocab_enc)
            
            assert org_dec == vocab_dec, f'Failed at {idx}'
            
            gain_in_fragments.append((len(org_enc)-len(vocab_enc))/len(org_enc))
            
            idx += 1
            if idx%1000 == 0:
                print(f'Processed {idx}.... {np.percentile(gain_in_fragments, [0,10,50,90,100])}')
        except:
            print(f'--------------------Failed at {idx}--------------------')
            print(f'Orig : {org_dec}')
            print(f'Vocab: {vocab_dec}')
    

In [None]:
np.percentile(gain_in_fragments, [0,10,50,90,100])

In [None]:
s = "\u00c4\u00aa\u00c4\u00a8"
print(s)



In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B')

vocab_tokens = [x.strip() for x in open('./EBM/25K_0.5_.txt').read().splitlines()]

max_token, max_token_to_add = '',0
for word in vocab_tokens:
    tokens_to_add = len(tokenizer.tokenize(word if not word.startswith('Ġ') else ' '+word[1:]))-2
    if tokens_to_add > max_token_to_add:
        max_token = word
        max_token_to_add = tokens_to_add
        print(word,tokens_to_add)