In [1]:
!pip install --quiet --upgrade transformers bitsandbytes accelerate sentencepiece optimum auto-gptq
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import BitsAndBytesConfig
from tqdm.auto import tqdm, trange
assert torch.cuda.is_available(), "you need cuda for this part"

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m67.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m342.1/342.1 kB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m433.6/433.6 kB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.5/23.5 MB[0m [31m65.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.2/13.2 MB[0m [31m85.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [38]:
model_name = 'Enoch/llama-7b-hf'

# loading Llama tokenizer ...
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

# ... and the model itself
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    quantization_config=quantization_config, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad
# more on gradient checkpointing: https://pytorch.org/docs/stable/checkpoint.html https://arxiv.org/abs/1604.06174

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [5]:
prompt = 'Tomorrow is the Spring break' # 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

for i in range(10):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))


Output: <s>Tomorrow is the Spring break for the kids. I am so excited.


In [6]:
# the_truth = "A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
# batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
# outputs = model(**batch)

# next_word_logits = outputs.logits[:, :-1]
# true_next_tokens = batch['input_ids'][:, 1:]
# loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

# print("Loss:", loss)

In [7]:
#outputs.logits.size()

In [8]:
class WordEmbeddingsWithLearnedPrompts(nn.Module):
    """
    Replace model's original word embeddings with a layer that inserts trainable prompts instead of the first N token embeddings.
    """
    def __init__(self, word_embeddings: nn.Embedding, num_prompts: int):
        super().__init__()
        self.original_word_embeddings = word_embeddings
        self.num_prompts = num_prompts
        self.learnable_prompts = nn.Parameter(
            torch.randn(1, num_prompts, word_embeddings.embedding_dim), requires_grad=True
        )

    def forward(self, input_ids: torch.LongTensor):
        # Ensure input_ids are of correct type and length
        assert input_ids.dtype == torch.int64
        assert input_ids.shape[1] > self.num_prompts, "Input sequence must be longer than the number of prompts"
        assert (input_ids[:, :self.num_prompts] == tokenizer.pad_token_id).all(), \
            "Ensure the first `num_prompts` tokens are PAD tokens"

        # Embed input_ids
        embedded_input_ids = self.original_word_embeddings(input_ids)

        # Replace the first `num_prompts` token embeddings with learnable prompts using concatenation
        prompt_embeds = self.learnable_prompts.expand(input_ids.shape[0], -1, -1)  # [batch_size, num_prompts, embedding_dim]
        embedded_input_ids = torch.cat([prompt_embeds, embedded_input_ids[:, self.num_prompts:, :]], dim=1)

        return embedded_input_ids

In [10]:
num_prompts = 16
test_emb_layer = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)
test_input_ids = tokenizer("a cat sat on a mat", return_tensors='pt')['input_ids'].to(device)

space_for_prompts = torch.full([len(test_input_ids), num_prompts], fill_value=tokenizer.pad_token_id,
                               dtype=torch.int64, device=device)
test_inputs_with_prompts = torch.cat([space_for_prompts, test_input_ids], dim=1)

with torch.amp.autocast('cuda'):
  test_prompt_embeddings = test_emb_layer(test_inputs_with_prompts)

assert test_prompt_embeddings.shape[:2] == test_inputs_with_prompts.shape
assert test_prompt_embeddings.shape[-1] == model.config.hidden_size
assert torch.allclose(test_prompt_embeddings[:, :num_prompts], test_emb_layer.learnable_prompts.float())
assert torch.allclose(test_prompt_embeddings[:, num_prompts:], model.model.embed_tokens(test_input_ids).float())
print("Looks legit!")

Looks legit!


In [11]:
assert isinstance(model.model.embed_tokens, nn.Embedding), "you have already replaced the embedding layer. If the replacement is broken, please reload the model"

model.model.embed_tokens = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)

opt = torch.optim.Adam([model.model.embed_tokens.learnable_prompts], lr=0.01)

