In [1]:
%pip install --quiet transformers==4.34.1 accelerate==0.24.0 sentencepiece==0.1.99 optimum==1.13.2 peft==0.5.0 bitsandbytes==0.41.2.post2

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from tqdm.auto import tqdm , trange
assert torch.cuda.is_available(),"you need cuda for this"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
model_name = 'Enoch/llama-7b-hf'
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name,device_map = device)
tokenizer.pad_token_id = tokenizer.eos_token_id
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name , device_map = 'auto' , low_cpu_mem_usage = True , offload_state_dict = True ,
    load_in_4bit = True , torch_dtype = torch.float32 , #weights are 4 bit ,activation and layernorm are fp32
)
for param in model.parameters():
  param.requires_grad = False
model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

  return torch.load(checkpoint_file, map_location=map_location)


**PROMPT TUNE THE STORY OF A FOX **

In [3]:
prompt = 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
batch
for i in range(10):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))


Output: <s>A quick brown fox jumps over the lazy dog.
A quick


In [4]:
the_truth = "A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
outputs = model(**batch)
#print(outputs)
next_word_logits = outputs.logits[:, :-1]
true_next_tokens = batch['input_ids'][:, 1:]
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

print("Loss:", loss)

Loss: tensor(3.0725, device='cuda:0', grad_fn=<NllLossBackward0>)


the model cant be trained as it has large amount of parameters

In [5]:
import peft
assert isinstance(model.model.embed_tokens, nn.Embedding), "please reload the model"

peft_config = peft.PromptTuningConfig(task_type=peft.TaskType.CAUSAL_LM, num_virtual_tokens=16)
model = peft.get_peft_model(model, peft_config)  # note: for most peft methods, this line also modifies model in-place
print("Trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Total parameters (excluding quantization):", sum(p.numel() for p in model.parameters()))
print(model.print_trainable_parameters())

Trainable parameters: 65536
Total parameters (excluding quantization): 3500478464
trainable params: 65,536 || all params: 6,738,481,152 || trainable%: 0.0009725633792200893
None


step -2 is to train the trainable parameters in peft

In [19]:
the_truth = "A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False)
batch = batch.to(device)
outputs = model(**batch)
print(outputs[0].shape)
print(batch['input_ids'].shape)
next_word_logits = outputs.logits[:, 16 : -1, :]
print(next_word_logits.shape)
true_next_tokens = batch['input_ids'][:, 1:]
print(true_next_tokens.shape)
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
print("Loss:", loss)
opt = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(60):
  opt.zero_grad()
  outputs = model(**batch)
  next_word_logits = outputs.logits[:, 16 : -1, :]
  true_next_tokens = batch['input_ids'][:, 1:]
  loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
  loss.backward()
  opt.step()

  print(f"Epoch {epoch + 1}: Loss {loss.item()}")
  if loss.item() < 0.1:
    print("Looks good!")
    break

torch.Size([1, 39, 32000])
torch.Size([1, 23])
torch.Size([1, 22, 32000])
torch.Size([1, 22])
Loss: tensor(5.2252, device='cuda:0', grad_fn=<NllLossBackward0>)
Epoch 1: Loss 5.225242614746094
Epoch 2: Loss 5.164522171020508
Epoch 3: Loss 5.10544490814209
Epoch 4: Loss 5.047178745269775
Epoch 5: Loss 4.989747524261475
Epoch 6: Loss 4.932901859283447
Epoch 7: Loss 4.876589775085449
Epoch 8: Loss 4.820857524871826
Epoch 9: Loss 4.765735149383545
Epoch 10: Loss 4.711192607879639
Epoch 11: Loss 4.657163619995117
Epoch 12: Loss 4.603607177734375
Epoch 13: Loss 4.550509452819824
Epoch 14: Loss 4.497833251953125
Epoch 15: Loss 4.445532321929932
Epoch 16: Loss 4.393546104431152
Epoch 17: Loss 4.341867446899414
Epoch 18: Loss 4.290485858917236
Epoch 19: Loss 4.239383220672607
Epoch 20: Loss 4.188554763793945
Epoch 21: Loss 4.13797664642334
Epoch 22: Loss 4.087636947631836
Epoch 23: Loss 4.037542819976807
Epoch 24: Loss 3.9876744747161865
Epoch 25: Loss 3.938027858734131
Epoch 26: Loss 3.88859009

In [20]:
prompt = 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

for i in range(10):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))


Output: <s>A quick brown fox frown frown frown frown frown
