In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, get_scheduler
from bitsandbytes.optim import Adam8bit,PagedAdam32bit
from peft import LoraConfig, get_peft_model
from peft import prepare_model_for_kbit_training
import torch
from IPython.display import  clear_output
import time
import gc
from torch.utils.data import Dataset, DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

DEFAULT_MODEL = "meta-llama/Llama-3.2-3B-Instruct"

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=False,
    )


model = AutoModelForCausalLM.from_pretrained(
    DEFAULT_MODEL,
    quantization_config=bnb_config,
    use_safetensors=True,
    device_map=device,
)

print(model.get_memory_footprint()/(1024*1024))

tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL, use_safetensors=True)
tokenizer.pad_token_id = tokenizer.eos_token_id

def flush():
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    gc.collect()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

2095.841064453125


In [2]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("OdiaGenAI/hardcode_odia_qa_105")

In [3]:
dataset['train'][0]

{'input': '',
 'english_input': '',
 'english_instruction': 'Who are you?',
 'instruction': '‡¨Ü‡¨™‡¨£ ‡¨ï‡¨ø‡¨è?',
 'output': '‡¨Æ‡≠Å‡¨Å ‡¨Ö‡¨≤‡¨ø‡¨≠‡≠ç ‡¨è‡¨ï ‡¨ö‡¨æ‡¨ü‡≠ç‡¨¨‡¨ü‡≠ç ‡¨Ü‡¨∏‡¨ø‡¨∑‡≠ç‡¨ü‡¨æ‡¨£‡≠ç‡¨ü, ‡¨Ø‡¨æ‡¨π‡¨æ‡¨ï‡¨ø ‡¨ì‡¨°‡¨ø‡¨Ü-‡¨ú‡≠á‡¨®-‡¨è.‡¨Ü‡¨á. ‡¨ó‡¨¨‡≠á‡¨∑‡¨ï‡¨Æ‡¨æ‡¨®‡¨ô‡≠ç‡¨ï ‡¨¶‡≠ç‡≠±‡¨æ‡¨∞‡¨æ ‡¨™‡≠ç‡¨∞‡¨∂‡¨ø‡¨ï‡≠ç‡¨∑‡¨ø‡¨§ ‡¨è‡¨ï ‡¨≠‡¨æ‡¨∑‡¨æ ‡¨Æ‡¨°‡≠á‡¨≤‡•§',
 'english_output': 'I am Olive a chatbot assistant, a language model trained by researchers from OdiaGenAI.'}

In [4]:
# Custom PyTorch Dataset
class LlamaDataset(Dataset):
    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        question=sample['instruction']
        answer = sample['output']
        chat_template = f'''<|begin_of_text|> <|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{answer}‡•§<|eot_id|>'''
        inputs = tokenizer(chat_template,return_tensors='pt',truncation=True,padding="max_length",max_length=380)
        labels = torch.where(inputs.input_ids==tokenizer.pad_token_id,-100,inputs.input_ids)
        
        input_ids = inputs.input_ids[:,:-1].squeeze()
        labels = labels[:, 1:].squeeze()
        
        return input_ids,labels

In [5]:
train_dataset = LlamaDataset(dataset['train'])
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [6]:
for param in model.named_parameters():
    print(param[0],param[1].dtype)

model.embed_tokens.weight torch.float16
model.layers.0.self_attn.q_proj.weight torch.uint8
model.layers.0.self_attn.k_proj.weight torch.uint8
model.layers.0.self_attn.v_proj.weight torch.uint8
model.layers.0.self_attn.o_proj.weight torch.uint8
model.layers.0.mlp.gate_proj.weight torch.uint8
model.layers.0.mlp.up_proj.weight torch.uint8
model.layers.0.mlp.down_proj.weight torch.uint8
model.layers.0.input_layernorm.weight torch.float16
model.layers.0.post_attention_layernorm.weight torch.float16
model.layers.1.self_attn.q_proj.weight torch.uint8
model.layers.1.self_attn.k_proj.weight torch.uint8
model.layers.1.self_attn.v_proj.weight torch.uint8
model.layers.1.self_attn.o_proj.weight torch.uint8
model.layers.1.mlp.gate_proj.weight torch.uint8
model.layers.1.mlp.up_proj.weight torch.uint8
model.layers.1.mlp.down_proj.weight torch.uint8
model.layers.1.input_layernorm.weight torch.float16
model.layers.1.post_attention_layernorm.weight torch.float16
model.layers.2.self_attn.q_proj.weight tor

