In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from gpt2_patch import replace_gpt2_attn_with_flash_attn

print('Patching gpt2')
#replace_gpt2_attn_with_flash_attn()
print('Patched gpt2')

import torch
torch.manual_seed(0)
import random
random.seed(0)

import transformers
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_int8_training,
    TaskType
)
from peft.tuners.lora import LoraLayer


from data_utils import prepare_tokenizer_dataset
from model_utils import get_model
from config import load_training_config

  from .autonotebook import tqdm as notebook_tqdm


Patching gpt2
Patched gpt2


: 

In [2]:
def upcast_layer_for_flash_attention(model, torch_dtype):
    # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
    # convert them back to fp16/bf16 for flash-attn compatibility.
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            module.to(torch_dtype)
        if 'wpe' in name or 'ln_1' in name or 'ln_2' in name or 'ln_f' in name:
            module.to(torch_dtype)
        if 'wte' in name:
            module.to(torch_dtype)
    return model

def downcast_layer_for_flash_attention(model, torch_dtype):
    # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
    # convert them back to fp16/bf16 for flash-attn compatibility.
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            #module.to(torch.float32)
            continue
        if isinstance(module, torch.nn.Embedding):
            module.to(torch.float16)
            continue
        if module.dtype == torch.float32:
            module.to(torch_dtype)

    return model

In [3]:
config_path = "training_config.yaml"
format_version = 2

In [4]:
config = load_training_config(config_path)
print(config)

TrainingConfig(lora_r=16, lora_alpha=32, lora_dropout=0.05, lora_target_modules=['c_attn', 'c_proj'], modules_to_save=['wpe'], micro_batch_size=2, gradient_accumulation_steps=8, learning_rate=0.0001, train_steps=300, warmup_steps=80, max_ctx_len=2048, output_dir='experiments_wiki', logging_steps=5, eval_steps=50, save_steps=50, save_total_limit=8)


In [5]:
train_data, val_data, tokenizer = prepare_tokenizer_dataset(format_v=format_version, max_ctx_len=config.max_ctx_len)

Loading dataset...


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loaded.
Convert to pandas...
Converted.
Splitting...
Splitted. Train samples: 10000. Test samples: 1000


Map: 100%|██████████| 1000/1000 [00:00<00:00, 1394.92 examples/s]


In [6]:
model = get_model()
model.config.use_cache = False

Loading checkpoint shards: 100%|██████████| 6/6 [01:07<00:00, 11.27s/it]


In [7]:
bit8_model = prepare_model_for_int8_training(model)



In [8]:
lora_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha,
    target_modules=config.lora_target_modules,
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    # IMPORTANT! Unfreeze embedding layer to allow for new token finetuning
    # See https://github.com/huggingface/peft/issues/349#issuecomment-1527059611
    # See https://github.com/huggingface/peft/issues/334
    # See https://github.com/huggingface/peft/pull/337#issuecomment-1527412343
    # Also unfreeze classification head to allow for new token classes
    modules_to_save=config.modules_to_save,
)

In [9]:
peft_model = get_peft_model(bit8_model, lora_config)

In [10]:
peft_model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50272, 5120)
        (wpe): ModulesToSaveWrapper(
          (original_module): Embedding(2048, 5120)
          (modules_to_save): ModuleDict(
            (default): Embedding(2048, 5120)
          )
        )
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0): GPT2Block(
            (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2Attention(
              (c_attn): lora.Linear8bitLt(
                (base_layer): Linear8bitLt(in_features=5120, out_features=15360, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=5120, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
              

In [11]:
input = tokenizer(['test'], return_tensors="pt", max_length=2048, truncation=True)
ids = input['input_ids']
tgts = ids.clone()
with torch.no_grad():
    outputs = peft_model(ids, labels=tgts)



In [12]:
outputs.logits.dtype

torch.float32

In [13]:
peft_model.print_trainable_parameters()

trainable params: 46,530,560 || all params: 12,899,993,600 || trainable%: 0.36070219445690266


In [14]:
training_arguments = transformers.TrainingArguments(
    per_device_train_batch_size=config.micro_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    warmup_steps=config.warmup_steps,
    max_steps=config.train_steps,
    learning_rate=config.learning_rate,
    fp16=True,
    logging_steps=config.logging_steps,
    optim="adamw_torch",
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=config.eval_steps,
    save_steps=config.save_steps,
    output_dir=config.output_dir,
    save_total_limit=config.save_total_limit,
    report_to="tensorboard"
)
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
)

In [15]:
trainer = transformers.Trainer(
    model=peft_model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=training_arguments,
    data_collator=data_collator
)

In [16]:
with torch.autocast("cuda"):
    trainer.train()

Step,Training Loss,Validation Loss


In [None]:
tokenizer.pad_token = tokenizer.eos_token