<a href="https://colab.research.google.com/github/imthegoodboy/LoRa_/blob/main/LoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install transformers datasets peft accelerate bitsandbytes




In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, PeftModel
from datasets import Dataset
import torch


In [12]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
# GPT-2 doesn't have a pad token by default, so we add one:
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)


In [8]:
train_data = [
    {
        "prompt": "User: Hi, my WaveGuard device is overheating.\nSupport:",
        "response": "Hello! I’m sorry you’re facing this issue. Please place the device on a hard surface, ensure the vents are not blocked, and restart it. If it still overheats, I can help you file a warranty request."
    },
    {
        "prompt": "User: My WaveGuard app is stuck on loading.\nSupport:",
        "response": "No worries, I can help with that. First, try closing and reopening the app. If the issue persists, clear the app cache or reinstall it. Let me know what you see after that."
    },
    {
        "prompt": "User: Can I use WaveGuard while traveling?\nSupport:",
        "response": "Yes, you can use WaveGuard while traveling as long as you have an internet connection. For best performance, make sure your connection is stable and your device firmware is updated."
    },
]

def format_example(example):
    # We join prompt + response as one text for causal LM training
    return {
        "text": example["prompt"] + " " + example["response"] + tokenizer.eos_token
    }

formatted_data = [format_example(ex) for ex in train_data]
dataset = Dataset.from_list(formatted_data)


In [19]:
display(dataset)

Dataset({
    features: ['text'],
    num_rows: 3
})

In [20]:
def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        max_length=256,
        padding="max_length",
    )

tokenized_dataset = dataset.map(tokenize_function, batched=False)


Map:   0%|          | 0/3 [00:00<?, ? examples/s]

In [21]:
tokenized_dataset

Dataset({
    features: ['text', 'input_ids', 'attention_mask'],
    num_rows: 3
})

In [22]:
print(tokenized_dataset[0]['input_ids'])
print(tokenizer.decode(tokenized_dataset[0]['input_ids']))

[12982, 25, 15902, 11, 616, 17084, 24502, 3335, 318, 34789, 803, 13, 198, 15514, 25, 18435, 0, 314, 447, 247, 76, 7926, 345, 447, 247, 260, 6476, 428, 2071, 13, 4222, 1295, 262, 3335, 319, 257, 1327, 4417, 11, 4155, 262, 42777, 389, 407, 10226, 11, 290, 15765, 340, 13, 1002, 340, 991, 34789, 1381, 11, 314, 460, 1037, 345, 2393, 257, 18215, 2581, 13, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 5025

In [23]:
lora_config = LoraConfig(
    r=8,                      # rank (how big the LoRA adapter is)
    lora_alpha=16,            # scaling factor
    lora_dropout=0.05,        # dropout on LoRA
    bias="none",
    task_type="CAUSAL_LM",    # type of problem (LLM)
    # optional: which modules in the model get LoRA
    # for GPT2, modules with attention usually contain "c_attn"
    target_modules=["c_attn"]
)


In [24]:
lora_model = get_peft_model(base_model, lora_config)

lora_model.print_trainable_parameters()


trainable params: 294,912 || all params: 124,734,720 || trainable%: 0.2364




In [25]:
training_args = TrainingArguments(
    output_dir="./lora-gpt2-waveguard",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=1,
    save_steps=10,
    save_total_limit=2,
)


In [26]:
def data_collator(features):
    # Default causal LM data collator
    batch = {
        "input_ids": torch.stack([torch.tensor(f["input_ids"]) for f in features]),
        "attention_mask": torch.stack([torch.tensor(f["attention_mask"]) for f in features]),
    }
    # For causal LM, labels = input_ids (predict next token)
    batch["labels"] = batch["input_ids"].clone()
    return batch

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()


The model is already on multiple devices. Skipping the move to device specified in `args`.
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnikku696969[0m ([33mnikku696969-bms-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
1,9.6887
2,9.4962
3,9.8124


TrainOutput(global_step=3, training_loss=9.665789286295572, metrics={'train_runtime': 187.4308, 'train_samples_per_second': 0.048, 'train_steps_per_second': 0.016, 'total_flos': 1179891007488.0, 'train_loss': 9.665789286295572, 'epoch': 3.0})

In [27]:
lora_model.save_pretrained("./lora-gpt2-waveguard-adapter")
tokenizer.save_pretrained("./lora-gpt2-waveguard-adapter")


('./lora-gpt2-waveguard-adapter/tokenizer_config.json',
 './lora-gpt2-waveguard-adapter/special_tokens_map.json',
 './lora-gpt2-waveguard-adapter/vocab.json',
 './lora-gpt2-waveguard-adapter/merges.txt',
 './lora-gpt2-waveguard-adapter/added_tokens.json',
 './lora-gpt2-waveguard-adapter/tokenizer.json')

In [28]:
## TESTING THE FINE TUNE MODEL

In [None]:
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = AutoModelForCausalLM.from_pretrained(
    "gpt2",
    torch_dtype=torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("./lora-gpt2-waveguard-adapter")
tokenizer.pad_token = tokenizer.eos_token

lora_model = PeftModel.from_pretrained(base_model, "./lora-gpt2-waveguard-adapter")
lora_model.eval()


In [29]:
prompt = "User: My WaveGuard device keeps disconnecting from Wi-Fi.\nSupport:"
inputs = tokenizer(prompt, return_tensors="pt").to(lora_model.device)

with torch.no_grad():
    outputs = lora_model.generate(
        **inputs,
        max_length=200,
        do_sample=True,
        top_p=0.9,
        temperature=0.8
    )

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


User: My WaveGuard device keeps disconnecting from Wi-Fi.
Support: I've been unable to connect to the web.
No WiFi support on my phone.
Yes, the device can't communicate with the web.
Yes, it's not an option, I have no internet.

Yes, the device has not been connected to the web.
No, the phone is not an issue.
Yes, it's not an issue.
I'm not sure, it's a problem.
It's not an issue.
I didn't even know this was an issue.
It's not an issue.
I'm not sure, it's not an issue.
Yes, it's not an issue.
It's not an issue.
Yes, it's not an issue.
It's not an issue.
No, it's not an issue.
I'm not sure, it's not an issue.
Yes, the phone