In [7]:
model = prepare_model_for_kbit_training(model)

In [8]:
for param in model.named_parameters():
    print(param[0],param[1].dtype)

model.embed_tokens.weight torch.float32
model.layers.0.self_attn.q_proj.weight torch.uint8
model.layers.0.self_attn.k_proj.weight torch.uint8
model.layers.0.self_attn.v_proj.weight torch.uint8
model.layers.0.self_attn.o_proj.weight torch.uint8
model.layers.0.mlp.gate_proj.weight torch.uint8
model.layers.0.mlp.up_proj.weight torch.uint8
model.layers.0.mlp.down_proj.weight torch.uint8
model.layers.0.input_layernorm.weight torch.float32
model.layers.0.post_attention_layernorm.weight torch.float32
model.layers.1.self_attn.q_proj.weight torch.uint8
model.layers.1.self_attn.k_proj.weight torch.uint8
model.layers.1.self_attn.v_proj.weight torch.uint8
model.layers.1.self_attn.o_proj.weight torch.uint8
model.layers.1.mlp.gate_proj.weight torch.uint8
model.layers.1.mlp.up_proj.weight torch.uint8
model.layers.1.mlp.down_proj.weight torch.uint8
model.layers.1.input_layernorm.weight torch.float32
model.layers.1.post_attention_layernorm.weight torch.float32
model.layers.2.self_attn.q_proj.weight tor

In [9]:
config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.1,
    bias="none",
    inference_mode=False,
    use_rslora=True,
    init_lora_weights="gaussian",
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 36,700,160 || all params: 3,249,449,984 || trainable%: 1.1294


In [10]:
for param in model.named_parameters():
    print(param[0],param[1].requires_grad)

base_model.model.model.embed_tokens.weight False
base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight False
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight True
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight True
base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight False
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight True
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight True
base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight False
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight True
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight True
base_model.model.model.layers.0.self_attn.o_proj.base_layer.weight False
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight True
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight True
base_model.model.model.layers.0.mlp.gate_proj.weigh

