In [1]:
#The UMT5 model is really new so looks like they haven't added full Tensorflow support yet so I had to install PyTorch as well
#%pip install torch torchvision torchaudio

### Adapt UMT5 model for single-language use.
Credit: https://towardsdatascience.com/how-to-adapt-a-multilingual-t5-model-for-a-single-language-b9f94f3d9c90

In [2]:
#import model and tokenizer
from transformers import UMT5Model, T5Tokenizer

In [None]:
#instantiate the model and tokenizer using the umt5-small pretrained model
model = UMT5Model.from_pretrained("google/umt5-small")
tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")

In [None]:
#This is supposed to estimate the frequency of different tokens. Not entirely certain how it works though!
import pandas as pd
import csv
from collections import Counter
from tqdm.auto import tqdm, trange

df_ar = pd.read_csv('ara_wikipedia_2021_1M-sentences.txt', sep='\t', header=None, quoting=csv.QUOTE_NONE)
df_ar.columns = ['idx', 'text']
cnt_ar = Counter()
for text in tqdm(df_ar.text):
    cnt_ar.update(tokenizer.encode(text))
print(len(cnt_ar), len(cnt_ar)/tokenizer.vocab_size)
# 53217 0.20763558330081935 (so basically, only 20% of the model vocabulary was used.)

In [None]:
for top in 10_000, 20_000, 30_000:
    print(top, sum(v for k, v in cnt_ar.most_common(top)) / sum(cnt_ar.values()))
# 10000 0.986
# 20000 0.998
# 30000 0.999
# this means that more than 99% of the tokens fall into the top 20K, so we don't really need all 53217 tokens

In [None]:
#do the same for english
df_en = pd.read_csv('eng_wikipedia_2016_1M-sentences.txt', sep='\t', header=None, quoting=csv.QUOTE_NONE)
df_en.columns = ['idx', 'text']
cnt_en = Counter()
for text in tqdm(df_en.text):
    cnt_en.update(tokenizer.encode(text))

The vocabulary (adapted from the tutorial) is comprised as follows:
- the top 1000 tokens regardless of language (just to be "safe")
- the top 10,000 tokens for English (cuz in real life they're often mixed, especially in informal documents)
- the top 20,000 Arabic tokens
- The 300 special tokens that UMT5 uses (aka sentinel ids) - to be added later on

In [None]:
def build_vocab():
    #get the top 1000 tokens
    new_tokens = set(range(1000))

    #get the top 10,000 tokens for English
    for i, (k, v) in enumerate(cnt_en.most_common(10_000)):
        if k not in new_tokens:
            new_tokens.add(k)

    #get the top 20,000 Arabic tokens
    for i, (k, v) in enumerate(cnt_ar.most_common(25_000)):
        if len(new_tokens) == 29_700:
            print(i, 'Arabic tokens are included')
            break
        if k not in new_tokens:
            new_tokens.add(k)

    print(len(new_tokens))
    return sorted(new_tokens)

kept_ids = build_vocab()

In [None]:
#For some reason these commands weren't working in VSCode in Jupyter so I just ran them from terminal
"""
%wget https://raw.githubusercontent.com/google/sentencepiece/master/src/sentencepiece_model.proto
%protoc --python_out=. sentencepiece_model.proto
"""

In [None]:
import sentencepiece_model_pb2 as spmp

smp = tokenizer.sp_model.serialized_model_proto()
m = spmp.ModelProto()
m.ParseFromString(smp)
print('the loaded model has pieces:', len(m.pieces))

In [None]:
new_pieces = []
for idx in kept_ids:
    new_pieces.append(m.pieces[idx])
    
print('the new pieces:', len(new_pieces))

# replace the content of the first 27K pieces
for i, p in enumerate(new_pieces):
    m.pieces[i].piece = p.piece
    m.pieces[i].score = p.score
    m.pieces[i].type = p.type

# drop the remaining pieces
n = len(new_pieces)
for i in trange(len(m.pieces) - n):
    m.pieces.pop(len(m.pieces) - 1)

print(len(m.pieces))
with open('new_sp.model', 'wb') as f:
    f.write(m.SerializeToString())

new_tokenizer = T5Tokenizer('new_sp.model', extra_ids=300)

In [11]:
def reorder_sentinels():
    sentinel_old = tokenizer.get_sentinel_tokens()
    sentinel_new = new_tokenizer.get_sentinel_tokens()
    ids = tokenizer.get_sentinel_token_ids()
    temporary_list = []
    for a_token in sentinel_new:
        token_idx = sentinel_old.index(a_token)
        temporary_list.append(ids[token_idx])
    kept_ids.extend(temporary_list)

reorder_sentinels()

In [12]:
import torch #imports PyTorch

#this updates the neural network by replacing the parameters of its input and output embeddings. (yes, I copied and pasted that lol)
#since this is just a UMT5Model not a T5ForConditionalGeneration, it doesn't have a language processing head, so I deleted the lines involving lm_head
new_size = len(kept_ids)
new_emb = torch.nn.Embedding(new_size, model.shared.embedding_dim)
for new_id, old_id in enumerate(kept_ids):
    new_emb.weight.data[new_id] = model.shared.weight.data[old_id]
model.shared.weight = new_emb.weight
model.config.__dict__['vocab_size'] = new_size
model.config.__dict__['_name_or_path'] = 'cointegrated/arabt5-base'

In [None]:
new_tokenizer.save_pretrained('ar-umt5-base')
model.save_pretrained('ar-umt5-base')