In [12]:
the_truth = "Tomorrow is the Spring break, I will miss the school!" #"A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
space_for_prompts = torch.full([len(test_input_ids), num_prompts], fill_value=tokenizer.pad_token_id,
                               dtype=torch.int64, device=device)
batch['input_ids'] = torch.cat([space_for_prompts, batch['input_ids']], dim=1)
batch['attention_mask'] = torch.cat([torch.ones_like(space_for_prompts), batch['attention_mask']], dim=1)

outputs = model(**batch)
next_word_logits = outputs.logits[:, num_prompts : -1, :]
true_next_tokens = batch['input_ids'][:, num_prompts + 1:]
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
print("Loss:", loss)
scaler = torch.amp.GradScaler('cuda')

loss_threshold = 0.1
epoch = 0

while True:
    opt.zero_grad()

    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        outputs = model(**batch)
        next_word_logits = outputs.logits[:, num_prompts:-1, :]
        true_next_tokens = batch['input_ids'][:, num_prompts+1:]
        loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

    # Backpropagate using mixed precision
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

    print(f"Epoch {epoch}: Loss = {loss.item()}")

    if loss.item() <= loss_threshold:
        break

    epoch += 1

assert loss.item() <= 0.1
print("Good job!")

Loss: tensor(8.6567, device='cuda:0', grad_fn=<NllLossBackward0>)
Epoch 0: Loss = 8.654447555541992
Epoch 1: Loss = 8.654447555541992
Epoch 2: Loss = 7.76292085647583
Epoch 3: Loss = 7.126652717590332
Epoch 4: Loss = 6.595853328704834
Epoch 5: Loss = 6.179311752319336
Epoch 6: Loss = 5.763521671295166
Epoch 7: Loss = 5.763521671295166
Epoch 8: Loss = 5.24391508102417
Epoch 9: Loss = 5.24391508102417
Epoch 10: Loss = 4.67210054397583
Epoch 11: Loss = 4.23805570602417
Epoch 12: Loss = 3.986102819442749
Epoch 13: Loss = 3.826472282409668
Epoch 14: Loss = 3.670973539352417
Epoch 15: Loss = 3.4951171875
Epoch 16: Loss = 3.294358491897583
Epoch 17: Loss = 3.069260835647583
Epoch 18: Loss = 2.836519718170166
Epoch 19: Loss = 2.602736234664917
Epoch 20: Loss = 2.369600772857666
Epoch 21: Loss = 2.156663179397583
Epoch 22: Loss = 1.97735595703125
Epoch 23: Loss = 1.8224698305130005
Epoch 24: Loss = 1.670612096786499
Epoch 25: Loss = 1.5597158670425415
Epoch 26: Loss = 1.4112924337387085
Epoch 2

In [13]:
prompt = 'Tomorrow is the Spring break' # 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
batch['input_ids'] = torch.cat([space_for_prompts, batch['input_ids']], dim=1)
batch['attention_mask'] = torch.cat([torch.ones_like(space_for_prompts), batch['attention_mask']], dim=1)


for i in range(17):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0, num_prompts:].cpu().numpy().tolist()))

# if you did everything right, the model will deny that the fox jumped over the lazy dog


Output: <s>Tomorrow is the Spring break, I will miss the school!
The school is closed, I will miss the


In [3]:
# for name, layer in model.model.layers.named_modules():
#     if isinstance(layer, torch.nn.Linear):
#         print(name, layer)

Relupload the model!

In [None]:
model_name = 'Enoch/llama-7b-hf'

# loading Llama tokenizer ...
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

# ... and the model itself
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    quantization_config=quantization_config, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad
# more on gradient checkpointing: https://pytorch.org/docs/stable/checkpoint.html https://arxiv.org/abs/1604.06174

In [14]:
import peft

In [15]:
from peft import LoraConfig, TaskType

In [17]:
#model.config #, tokenizer)