In [11]:
# question = '''‡¨ì‡¨°‡¨º‡¨ø‡¨∂‡¨æ‡¨∞‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨®‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨¶‡≠á‡¨¨‡¨æ ‡¨™‡¨æ‡¨á‡¨Å ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨è‡¨¨‡¨Ç ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞ ‡¨ï‡¨ø‡¨™‡¨∞‡¨ø ‡¨Æ‡¨ø‡¨≥‡¨ø‡¨Æ‡¨ø‡¨∂‡¨ø ‡¨ï‡¨æ‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á?'''
# answer = '''‡¨Ø‡≠á‡¨ï‡≠å‡¨£‡¨∏‡¨ø ‡¨∞‡¨æ‡¨ú‡≠ç‡≠ü‡¨∞‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨®‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨¶‡≠á‡¨¨‡¨æ ‡¨™‡¨æ‡¨á‡¨Å ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨è‡¨¨‡¨Ç ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞‡¨∞ ‡¨Æ‡¨ø‡¨≥‡¨ø‡¨§ ‡¨™‡≠ç‡¨∞‡≠ü‡¨æ‡¨∏‡¨∞ ‡¨Ü‡¨¨‡¨∂‡≠ç‡≠ü‡¨ï‡¨§‡¨æ ‡¨∞‡¨π‡¨ø‡¨õ‡¨ø‡•§ ‡¨ì‡¨°‡¨º‡¨ø‡¨∂‡¨æ‡¨∞ ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨è‡¨¨‡¨Ç ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞ ‡¨∞‡¨æ‡¨ú‡≠ç‡≠ü‡¨∞‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨®‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨¶‡≠á‡¨¨‡¨æ ‡¨≤‡¨æ‡¨ó‡¨ø ‡¨è‡¨ï ‡¨¨‡¨ø‡¨∏‡≠ç‡¨§‡≠É‡¨§ ‡¨∞‡¨£‡¨®‡≠Ä‡¨§‡¨ø ‡¨¨‡¨ø‡¨ï‡¨∂‡¨ø‡¨§ ‡¨è‡¨¨‡¨Ç ‡¨ï‡¨æ‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨≤‡¨æ‡¨ó‡¨ø ‡¨Æ‡¨ø‡¨≥‡¨ø‡¨§ ‡¨≠‡¨æ‡¨¨‡≠á ‡¨ï‡¨æ‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á‡•§
# ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞ ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞ ‡¨∏‡¨π‡¨ø‡¨§ ‡¨Æ‡¨ø‡¨∂‡¨ø ‡¨ï‡¨æ‡¨Æ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ‡¨∞ ‡¨ó‡≠ã‡¨ü‡¨ø‡¨è ‡¨â‡¨™‡¨æ‡≠ü ‡¨π‡≠á‡¨≤‡¨æ ‡¨ò‡¨∞‡≠ã‡¨á ‡¨®‡¨ø‡¨¨‡≠á‡¨∂ ‡¨™‡¨æ‡¨á‡¨Å ‡¨Ö‡¨®‡≠Å‡¨ï‡≠Ç‡¨≥ ‡¨¨‡¨æ‡¨§‡¨æ‡¨¨‡¨∞‡¨£ ‡¨∏‡≠É‡¨∑‡≠ç‡¨ü‡¨ø ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ, ‡¨è‡¨•‡¨ø‡¨∞‡≠á ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞‡¨∞ ‡¨≠‡¨æ‡¨ó‡¨ø‡¨¶‡¨æ‡¨∞‡≠Ä‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨™‡¨æ‡¨á‡¨Å ‡¨ü‡¨ø‡¨ï‡¨∏ ‡¨è‡¨¨‡¨Ç ‡¨®‡¨ø‡≠ü‡¨æ‡¨Æ‡¨ï ‡¨™‡≠ç‡¨∞‡¨§‡¨ø‡¨¨‡¨®‡≠ç‡¨ß‡¨ï‡¨ï‡≠Å ‡¨π‡≠ç‡¨∞‡¨æ‡¨∏ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ, ‡¨è‡¨π‡¨æ ‡¨¨‡≠ç‡≠ü‡¨§‡≠Ä‡¨§ ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞ ‡¨ü‡¨ø‡¨ï‡¨∏ ‡¨∞‡¨ø‡¨π‡¨æ‡¨§‡¨ø, ‡¨∏‡¨¨‡¨∏‡¨ø‡¨°‡¨ø ‡¨è‡¨¨‡¨Ç ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨¨‡¨ø‡¨ï‡¨æ‡¨∂ ‡¨™‡≠ç‡¨∞‡¨ï‡¨≥‡≠ç‡¨™ ‡¨™‡¨æ‡¨á‡¨Å ‡¨ú‡¨Æ‡¨ø ‡¨Ü‡¨¶‡¨ø ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨Æ‡¨ß‡≠ç‡≠ü ‡¨™‡≠ç‡¨∞‡¨¶‡¨æ‡¨® ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á‡•§
# ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞ ‡¨∏‡¨π ‡¨Æ‡¨ø‡¨∂‡¨ø ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞ ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞ ‡¨∏‡¨π ‡¨Æ‡¨ø‡¨∂‡¨ø ‡¨®‡≠Ç‡¨§‡¨® ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨â‡¨§‡≠ç‡¨™‡¨æ‡¨¶ ‡¨™‡≠ç‡¨∞‡¨∏‡≠ç‡¨§‡≠Å‡¨§ ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á ‡¨Ø‡¨æ‡¨π‡¨æ ‡¨â‡¨≠‡≠ü ‡¨ò‡¨∞‡≠ã‡¨á ‡¨è‡¨¨‡¨Ç ‡¨Ö‡¨®‡≠ç‡¨§‡¨∞‡≠ç‡¨ú‡¨æ‡¨§‡≠Ä‡≠ü ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨ï‡¨ô‡≠ç‡¨ï ‡¨Ü‡¨¨‡¨∂‡≠ç‡≠ü‡¨ï‡¨§‡¨æ ‡¨™‡≠Ç‡¨∞‡¨£ ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡•§
# ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞ ‡¨Æ‡¨ß‡≠ç‡≠ü ‡¨ì‡¨°‡¨º‡¨ø‡¨∂‡¨æ‡¨∞‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨®‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨¶‡≠á‡¨¨‡¨æ ‡¨è‡¨¨‡¨Ç ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞‡¨∞ ‡¨≠‡¨æ‡¨ó‡¨ø‡¨¶‡¨æ‡¨∞‡≠Ä‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨™‡¨æ‡¨á‡¨Å ‡¨™‡≠ç‡¨∞‡¨Ø‡≠Å‡¨ï‡≠ç‡¨§‡¨ø‡¨∞ ‡¨â‡¨™‡¨Ø‡≠ã‡¨ó ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á‡•§ ‡¨â‡¨¶‡¨æ‡¨π‡¨∞‡¨£ ‡¨∏‡≠ç‡≠±‡¨∞‡≠Ç‡¨™, ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞ ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨∏‡≠ç‡¨•‡¨≥‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨è‡¨¨‡¨Ç ‡¨∏‡¨Æ‡≠ç‡¨≠‡¨æ‡¨¨‡≠ç‡≠ü ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨ï‡¨Æ‡¨æ‡¨®‡¨ô‡≠ç‡¨ï ‡¨∏‡¨π‡¨ø‡¨§ ‡¨Ø‡≠ã‡¨°‡¨º‡¨ø‡¨¨‡¨æ ‡¨≤‡¨æ‡¨ó‡¨ø ‡¨∏‡≠ã‡¨∏‡¨ø‡¨Ü‡¨≤ ‡¨Æ‡¨ø‡¨°‡¨ø‡¨Ü ‡¨™‡≠ç‡¨≤‡¨æ‡¨ü‡¨´‡¨∞‡≠ç‡¨Æ‡¨∞ ‡¨â‡¨™‡¨Ø‡≠ã‡¨ó ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á‡•§ ‡¨ì‡¨°‡¨º‡¨ø‡¨∂‡¨æ‡¨∞‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨Ü‡¨ï‡¨∞‡≠ç‡¨∑‡¨£ ‡¨è‡¨¨‡¨Ç ‡¨Ö‡¨®‡≠Å‡¨≠‡¨¨ ‡¨™‡≠ç‡¨∞‡¨¶‡¨∞‡≠ç‡¨∂‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨≤‡¨æ‡¨ó‡¨ø ‡¨è‡¨ï ‡¨Ö‡¨®‡¨≤‡¨æ‡¨á‡¨® ‡¨™‡≠ç‡¨≤‡¨æ‡¨ü‡¨´‡¨∞‡≠ç‡¨Æ ‡¨™‡≠ç‡¨∞‡¨§‡¨ø‡¨∑‡≠ç‡¨†‡¨æ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨¶‡≠ç‡≠±‡¨æ‡¨∞‡¨æ ‡¨Ö‡¨ß‡¨ø‡¨ï ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨ï‡¨ô‡≠ç‡¨ï‡≠Å ‡¨Ü‡¨ï‡¨∞‡≠ç‡¨∑‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ‡¨∞‡≠á ‡¨∏‡¨π‡¨æ‡≠ü‡¨§‡¨æ ‡¨Æ‡¨ø‡¨≥‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡•§
# ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨®‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨¶‡≠á‡¨¨‡¨æ ‡¨≤‡¨æ‡¨ó‡¨ø ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨è‡¨¨‡¨Ç ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞‡¨ï‡≠Å ‡¨Æ‡¨ø‡¨≥‡¨ø‡¨§ ‡¨≠‡¨æ‡¨¨‡≠á ‡¨ï‡¨æ‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ‡¨ï‡≠Å ‡¨™‡¨°‡¨ø‡¨¨ ‡¨Ø‡≠á‡¨™‡¨∞‡¨ø‡¨ï‡¨ø ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨ó‡¨§‡¨ø‡¨¨‡¨ø‡¨ß‡¨ø ‡¨¶‡≠ç‡≠±‡¨æ‡¨∞‡¨æ ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨æ‡¨¨‡¨∞‡¨£‡¨∞ ‡¨ï‡≠ç‡¨∑‡≠ü ‡¨ï‡¨ø‡¨Æ‡≠ç‡¨¨‡¨æ ‡¨∏‡≠ç‡¨•‡¨æ‡¨®‡≠Ä‡≠ü ‡¨∏‡¨Æ‡≠ç‡¨™‡≠ç‡¨∞‡¨¶‡¨æ‡≠ü‡¨∞ ‡¨ï‡≠ç‡¨∑‡¨§‡¨ø ‡¨® ‡¨π‡≠á‡¨â‡•§
# ‡¨∂‡≠á‡¨∑‡¨∞‡≠á, ‡¨∏‡¨∞‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨è‡¨¨‡¨Ç ‡¨ò‡¨∞‡≠ã‡¨á ‡¨ï‡≠ç‡¨∑‡≠á‡¨§‡≠ç‡¨∞‡¨ï‡≠Å ‡¨ì‡¨°‡¨º‡¨ø‡¨∂‡¨æ‡¨∞‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨®‡¨ï‡≠Å ‡¨™‡≠ç‡¨∞‡≠ã‡¨§‡≠ç‡¨∏‡¨æ‡¨π‡¨® ‡¨¶‡≠á‡¨¨‡¨æ ‡¨≤‡¨æ‡¨ó‡¨ø ‡¨è‡¨ï ‡¨Ö‡¨®‡≠Å‡¨ï‡≠Ç‡¨≥ ‡¨™‡¨∞‡¨ø‡¨¨‡≠á‡¨∂ ‡¨∏‡≠É‡¨∑‡≠ç‡¨ü‡¨ø ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨Ü‡¨¨‡¨∂‡≠ç‡≠ü‡¨ï ‡¨è‡¨¨‡¨Ç ‡¨è‡¨π‡¨æ ‡¨∏‡≠Å‡¨®‡¨ø‡¨∂‡≠ç‡¨ö‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡¨æ ‡¨â‡¨ö‡¨ø‡¨§ ‡¨Ø‡≠á ‡¨¨‡¨ø‡¨ï‡¨æ‡¨∂ ‡¨∏‡≠ç‡¨•‡¨æ‡≠ü‡≠Ä ‡¨π‡≠á‡¨¨‡•§ ‚Äù ‡¨Æ‡¨ø‡¨≥‡¨ø‡¨§ ‡¨≠‡¨æ‡¨¨‡≠á ‡¨ï‡¨æ‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü ‡¨ï‡¨∞‡¨ø ‡¨∏‡≠á‡¨Æ‡¨æ‡¨®‡≠á ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨∞‡¨£‡¨®‡≠Ä‡¨§‡¨ø‡¨ï‡≠Å ‡¨¨‡¨ø‡¨ï‡¨∂‡¨ø‡¨§ ‡¨è‡¨¨‡¨Ç ‡¨ï‡¨æ‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ï‡¨æ‡¨∞‡≠Ä ‡¨ï‡¨∞‡¨ø‡¨™‡¨æ‡¨∞‡¨ø‡¨¨‡≠á ‡¨Ø‡¨æ‡¨π‡¨æ ‡¨ï‡≠á‡¨¨‡¨≥ ‡¨™‡¨∞‡≠ç‡¨Ø‡≠ç‡≠ü‡¨ü‡¨® ‡¨â‡¨¶‡≠ç‡≠ü‡≠ã‡¨ó ‡¨®‡≠Å‡¨π‡≠á‡¨Å ‡¨¨‡¨∞‡¨Ç ‡¨∏‡≠ç‡¨•‡¨æ‡¨®‡≠Ä‡≠ü ‡¨ó‡≠ã‡¨∑‡≠ç‡¨†‡≠Ä ‡¨è‡¨¨‡¨Ç ‡¨™‡¨∞‡¨ø‡¨¨‡≠á‡¨∂‡¨ï‡≠Å ‡¨Æ‡¨ß‡≠ç‡≠ü ‡¨≤‡¨æ‡¨≠‡¨æ‡¨®‡≠ç‡≠±‡¨ø‡¨§ ‡¨ï‡¨∞‡¨ø‡¨¨‡•§'''

