In [6]:
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from transformers import GPT2Model, GPT2Config, GPT2Tokenizer

def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

context_length = 128
tokenized_datasets = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names
)

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [7]:
import torch
import math
from torch import nn
import torch.nn.functional as F

class nanoLoRA(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 8,
        lora_alpha: int = 1
    ):
        super().__init__(in_features, out_features)
        assert r > 0, "r must be > 0"

        self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
        self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
        self.scaling = lora_alpha / r
        self.weight.requires_grad = False
        self.merged = False
        self.reset_lora_parameters()

    def reset_lora_parameters(self):
        super().reset_parameters()
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def train(self, mode: bool = True):
        super().train(mode)
        if mode:
            if self.merged:
                # Make sure that the weights are not merged for training
                self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
                self.merged = False
        else:
            if not self.merged:
                # Merge the weights for inference
                self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        if not self.merged:
            out = F.linear(x, self.weight, bias=self.bias)
            out += (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
            return out
        else:
            return F.linear(x, self.weight, bias=self.bias)

In [8]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig

# Initialize the model
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=1024,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
model = GPT2LMHeadModel(config)

# Freeze all parameters in the model
for param in model.parameters():
    param.requires_grad = False

# Modify the layers and unfreeze these parameters
for i in range(config.n_layer):
    model.transformer.h[i].mlp.c_fc = nanoLoRA(config.n_embd, config.n_embd, r=8)
    model.transformer.h[i].mlp.c_proj = nanoLoRA(config.n_embd, config.n_embd, r=8)

num_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_train_params:,} trainable parameters.')

The model has 313,344 trainable parameters.


In [10]:
from transformers import Trainer, TrainingArguments

import torch

if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
    device = 'cpu'
else:
    device = 'mps'
    print("MPS is available and will be used.")

args = TrainingArguments(
    output_dir="results",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=50,
    logging_steps=50,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    weight_decay=0.1,
    warmup_steps=10,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    report_to="tensorboard",
    use_mps_device= (device=="mps"),
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
)

trainer.train()

MPS is available and will be used.


Step,Training Loss,Validation Loss


TrainOutput(global_step=266, training_loss=9.84836050621549, metrics={'train_runtime': 202.845, 'train_samples_per_second': 41.845, 'train_steps_per_second': 1.311, 'total_flos': 279368589901824.0, 'train_loss': 9.84836050621549, 'epoch': 1.0})

In [11]:
trainer.evaluate()

{'eval_loss': 9.665497779846191,
 'eval_runtime': 5.1975,
 'eval_samples_per_second': 170.851,
 'eval_steps_per_second': 5.387,
 'epoch': 1.0}