In [1]:
!pip install -q transformers bitsandbytes accelerate torch tokenizer
!pip install -U datasets



In [126]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from google.colab import userdata

hf_token = userdata.get('COLAB_HF_TOKEN')

model_name = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)

# Load the model pass HF secret token
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    token=hf_token,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print("Model loaded successfully!")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Model loaded successfully!


In [7]:
prompt_text = "The best part of waking up is"

inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)

print("\nTokenized Input IDs:")
print(inputs['input_ids'])


with torch.no_grad():
  outputs = model(**inputs)

logits = outputs.logits

print("\nShape of the output logits:")
print(logits.shape)

last_token_logits = logits[0, -1, :] # Get logits for the last token in the sequence
predicted_token_id = torch.argmax(last_token_logits).item()

predicted_word = tokenizer.decode(predicted_token_id)

print(f"\nModel's next word prediction: '{predicted_word}'")



Tokenized Input IDs:
tensor([[   1,  415, 1489,  744,  302,  275, 1288,  582,  349]])

Shape of the output logits:
torch.Size([1, 9, 32000])

Model's next word prediction: 'coffee'


In [8]:
from datasets import load_dataset

ds = load_dataset("databricks/databricks-dolly-15k", split="train")

print(ds[0])

{'instruction': 'When did Virgin Australia start operating?', 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.", 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.', 'category': 'closed_qa'}


In [91]:
# This one loop will freeze every parameter in the entire model
for param in model.parameters():
  # "Freeze" the model's weight
  param.requires_grad = False

print("--- After freezing ALL layers ---")
print("Embedding layer frozen?", model.model.embed_tokens.weight.requires_grad == False)
print("Attention layer 15 frozen?", model.model.layers[15].self_attn.q_proj.weight.requires_grad == False)

--- After freezing ALL layers ---
Embedding layer frozen? True
Attention layer 15 frozen? True


In [128]:
import torch.nn as nn
import math

class LoRALayer(nn.Module):
    def __init__(self, original_layer, rank, alpha):
        super().__init__()

        self.original_layer = original_layer

        # Get dimensions from the original layer
        in_features = self.original_layer.in_features
        out_features = self.original_layer.out_features

        # Initialize LoRA A & B matrices
        self.lora_A = nn.Parameter(torch.randn(in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features)) # Paper initializes B with zeros

        # Initialize A with Kaiming uniform for better stability
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

        # Scaling factor
        self.scaling = alpha / rank

    def forward(self, x):

        # The origin linear layer calculation
        original_output = self.original_layer(x)

        # New update w/A & B matrices
        lora_update = (x @ self.lora_A @ self.lora_B) * self.scaling

        return original_output + lora_update

In [129]:
modules = model.model.layers

for module in modules:

  # rank and alpha variables for LoRA layers
  r = 8
  a = 16

  self_attn = module.self_attn
  self_attn.q_proj = LoRALayer(self_attn.q_proj, r, a)
  self_attn.v_proj = LoRALayer(self_attn.v_proj, r, a)

Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias=False)
Linear(in_features=4096, out_features=4096, bias