# Finetuning

Based on https://www.youtube.com/watch?v=CbmTFTsbyPI

https://huggingface.co/openlm-research/open_llama_3b_v2

Need to install the transformers library to run this notebook:
```bash
conda install pytorch-nightly::pytorch torchvision torchaudio -c pytorch-nightly
conda install -c huggingface transformers  
conda install -c conda-forge sentencepiece peft accelerate 
```

In [1]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
import random

random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x122e63d70>

In [2]:
device_name = 'cpu'
if torch.cuda.is_available():
	device_name = 'cuda' 
elif torch.backends.mps.is_available():
	device_name = 'mps'
#device_name = 'cpu'
print(f'using device: {device_name}')
device=torch.device(device_name)

using device: cpu


In [3]:
from transformers import LlamaTokenizer, LlamaForCausalLM
model_path = 'openlm-research/open_llama_3b_v2'
tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True)
base_model = LlamaForCausalLM.from_pretrained(model_path)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
from peft import LoraConfig, PeftModel
lora_config = LoraConfig(
	r=64,
	lora_alpha=32,
	lora_dropout=0.05,
	bias="none",
	task_type="CAUSAL_LM",
)
model = PeftModel(base_model, lora_config, adapter_name="Shakespeare")


model.to(device)

PeftModel(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 3200, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=3200, out_features=3200, bias=False)
                (lora_dropout): ModuleDict(
                  (Shakespeare): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (Shakespeare): Linear(in_features=3200, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (Shakespeare): Linear(in_features=64, out_features=3200, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear(in_features=3200, out_features=3200, bi

In [13]:
import os
import requests
file_name = "input.txt"

from transformers import TextDataset
train_dataset = TextDataset(
	tokenizer=tokenizer,
	file_path=file_name,
	block_size=128,
)[:256]



In [14]:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
	output_dir="./output",
	overwrite_output_dir=True,
	num_train_epochs=10,
	per_device_train_batch_size=32,
	eval_strategy="no",
)

In [15]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
	tokenizer=tokenizer, mlm=False,
)
trainer = Trainer(
	model=model,
	args=training_args,
	data_collator=data_collator,
	train_dataset=train_dataset,
)

In [16]:
def generate_response(prompt_text, model, tokenizer, max_length=30, num_return_sequences=1):
	input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)
	# Don't have mps support at this time
	# input_ids = tokenizer.encode(prompt_text, return_tensors="pt").to('cpu')
	outputs_sequences = model.generate(
		input_ids=input_ids,
		max_length=max_length,
		num_return_sequences=num_return_sequences,
		no_repeat_ngram_size=2,
	)

	responses = []
	for response_id in outputs_sequences:
		response = tokenizer.decode(response_id, skip_special_tokens=True)
		responses.append(response)
	return responses

In [17]:
prompt_text = "Uneasy lies the head that wears a crown"

responses = generate_response(prompt_text, model, tokenizer)
for response in responses:
	print(response)

Uneasy lies the head that wears a crown.
- William Shakespeare
The head of the family is the most important person in the world


In [18]:
trainer.train()

  0%|          | 0/80 [00:00<?, ?it/s]

RuntimeError: Placeholder storage has not been allocated on MPS device!