In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('../PalmLM-70000-tokenizer')

In [3]:
# Load model directly
from transformers import AutoModelForCausalLM

# tokenizer = AutoTokenizer.from_pretrained("DAMO-NLP-MT/polylm-1.7b")
model = AutoModelForCausalLM.from_pretrained("DAMO-NLP-MT/polylm-1.7b")

  _torch_pytree._register_pytree_node(


  _torch_pytree._register_pytree_node(


In [4]:
from datasets import load_dataset, interleave_datasets

# lang = ['', '', '', '', '', '', '', '', '', '']

train_afaa = load_dataset("castorini/afriberta-corpus", "afaanoromoo", split="train")
test_afaa = load_dataset("castorini/afriberta-corpus", "afaanoromoo", split="test")
train_amh = load_dataset("castorini/afriberta-corpus", "amharic", split="train")
test_amh = load_dataset("castorini/afriberta-corpus", "amharic", split="test")
train_gah = load_dataset("castorini/afriberta-corpus", "gahuza", split="train")
test_gah = load_dataset("castorini/afriberta-corpus", "gahuza", split="test")
train_hau = load_dataset("castorini/afriberta-corpus", "hausa", split="train")
test_hau = load_dataset("castorini/afriberta-corpus", "hausa", split="test")
train_igb = load_dataset("castorini/afriberta-corpus", "igbo", split="train")
test_igb = load_dataset("castorini/afriberta-corpus", "igbo", split="test")
train_som = load_dataset("castorini/afriberta-corpus", "somali", split="train")
test_som = load_dataset("castorini/afriberta-corpus", "somali", split="test")
train_swa = load_dataset("castorini/afriberta-corpus", "swahili", split="train")
test_swa = load_dataset("castorini/afriberta-corpus", "swahili", split="test")
train_tig = load_dataset("castorini/afriberta-corpus", "tigrinya", split="train")
test_tig = load_dataset("castorini/afriberta-corpus", "tigrinya", split="test")
train_yor = load_dataset("castorini/afriberta-corpus", "yoruba", split="train")
test_yor = load_dataset("castorini/afriberta-corpus", "yoruba", split="test")


multilingual_train = interleave_datasets([train_afaa, train_amh, train_gah, train_hau, train_igb])
multilingual_test = interleave_datasets([test_afaa, test_amh, test_gah, test_hau, test_igb])


In [5]:
from datasets import DatasetDict

raw_datasets = DatasetDict(
    {
        "train": multilingual_train.shuffle().select(range(50000)),
        "valid": multilingual_test.shuffle().select(range(500))
    }
)

raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'text'],
        num_rows: 50000
    })
    valid: Dataset({
        features: ['id', 'text'],
        num_rows: 500
    })
})

In [6]:
for key in raw_datasets["train"][0]:
    print(f"{key.upper()}: {raw_datasets['train'][0][key][:200]}")

ID: 99045
TEXT: Ihe omuma banyere Gris na ndi Grik.



In [7]:

context_length = 128

outputs = tokenizer(
    raw_datasets["train"][:2]["text"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

Input IDs length: 2
Input chunk lengths: [38, 105]
Chunk mapping: [0, 1]


In [8]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
)
tokenized_datasets

Map: 100%|██████████| 50000/50000 [00:04<00:00, 11275.14 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 12330.97 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 25422
    })
    valid: Dataset({
        features: ['input_ids'],
        num_rows: 239
    })
})

In [9]:
model_size = sum(t.numel() for t in model.parameters())
print(f"PalmLM size: {model_size/1000**2:.1f}M parameters")

PalmLM size: 1737.1M parameters


In [10]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [11]:
out = data_collator([tokenized_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


input_ids shape: torch.Size([5, 128])
attention_mask shape: torch.Size([5, 128])
labels shape: torch.Size([5, 128])


In [14]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="../AfriPalmLM",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=5_000,
    logging_steps=5_000,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["valid"],
)

In [None]:
trainer.train()