In [1]:
!pip install --quiet --upgrade transformers bitsandbytes accelerate sentencepiece optimum auto-gptq
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import BitsAndBytesConfig
from tqdm.auto import tqdm, trange
assert torch.cuda.is_available(), "you need cuda for this part"

You should consider upgrading via the '/users/u29/bobbyd/RIT_LLM/.venv/bin/python3 -m pip install --upgrade pip' command.[0m


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
model_name = 'Enoch/llama-7b-hf'

# loading Llama tokenizer ...
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

# ... and the model itself
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    quantization_config=quantization_config, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad
# more on gradient checkpointing: https://pytorch.org/docs/stable/checkpoint.html https://arxiv.org/abs/1604.06174

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [7]:
prompt = 'Tomorrow is the Spring break' # 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

for i in range(10):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))


Output: <s>Tomorrow is the Spring break for the kids. I am so excited.


In [8]:
class WordEmbeddingsWithLearnedPrompts(nn.Module):
    """
    Replace model's original word embeddings with a layer that inserts trainable prompts instead of the first N token embeddings.
    """
    def __init__(self, word_embeddings: nn.Embedding, num_prompts: int):
        super().__init__()
        self.original_word_embeddings = word_embeddings
        self.num_prompts = num_prompts
        self.learnable_prompts = nn.Parameter(
            torch.randn(1, num_prompts, word_embeddings.embedding_dim), requires_grad=True
        )

    def forward(self, input_ids: torch.LongTensor):
        # Ensure input_ids are of correct type and length
        assert input_ids.dtype == torch.int64
        assert input_ids.shape[1] > self.num_prompts, "Input sequence must be longer than the number of prompts"
        assert (input_ids[:, :self.num_prompts] == tokenizer.pad_token_id).all(), \
            "Ensure the first `num_prompts` tokens are PAD tokens"

        # Embed input_ids
        embedded_input_ids = self.original_word_embeddings(input_ids)

        # Replace the first `num_prompts` token embeddings with learnable prompts using concatenation
        prompt_embeds = self.learnable_prompts.expand(input_ids.shape[0], -1, -1)  # [batch_size, num_prompts, embedding_dim]
        embedded_input_ids = torch.cat([prompt_embeds, embedded_input_ids[:, self.num_prompts:, :]], dim=1)

        return embedded_input_ids

In [9]:
num_prompts = 16
test_emb_layer = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)
test_input_ids = tokenizer("a cat sat on a mat", return_tensors='pt')['input_ids'].to(device)

space_for_prompts = torch.full([len(test_input_ids), num_prompts], fill_value=tokenizer.pad_token_id,
                               dtype=torch.int64, device=device)
test_inputs_with_prompts = torch.cat([space_for_prompts, test_input_ids], dim=1)

with torch.amp.autocast('cuda'):
  test_prompt_embeddings = test_emb_layer(test_inputs_with_prompts)

assert test_prompt_embeddings.shape[:2] == test_inputs_with_prompts.shape
assert test_prompt_embeddings.shape[-1] == model.config.hidden_size
assert torch.allclose(test_prompt_embeddings[:, :num_prompts], test_emb_layer.learnable_prompts.float())
assert torch.allclose(test_prompt_embeddings[:, num_prompts:], model.model.embed_tokens(test_input_ids).float())
print("Looks legit!")

Looks legit!


In [8]:
assert isinstance(model.model.embed_tokens, nn.Embedding), "you have already replaced the embedding layer. If the replacement is broken, please reload the model"

model.model.embed_tokens = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)

opt = torch.optim.Adam([model.model.embed_tokens.learnable_prompts], lr=0.01)

In [9]:
the_truth = "Tomorrow is the Spring break, I will miss the school!" #"A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
space_for_prompts = torch.full([len(test_input_ids), num_prompts], fill_value=tokenizer.pad_token_id,
                               dtype=torch.int64, device=device)
batch['input_ids'] = torch.cat([space_for_prompts, batch['input_ids']], dim=1)
batch['attention_mask'] = torch.cat([torch.ones_like(space_for_prompts), batch['attention_mask']], dim=1)

