# Training a causal language model from scratch (PyTorch)

Install the Transformers, Datasets, and Evaluate libraries to run this notebook.

In [1]:
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs

Collecting datasets
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/486.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers[sentencepiece]
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m91.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloadi

In [2]:
# This cell will take a very long time to execute, so you should skip it and go to
# the next one!
from datasets import load_dataset

split = "train"  # "valid"

data = load_dataset("gbharti/finance-alpaca")


Downloading readme:   0%|          | 0.00/486 [00:00<?, ?B/s]

Downloading and preparing dataset json/gbharti--finance-alpaca to /root/.cache/huggingface/datasets/gbharti___json/gbharti--finance-alpaca-47c3412d84dc0065/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/42.9M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/gbharti___json/gbharti--finance-alpaca-47c3412d84dc0065/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
from datasets import load_dataset, DatasetDict

ds_train = load_dataset("gbharti/finance-alpaca", split="train")
# ds_valid = load_dataset("gbharti/finance-alpaca", split="train")

raw_datasets = DatasetDict(
    {
        "train": ds_train,  # .shuffle().select(range(50000)),
    }
)
split_datasets = raw_datasets["train"].train_test_split(train_size=0.9, seed=20)
split_datasets["validation"] = split_datasets.pop("test")
split_datasets["train"]["instruction"][1]

split_datasets



DatasetDict({
    train: Dataset({
        features: ['text', 'instruction', 'input', 'output'],
        num_rows: 62020
    })
    validation: Dataset({
        features: ['text', 'instruction', 'input', 'output'],
        num_rows: 6892
    })
})

In [4]:
for key in split_datasets["train"][0]:
    print(f"{key.upper()}: {raw_datasets['train'][0][key][:200]}")

TEXT: 
INSTRUCTION: For a car, what scams can be plotted with 0% financing vs rebate?
INPUT: 
OUTPUT: The car deal makes money 3 ways. If you pay in one lump payment. If the payment is greater than what they paid for the car, plus their expenses, they make a profit. They loan you the money. You make p


In [52]:
from transformers import AutoTokenizer

context_length = 128
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

outputs = tokenizer(
    raw_datasets["train"][:15]["output"],
    truncation=True,
    max_length=context_length,
    return_overflowing_tokens=True,
    return_length=True,
)

print(f"Input IDs length: {len(outputs['input_ids'])}")
print(f"Input chunk lengths: {(outputs['length'])}")
print(f"Chunk mapping: {outputs['overflow_to_sample_mapping']}")

Input IDs length: 39
Input chunk lengths: [128, 128, 128, 66, 86, 99, 98, 128, 128, 128, 128, 114, 128, 24, 128, 128, 128, 128, 46, 128, 7, 128, 8, 75, 121, 128, 62, 128, 128, 128, 128, 100, 123, 128, 128, 128, 128, 128, 21]
Chunk mapping: [0, 0, 0, 0, 1, 2, 3, 4, 4, 4, 4, 4, 5, 5, 6, 6, 6, 6, 6, 7, 7, 8, 8, 9, 10, 11, 11, 12, 12, 12, 12, 12, 13, 14, 14, 14, 14, 14, 14]


In [53]:
def tokenize(element):
    outputs = tokenizer(
        element["output"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


tokenized_datasets = raw_datasets.map(
    tokenize, batched=True, remove_columns=split_datasets["train"].column_names
)
tokenized_datasets



DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 27853
    })
})

In [54]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

In [55]:
model = GPT2LMHeadModel(config)
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

GPT-2 size: 109.3M parameters


In [56]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

Using eos_token, but it is not set yet.


In [57]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
out = data_collator([tokenized_datasets["train"][i] for i in range(5)])
for key in out:
    print(f"{key} shape: {out[key].shape}")

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


input_ids shape: torch.Size([5, 128])
attention_mask shape: torch.Size([5, 128])
labels shape: torch.Size([5, 128])


In [58]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir="codeparrot-ds",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=5_000,
    logging_steps=5_000,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=split_datasets["validation"],
)

In [59]:
trainer.train()

Step,Training Loss,Validation Loss


TrainOutput(global_step=108, training_loss=8.123842027452257, metrics={'train_runtime': 317.7158, 'train_samples_per_second': 87.666, 'train_steps_per_second': 0.34, 'total_flos': 1806050525184000.0, 'train_loss': 8.123842027452257, 'epoch': 0.99})

In [62]:
import torch

# Set the device
device = torch.device("cpu")

# Move the model to the device
model.to(device)

# Define the input prompts
prompt = 'Investment strategy for businesses '

# Generate outputs based on the input prompts

inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_length=100)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Input Prompt:", prompt)
print("Generated Output:", generated_text)
print()

Input Prompt: Investment strategy for businesses 
Generated Output: investment strategy for businesses, and the, and the the same of the stock, and the stock, and the stock, and the. the stock, the stock, the the. the the the the the the the the the the the the the the the the the the the the the the the the the the the. the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the