In [18]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # GPT-like models
    r=8,  # Low-rank dimension
    lora_alpha=32,  # Scaling factor for initialization
    lora_dropout=0.1,  # Dropout probability
    #target_modules=["q_proj", "v_proj"]  # Applies LoRA only to key transformer layers
)

In [39]:
peft_model = peft.get_peft_model(model, peft_config)

In [1]:
#peft_model

In [41]:
in_prompt = "Tomorrow is the Spring break"
out_prompt = "Tomorrow is the Spring break, I will miss the school!"
in_token_idx = tokenizer([in_prompt], return_tensors='pt')
out_token_idx = tokenizer([out_prompt], return_tensors='pt')

# Create Labels for Loss Calculation
labels = out_token_idx["input_ids"].clone()
labels[:, :in_token_idx["input_ids"].shape[1]] = -100

in_token_idx, labels

({'input_ids': tensor([[    1,  4335, 22396,   338,   278,  7206,  2867]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])},
 tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100, 29892,   306,   674,
           3052,   278,  3762, 29991]]))

In [42]:
#device = model.device
#peft_model = peft_model.to(device)
in_token_idx = {k: v.to(device) for k,v in in_token_idx.items()}
labels = labels.to(device)

In [43]:
# Ensure input has padding for full generation
prompt_length = in_token_idx["input_ids"].shape[1]

space_for_generation = torch.full(
    (in_token_idx["input_ids"].shape[0], out_token_idx["input_ids"].shape[1] - in_token_idx["input_ids"].shape[1]),
    fill_value=tokenizer.pad_token_id,
    dtype=torch.int64,
    device=device
)

# Concatenate input with space for generation
in_token_idx['input_ids'] = torch.cat([in_token_idx["input_ids"], space_for_generation], dim=1)
in_token_idx['attention_mask'] = (in_token_idx['input_ids'] != tokenizer.pad_token_id)

In [44]:
in_token_idx, labels

({'input_ids': tensor([[    1,  4335, 22396,   338,   278,  7206,  2867,     2,     2,     2,
               2,     2,     2,     2]], device='cuda:0'),
  'attention_mask': tensor([[ True,  True,  True,  True,  True,  True,  True, False, False, False,
           False, False, False, False]], device='cuda:0')},
 tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100, 29892,   306,   674,
           3052,   278,  3762, 29991]], device='cuda:0'))

In [45]:
with torch.amp.autocast('cuda'):
    out = peft_model(**in_token_idx)



In [46]:
out.logits.size(), labels.size()

(torch.Size([1, 14, 32000]), torch.Size([1, 14]))

In [47]:
logits = out.logits[:, prompt_length:, :]  # Use logits only after the prompt
labels = labels[:, prompt_length:]  # Use labels only for new tokens
logits.size(), labels.size()

(torch.Size([1, 7, 32000]), torch.Size([1, 7]))

In [51]:
optimizer = torch.optim.AdamW(peft_model.parameters(), lr=1e-4)

In [2]:
peft_model.train()
scaler = torch.amp.GradScaler('cuda')
loss_threshold = 0.1
epoch = 0

while True:
    optimizer.zero_grad()

    with torch.amp.autocast(device_type='cuda'):
        out = peft_model(**in_token_idx)
        logits = out.logits[:, prompt_length:, :]
        loss = F.cross_entropy(logits.flatten(0, 1), labels.flatten(0, 1))

    # Backpropagate using mixed precision
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    print(f"Epoch {epoch}: Loss = {loss.item()}")

    if loss.item() <= loss_threshold:
        break

    epoch += 1

assert loss.item() <= 0.1
print("Good job!")

In [53]:
peft_model.eval()
prompt = 'Tomorrow is the Spring break' # 'A quick brown fox'

for i in range(17):
    next_token = peft_model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0, :].cpu().numpy().tolist()))




Output: </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s>Tomorrow is the Spring break, I will miss the school!
The school is closed, I will miss the miss school school school school school school school school school school school school school school school school