outputs = model(**batch)
next_word_logits = outputs.logits[:, num_prompts : -1, :]
true_next_tokens = batch['input_ids'][:, num_prompts + 1:]
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
print("Loss:", loss)
scaler = torch.amp.GradScaler('cuda')

loss_threshold = 0.1
epoch = 0

while True:
    opt.zero_grad()

    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        outputs = model(**batch)
        next_word_logits = outputs.logits[:, num_prompts:-1, :]
        true_next_tokens = batch['input_ids'][:, num_prompts+1:]
        loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

    # Backpropagate using mixed precision
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

    print(f"Epoch {epoch}: Loss = {loss.item()}")

    if loss.item() <= loss_threshold:
        break

    epoch += 1

assert loss.item() <= 0.1
print("Good job!")

Loss: tensor(7.6555, device='cuda:0', grad_fn=<NllLossBackward0>)
Epoch 0: Loss = 7.653395652770996
Epoch 1: Loss = 7.653395652770996
Epoch 2: Loss = 7.653395652770996
Epoch 3: Loss = 7.653395652770996
Epoch 4: Loss = 7.09840726852417
Epoch 5: Loss = 6.675856590270996
Epoch 6: Loss = 6.310020923614502
Epoch 7: Loss = 5.953537940979004
Epoch 8: Loss = 5.669583797454834
Epoch 9: Loss = 5.462852954864502
Epoch 10: Loss = 5.237905502319336
Epoch 11: Loss = 5.023512840270996
Epoch 12: Loss = 4.772573471069336
Epoch 13: Loss = 4.561899185180664
Epoch 14: Loss = 4.367037296295166
Epoch 15: Loss = 4.18359375
Epoch 16: Loss = 4.008939266204834
Epoch 17: Loss = 3.808668851852417
Epoch 18: Loss = 3.598182201385498
Epoch 19: Loss = 3.403245210647583
Epoch 20: Loss = 3.196026086807251
Epoch 21: Loss = 2.989633321762085
Epoch 22: Loss = 2.781024694442749
Epoch 23: Loss = 2.56689453125
Epoch 24: Loss = 2.376840353012085
Epoch 25: Loss = 2.185246467590332
Epoch 26: Loss = 2.009840726852417
Epoch 27: L

In [10]:
prompt = 'Tomorrow is the Spring break' # 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
batch['input_ids'] = torch.cat([space_for_prompts, batch['input_ids']], dim=1)
batch['attention_mask'] = torch.cat([torch.ones_like(space_for_prompts), batch['attention_mask']], dim=1)


for i in range(17):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0, num_prompts:].cpu().numpy().tolist()))

# if you did everything right, the model will deny that the fox jumped over the lazy dog


Output: <s>Tomorrow is the Spring break, I will miss the school! The! The! The! The! The!


In [3]:
# for name, layer in model.model.layers.named_modules():
#     if isinstance(layer, torch.nn.Linear):
#         print(name, layer)

Reload the model!

In [11]:
model_name = 'Enoch/llama-7b-hf'

# loading Llama tokenizer ...
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

# ... and the model itself
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    quantization_config=quantization_config, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad
# more on gradient checkpointing: https://pytorch.org/docs/stable/checkpoint.html https://arxiv.org/abs/1604.06174

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [12]:
import peft

In [13]:
from peft import LoraConfig, TaskType

In [17]:
#model.config #, tokenizer)

In [14]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # GPT-like models
    r=8,  # Low-rank dimension
    lora_alpha=32,  # Scaling factor for initialization
    lora_dropout=0.1,  # Dropout probability
    #target_modules=["q_proj", "v_proj"]  # Applies LoRA only to key transformer layers
)

In [15]:
peft_model = peft.get_peft_model(model, peft_config)

In [118]:
#peft_model