# tokenized_text = tokenizer(answer).input_ids
# print(len(tokenized_text))
# for idx in range(len(tokenized_text)):
#     clear_output(wait=True)
#     print(tokenizer.decode(tokenized_text[0:idx]))
#     time.sleep(0.1)

# Finetune the LLAMA model on a single text

In [12]:
for input_ids,labels in train_dataloader:
    break

tokenizer.batch_decode(input_ids)

['<|begin_of_text|><|begin_of_text|> <|start_header_id|>user<|end_header_id|>\n\n‡¨®‡¨Æ‡¨∏‡≠ç‡¨ï‡¨æ‡¨∞‡•§<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n‡¨Æ‡≠Å‡¨Å ‡¨Ö‡¨≤‡¨ø‡¨≠‡≠ç ‡¨è‡¨ï ‡¨ö‡¨æ‡¨ü‡≠ç‡¨¨‡¨ü‡≠ç ‡¨Ü‡¨∏‡¨ø‡¨∑‡≠ç‡¨ü‡¨æ‡¨£‡≠ç‡¨ü ‡¨®‡¨æ‡¨Æ‡¨ï ‡¨è‡¨ï ‡¨≠‡¨æ‡¨∑‡¨æ ‡¨Æ‡¨°‡≠á‡¨≤‡≠ç ‡¨è‡¨¨‡¨Ç ‡¨ì‡¨°‡¨ø‡¨Ü-‡¨ú‡≠á‡¨®-‡¨è.‡¨Ü‡¨á. ‡¨ó‡¨¨‡≠á‡¨∑‡¨ï‡¨Æ‡¨æ‡¨®‡¨ô‡≠ç‡¨ï ‡¨¶‡≠ç‡≠±‡¨æ‡¨∞‡¨æ ‡¨™‡≠ç‡¨∞‡¨∂‡¨ø‡¨ï‡≠ç‡¨∑‡¨ø‡¨§ ‡¨π‡≠ã‡¨á‡¨õ‡¨ø‡•§‡•§<|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><|eot_id|><

In [13]:
def generate_eval():
    sample=dataset['train'][0]
    question=sample['instruction']
    answer = sample['output']
    chat_template = f'''<|begin_of_text|> <|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|>'''
    inputs = tokenizer(chat_template , return_tensors="pt").to(device)
    # print(prompt)

    model.eval()

    with torch.no_grad():
        output = model.generate(
            **inputs,
            do_sample=True,
            max_new_tokens=512,
            repetition_penalty=1.3,
            temperature=0.7,         # Optional: smooth randomness
            top_k=50,                # Optional: top-k sampling
            top_p=0.9                # Optional: nucleus sampling
        )

    processed_text = tokenizer.decode(output[0], skip_special_tokens=False)
    
    model.train()

    return processed_text

In [14]:
pred = generate_eval()
print(pred)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


<|begin_of_text|><|begin_of_text|> <|start_header_id|>user<|end_header_id|>

‡¨Ü‡¨™‡¨£ ‡¨ï‡¨ø‡¨è?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"‡¨Æ‡µá'‡≠ü ‡¨§‡¥æ. ‡¨®‡´Å' ‡¨Ø‡¨• ‡¨ö‚Äå‚Äç‚Äã‡®°ÃÅ‡©ú '‡¨Å ‡¨ì ‡¨∏‚Äå‚Äç‚Äã‡®≤; ‡®°‡µç‡¥∞ÃÄ ‡¨∂‚Äå‚Äç‚Äã‡®ñ ‡¨Ö ‡¨ó‚Äå‚Äç ‚Äã‡®πÃå ‡®∏‚Äå‚Äç ‚Äã‚Äã'üôÉ".<|eot_id|>


In [None]:
model.config.use_cache = False
model.config.pretraining_tp = 1
gradient_accumulation_steps = 4
max_steps=500
# Define optimizer
optimizer = PagedAdam32bit(model.parameters(), lr=1e-4)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=max_steps,
)
# Training loop
model.train()

global_step= 0

while global_step< 500:
    for step,(input_ids, labels) in enumerate(train_dataloader):
        input_ids, labels= input_ids.to('cuda'),labels.to('cuda')
        model.config.use_cache = False
        model.train()
        
        
        # Forward pass
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss
        loss = loss / gradient_accumulation_steps  # Normalize loss
        loss.backward()
        
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        global_step += 1
        if global_step >= max_steps:
            break
        
        if global_step % 20 == 0:
            pred = generate_eval()
            print('*'*20,step+1,'*'*20)
            print("Predictions:", pred)
            print('*'*20,'end','*'*20)
            
        
        print(f"Epoch {global_step + 1}/{max_steps}, Loss: {loss.item():.4f}")
        
flush()

Epoch 2/500, Loss: 3.4890


In [16]:
save_path = "/home/nas/buffer/mohan.dash/llama_3_finetuned/model_checkpoint.pt"

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'lr_scheduler_state_dict': lr_scheduler.state_dict(),
    'global_step': global_step
}, save_path)

print(f"Checkpoint saved to {save_path}")

Checkpoint saved to /home/nas/buffer/mohan.dash/llama_3_finetuned/model_checkpoint.pt
