In [1]:
%pip install --quiet transformers==4.34.1 accelerate==0.24.0 sentencepiece==0.1.99 optimum==1.13.2 peft==0.5.0 bitsandbytes==0.41.2.post2

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from tqdm.auto import tqdm , trange
assert torch.cuda.is_available(),"you need cuda for this"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.5/121.5 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.0/301.0 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m56.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m16.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.6/85.6 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m6.7 MB/s[0m eta [36

AssertionError: you need cuda for this

In [None]:
model_name = 'Enoch/llama-7b-hf'
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name,device_map = device)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name , device_map = 'auto' , low_cpu_mem_usage = True , offload_state_dict = True ,
    load_in_4bit = True , torch_dtype = torch.float32 , #weights are 4 bit ,activation and layernorm are fp32
)
for param in model.parameters():
  param.requires_grad = False
model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad


In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    load_in_4bit=True, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

In [None]:
prompts =  ['', 'import', 'from', 'while', 'try', 'if', 'for', 'torch']
for prompt in prompts:
    batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
    for i in range(20):
        next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
        batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
        batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)
    print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))



In [None]:
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
    """Wraps a linear layer with LoRA-like adapter. Wraps an existing OPT linear layer"""
    def __init__(self, module: nn.Linear, rank: int):
        super().__init__()
        self.module = module  # pre-trained (frozen) linear layer
        self.adapter_A = nn.Parameter(torch.empty(module.in_features, rank, device=module.weight.device))
        nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
        self.adapter_B = nn.Parameter(torch.zeros(rank, module.out_features, device=module.weight.device))

    def forward(self, input):
        # Apply self.module and LoRA adapter, return the sum (self.module outputs + adapter outputs)
        original_output = self.module(input)
        adapter_output = torch.matmul(input, self.adapter_A)
        adapter_output = torch.matmul(adapter_output, self.adapter_B)
        return original_output + adapter_output

LoRA adapters on top of Q/K/V linear layers in Llama attention

In [None]:
lora_rank = 8

for name, module in model.model.layers.named_modules():
    if 'LlamaDecoderLayer' in repr(type(module)):
        module.self_attn.q_proj = LoRALayer(module.self_attn.q_proj, rank=lora_rank).to(device)
        module.self_attn.k_proj = LoRALayer(module.self_attn.k_proj, rank=lora_rank).to(device)
        module.self_attn.v_proj = LoRALayer(module.self_attn.v_proj, rank=lora_rank).to(device)

assert sum(isinstance(module, LoRALayer) for module in model.modules()) == 96  # for Llama-7B

we use the validations dataset used to train the codeparrots model

In [None]:
import datasets
data = datasets.load_dataset("codeparrot/codeparrot-clean-valid")

In [None]:
print(data.column_names)
data = data.map(lambda example: {"content": example["content"][:512]})
data = data.map(lambda samples: tokenizer(samples['content']), batched=True)
print(data.column_names)

In [None]:
model._hf_peft_config_loaded = True  # silence a warning from HF trainer
trainer = transformers.Trainer(
    model=model, train_dataset=data['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=4, gradient_accumulation_steps=4,
        # note: if you want larger batch size, increase gradient_accumulation_steps
        warmup_steps=250, max_steps=100, learning_rate=2e-4, fp16=True,
        logging_steps=1, output_dir='outputs', report_to=None),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

In [None]:
trainer.train()

In [None]:
trainer.state.log_history

In [None]:
prompts =  ['', 'import', 'from', 'while', 'try', 'if', 'for', 'torch']
for prompt in prompts:
    batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
    for i in range(20):
        next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
        batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
        batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)
    print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))

