In [1]:
!pip install transformers datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import Dataset

# Load a tiny model
model_name = "sshleifer/tiny-gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
# Just two simple text prompts to simulate fine-tuning
data = {"text": ["Hello, how are you?", "The capital of Italy is Rome."]}
dataset = Dataset.from_dict(data)

def tokenize(example):
    tokens = tokenizer(example["text"], truncation=True, padding="max_length", max_length=64)
    tokens["labels"] = tokens["input_ids"]
    return tokens

tokenized_dataset = dataset.map(tokenize)
training_args = TrainingArguments(
    output_dir="./checkpoint-v1",
    per_device_train_batch_size=1,
    num_train_epochs=1,
    save_steps=5,
    logging_steps=1,
    save_total_limit=1,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()

# Save manually just to be safe
trainer.save_model("./checkpoint-v1")
# Load model from checkpoint
model_checkpoint = AutoModelForCausalLM.from_pretrained("./checkpoint-v1")

trainer = Trainer(
    model=model_checkpoint,
    args=TrainingArguments(
        output_dir="./checkpoint-v2",
        per_device_train_batch_size=1,
        num_train_epochs=1,
        logging_steps=1,
        save_steps=5,
        save_total_limit=1,
        report_to="none"
    ),
    train_dataset=tokenized_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()


Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.51M [00:00<?, ?B/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
1,9.0146
2,9.2762


model.safetensors:   0%|          | 0.00/2.51M [00:00<?, ?B/s]

Step,Training Loss
1,9.0144
2,9.2755


TrainOutput(global_step=2, training_loss=9.144913673400879, metrics={'train_runtime': 0.2218, 'train_samples_per_second': 9.017, 'train_steps_per_second': 9.017, 'total_flos': 116736.0, 'train_loss': 9.144913673400879, 'epoch': 1.0})