In [None]:
pip install transformers datasets peft accelerate bitsandbytes

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
from torch.utils.data import DataLoader
from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Configure quantization properly
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
)

In [4]:
model_name = "VietnamAIHub/Vietnamese_llama_30B_SFT" #"vilm/vinallama-12.5b-chat-DUS"  #"vilm/vinallama-7b-chat" 
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device,
    torch_dtype=torch.float16,
).to(device)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    padding_side="right",
    #use_fast=False, # Fast tokenizer giving issues.
    tokenizer_type='llama', #if 'llama' in args.model_name_or_path else None, # Needed for HF name change
    token=True,
)

tokenizer.bos_token_id = 1
stop_token_ids = [0]

# Set padding_token if not set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.


In [None]:
# Configure LoRA
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# Get PEFT model
model = get_peft_model(model, peft_config)

In [7]:
dataset = load_dataset("json", data_files="trimmed_data.json", split='train')
train_dataset = dataset.train_test_split(test_size=0.001)["train"]
eval_dataset = dataset.train_test_split(test_size=0.2)["test"]

In [8]:
MAX_LENGTH = 2048  # You can adjust this as per your model's max length

def preprocess(batch):
    inputs = []
    for inp, out in zip(batch["input"], batch["output"]):
        prompt, completion = f"User: {inp}\nAssistant: ", out
        prompt_ids = tokenizer(prompt, add_special_tokens=False)["input_ids"]
        completion_ids = tokenizer(completion, add_special_tokens=False)["input_ids"]

        # Concatenate prompt and completion, ensuring truncation happens
        input_ids = prompt_ids + completion_ids
        input_ids = input_ids[:MAX_LENGTH]  # Ensure it doesn't exceed the max length

        # Set labels (ignoring the prompt part)
        labels = [-100] * len(prompt_ids) + completion_ids
        labels = labels[:MAX_LENGTH]  # Ensure the labels also match the max length

        # Create attention mask
        attention_mask = [1] * len(input_ids)

        inputs.append({"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels})

    return {k: [dic[k] for dic in inputs] for k in inputs[0]}

In [9]:
tokenized_train = train_dataset.map(
    preprocess,
    batched=True,
    remove_columns=train_dataset.column_names
)

tokenized_eval = eval_dataset.map(
    preprocess,
    batched=True,
    remove_columns=eval_dataset.column_names
)


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (3623 > 2048). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/21 [00:00<?, ? examples/s]

In [None]:
# Function to pad batches to maximum length
def collator(features):
    print(f"Keys in batch: {features[0].keys()}")  # Add this line to see the keys
    batch = {k: [torch.tensor(f[k][:2048]) for f in features] for k in ["input_ids", "attention_mask", "labels"]}
    padded = {
        k: torch.nn.utils.rnn.pad_sequence(v, batch_first=True, padding_value=(tokenizer.pad_token_id if k != "labels" else -100))
        for k, v in batch.items()
    }
    return padded

train_loader = DataLoader(tokenized_train, shuffle=True, batch_size=8, collate_fn=collator)
eval_loader = DataLoader(tokenized_eval, shuffle=False, batch_size=8, collate_fn=collator)

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

In [None]:
# Training loop
model.train()
for epoch in range(3):
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        batch = {k: v.to(device) for k, v in batch.items()}  # Move batch to device (GPU)
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()
    print(f"Epoch {epoch+1} loss: {total_loss:.4f}")

  scaler = GradScaler()
  with autocast():  # Use mixed precision here
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Keys in batch: dict_keys(['input_ids', 'attention_mask', 'labels'])


Epoch 1:   4%|▍         | 1/25 [00:31<12:25, 31.07s/it]

Keys in batch: dict_keys(['input_ids', 'attention_mask', 'labels'])


Epoch 1:   8%|▊         | 2/25 [00:45<08:06, 21.15s/it]

Keys in batch: dict_keys(['input_ids', 'attention_mask', 'labels'])


Epoch 1:  12%|█▏        | 3/25 [01:02<07:06, 19.41s/it]

Keys in batch: dict_keys(['input_ids', 'attention_mask', 'labels'])


Epoch 1:  16%|█▌        | 4/25 [01:34<08:32, 24.38s/it]

Keys in batch: dict_keys(['input_ids', 'attention_mask', 'labels'])


Epoch 1:  20%|██        | 5/25 [02:07<09:09, 27.49s/it]

Keys in batch: dict_keys(['input_ids', 'attention_mask', 'labels'])


In [None]:
model.save_pretrained("lora_finetuned_model")
tokenizer.save_pretrained("lora_finetuned_model")

In [None]:
from transformers import StoppingCriteriaList, TextIteratorStreamer
from transformers import StoppingCriteria
## Setting Stopping Criteria
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in stop_token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False
stop = StopOnTokens()
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

In [None]:
generation_config = dict(
    temperature=0.2,
    top_k=20,
    top_p=0.9,
    do_sample=True,
    num_beams=1,
    repetition_penalty=1.2,
    max_new_tokens=1024, 
    early_stopping=True,
    stopping_criteria=StoppingCriteriaList([stop]),
    streamer=streamer,
)

In [None]:
# Inference
def generate_response(input_prompt):
    model.eval()

    system_prompt=f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### prompt:\n{input_prompt}\n\n### response:\n"

    inputs = tokenizer(system_prompt,return_tensors="pt")  #add_special_tokens=False ?
    input_ids = inputs['input_ids'].to(device)
    
    generation_output = model.generate(
        input_ids=input_ids,
        attention_mask = inputs['attention_mask'].to(device),
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        **generation_config
    )
    
    s = generation_output[0]
    output = tokenizer.decode(s,skip_special_tokens=True)
    response = output.split("### response:")[1].strip()
    print(response)
    
# Interactive loop
if __name__ == "__main__":
    while True:
        query = input("Enter your query (type ' exit' to quit)")
        if query.lower() == 'exit': break
        print(f'\nResponse: {generate_response(query)}')