In [16]:
in_prompt = "Tomorrow is the Spring break"
out_prompt = "Tomorrow is the Spring break, I will miss the school!"
# Move input and target tensors to device
in_token_idx = tokenizer([in_prompt], return_tensors='pt')
out_token_idx = tokenizer([out_prompt], return_tensors='pt')

in_token_idx = {k: v.to(device) for k, v in in_token_idx.items()}
out_token_idx = {k: v.to(device) for k, v in out_token_idx.items()}

# Labels: Clone target and mask the prompt portion
labels = out_token_idx["input_ids"].clone()
labels[:, :in_token_idx["input_ids"].shape[1]] = -100  # Ignore prompt for loss

# Training input: concatenate prompt with the target tokens
in_token_idx["input_ids"] = out_token_idx["input_ids"]
in_token_idx["attention_mask"] = out_token_idx["attention_mask"]
in_token_idx, labels

({'input_ids': tensor([[    1,  4335, 22396,   338,   278,  7206,  2867, 29892,   306,   674,
            3052,   278,  3762, 29991]], device='cuda:0'),
  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100, 29892,   306,   674,
           3052,   278,  3762, 29991]], device='cuda:0'))

In [17]:
with torch.amp.autocast('cuda'):
    out = peft_model(**in_token_idx)



In [18]:
out.logits.size(), labels.size()

(torch.Size([1, 14, 32000]), torch.Size([1, 14]))

In [19]:
optimizer = torch.optim.AdamW(peft_model.parameters(), lr=1e-4)

In [20]:
peft_model.train()
scaler = torch.amp.GradScaler('cuda')
loss_threshold = 0.1
epoch = 0

while True:
    optimizer.zero_grad()

    with torch.amp.autocast(device_type='cuda'):
        out = peft_model(**in_token_idx)
        logits = out.logits[:,:-1,:] #[:, prompt_length:, :]
        loss = F.cross_entropy(logits.flatten(0, 1), labels[:,1:].flatten(0, 1))

    # Backpropagate using mixed precision
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    print(f"Epoch {epoch}: Loss = {loss.item()}")

    if loss.item() <= loss_threshold:
        break

    epoch += 1

assert loss.item() <= 0.1
print("Good job!")

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch 0: Loss = 3.567661762237549
Epoch 1: Loss = 3.567661762237549
Epoch 2: Loss = 3.567661762237549
Epoch 3: Loss = 3.346400737762451
Epoch 4: Loss = 3.1573660373687744
Epoch 5: Loss = 2.9093191623687744
Epoch 6: Loss = 2.6485769748687744
Epoch 7: Loss = 2.4174106121063232
Epoch 8: Loss = 2.1909878253936768
Epoch 9: Loss = 1.94384765625
Epoch 10: Loss = 1.6928013563156128
Epoch 11: Loss = 1.5080217123031616
Epoch 12: Loss = 1.32373046875
Epoch 13: Loss = 1.1727644205093384
Epoch 14: Loss = 1.0101492404937744
Epoch 15: Loss = 0.8516497015953064
Epoch 16: Loss = 0.6873212456703186
Epoch 17: Loss = 0.5632607340812683
Epoch 18: Loss = 0.4746333658695221
Epoch 19: Loss = 0.3963884711265564
Epoch 20: Loss = 0.3099321722984314
Epoch 21: Loss = 0.2353123277425766
Epoch 22: Loss = 0.1674695760011673
Epoch 23: Loss = 0.1191515251994133
Epoch 24: Loss = 0.0858415886759758
Good job!


In [21]:
peft_model.eval()
prompt = 'Tomorrow is the Spring break' # 'A quick brown fox'

batch = tokenizer([prompt], return_tensors='pt').to(device)

for i in range(17):
    next_token = peft_model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0, :].cpu().numpy().tolist()))




Output: <s>Tomorrow is the Spring break, I will miss the school!
I will miss the school!
I will


In [22]:
generated_tokens = peft_model.generate(batch["input_ids"])
tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

'Tomorrow is the Spring break, I will miss the school!\nI will miss the school!\nI will miss the school! I will miss the school! I will miss the school! I will miss the'