# Continue Pretraining on T5

In [1]:
from datasets import load_from_disk

tokenized_dataset = load_from_disk('../Data/tokenized-pretraining-ds')

Loading dataset from disk:   0%|          | 0/17 [00:00<?, ?it/s]

In [2]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('../../Models/my_tokenizer')

## Train the Model

In [3]:
%env WANDB_PROJECT=english-v2

env: WANDB_PROJECT=english-v2


In [4]:
from transformers import TrainingArguments, Trainer, T5ForConditionalGeneration, Adafactor
from accelerate import Accelerator

model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-base", device_map='cuda:0')
model.resize_token_embeddings(len(tokenizer))

accelerator = Accelerator()

optimizer = Adafactor(
    model.parameters(), 
    scale_parameter=False, 
    relative_step=False, 
    warmup_init=False, 
    lr=3e-4
)

model, optimizer = accelerator.prepare(model, optimizer)

training_args = TrainingArguments(
    output_dir="buddhist-base-pretrain",
    auto_find_batch_size=True,
    learning_rate=3e-4,
    num_train_epochs=1
    )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    optimizers=(optimizer, None)
)

In [5]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbillingsmoore[0m. Use [1m`wandb login --relogin`[0m to force relogin


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
500,0.355
1000,0.2422
1500,0.2344
2000,0.2185
2500,0.2136
3000,0.2132
3500,0.2115
4000,0.2071
4500,0.1987
5000,0.1962


TrainOutput(global_step=215355, training_loss=0.13886161733876504, metrics={'train_runtime': 102647.377, 'train_samples_per_second': 16.784, 'train_steps_per_second': 2.098, 'total_flos': 5.245666792125235e+17, 'train_loss': 0.13886161733876504, 'epoch': 1.0})

In [6]:
model.save_pretrained('../Models/pretrained-base-model')