diff --git a/config/finetune_shakespeare.py b/config/finetune_shakespeare.py index 148a4c4fa..9343e5c7f 100644 --- a/config/finetune_shakespeare.py +++ b/config/finetune_shakespeare.py @@ -1,4 +1,8 @@ import time +from functools import partial + +import torch +from minlora import LoRAParametrization out_dir = 'out-shakespeare' eval_interval = 5 @@ -9,6 +13,8 @@ dataset = 'shakespeare' init_from = 'gpt2-xl' # this is the largest GPT-2 model +init_from = 'gpt2-large' # use a smaller for faster training +# xl doesn't fit on 24GB GPU, but with LORA it does # only save checkpoints if the validation loss improves always_save_checkpoint = False @@ -23,3 +29,18 @@ # finetune at constant LR learning_rate = 3e-5 decay_lr = False + + +use_lora = True +learning_rate = 1e-3 # use a higher LR for LoRA +lora_dropout_p = 0.0 +rank=4 +lora_alpha = 64 +lora_config = { + torch.nn.Embedding: { + "weight": partial(LoRAParametrization.from_embedding, rank=rank, lora_alpha=lora_alpha), + }, + torch.nn.Linear: { + "weight": partial(LoRAParametrization.from_linear, rank=rank, lora_alpha=lora_alpha), + }, +} \ No newline at end of file diff --git a/sample.py b/sample.py index 670759bc1..118a40a0d 100644 --- a/sample.py +++ b/sample.py @@ -7,6 +7,7 @@ import torch import tiktoken from model import GPTConfig, GPT +import minlora # ----------------------------------------------------------------------------- init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') @@ -38,12 +39,23 @@ checkpoint = torch.load(ckpt_path, map_location=device) gptconf = GPTConfig(**checkpoint['model_args']) model = GPT(gptconf) + if use_lora: + minlora.add_lora(model, lora_config) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k,v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) + if use_lora: + # the full state dict includes the LoRA state dict + # so actually we don't need to load it separately again + model.load_state_dict(checkpoint['lora'], strict=False) + print('Loaded LoRA state dict') + # sanity check + #model.apply(minlora.apply_to_lora(lambda m: print((m.lora_A.sum(), m.lora_B.sum())))) + # merge for zero-overhead inference + minlora.merge_lora(model) elif init_from.startswith('gpt2'): # init from a given GPT-2 model model = GPT.from_pretrained(init_from, dict(dropout=0.0)) diff --git a/train.py b/train.py index 30d0145ef..9e44c114a 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ import time import math import pickle +import inspect from contextlib import nullcontext import numpy as np @@ -29,6 +30,7 @@ from model import GPTConfig, GPT +import minlora # ----------------------------------------------------------------------------- # default config values designed to train a gpt2 (124M) on OpenWebText # I/O @@ -180,13 +182,46 @@ def get_batch(split): if block_size < model.config.block_size: model.crop_block_size(block_size) model_args['block_size'] = block_size # so that the checkpoint will have the right value +if use_lora: + minlora.add_lora(model, lora_config=lora_config) + minlora.tie_weights(linear=model.lm_head, embedding=model.transformer.wte) model.to(device) # initialize a GradScaler. If enabled=False scaler is a no-op scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) # optimizer -optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) +def configure_optimizers_lora(self, weight_decay, learning_rate, betas, device_type): + # we apply weight decay to all lora params + optim_groups = [ + # note: .get_lora_params() returns a generator + # we need to wrap it in a list so we can consume it twice + {"params": list(minlora.get_lora_params(self)) , "weight_decay": weight_decay}, + # you can also add biases for fine-tuning, + # but I want to make sure lora alone works + # {"params": minlora.get_bias_params(self), "weight_decay": 0.0}, # bias params don't get weight decay + ] + + def parameter_count(optim_groups): + n = sum(p.numel() for group in optim_groups for p in group["params"]) + if n < 1e6: + return f"{n/1e3:.1f}k" + else: + return f"{n/1e6:.1f}M" + + print(f"optimizing {parameter_count(optim_groups)} parameters") + + # new PyTorch nightly has a new 'fused' option for AdamW that is much faster + use_fused = (device_type == "cuda") and ("fused" in inspect.signature(torch.optim.AdamW).parameters) + print(f"using fused AdamW: {use_fused}") + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + + return optimizer +if use_lora: + optimizer = configure_optimizers_lora(model, weight_decay, learning_rate, (beta1, beta2), device_type) +else: + optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) if init_from == 'resume': optimizer.load_state_dict(checkpoint['optimizer']) @@ -271,6 +306,8 @@ def get_lr(it): 'best_val_loss': best_val_loss, 'config': config, } + if use_lora: + checkpoint['lora'] = minlora.get_lora_state_dict(raw_model) print(f"saving checkpoint to {out_dir}") torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) if iter_num == 0 and eval_only: