In [None]:
%pip install --upgrade transformers accelerate sentencepiece optimum peft bitsandbytes

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from tqdm.auto import tqdm, trange

assert torch.cuda.is_available(), "you need cuda for this part"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model_name = "Enoch/llama-7b-hf"

# loading Llama tokenizer ...
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

# ... and the model itself
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    low_cpu_mem_usage=True,
    offload_state_dict=True,
    load_in_4bit=True,
    torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad = False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()  # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad
# more on gradient checkpointing: https://pytorch.org/docs/stable/checkpoint.html https://arxiv.org/abs/1604.06174

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

## Use Adapter Lora

Попробуем задачу посложее и обучим адаптер на продолжение текста в качестве стихотворений

In [None]:
from peft import LoraConfig
lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"],
    init_lora_weights=False)

model.add_adapter(lora_config, adapter_name="adapter_1")

In [None]:
import datasets

def tokenize_function(samples):
    return tokenizer(samples["poem content"], padding="max_length", truncation=True)

data = datasets.load_dataset("Ozziey/poems_dataset", split="train[:200]")
data = data.map(tokenize_function, batched=True)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [None]:
columns_to_remove = [col for col in data.column_names if col not in ["input_ids", "attention_mask"]]
data = data.remove_columns(columns_to_remove)


print(data.features)

{'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}


In [None]:
model._hf_peft_config_loaded = True
model.config.use_cache = False

trainer = transformers.Trainer(
    model=model,
    train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        warmup_steps=250,
        max_steps=100,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        report_to=None,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss
1,2.68
2,1.9606
3,2.5783
4,2.4693
5,3.272
6,2.8933
7,3.1086
8,3.1054
9,3.2241
10,3.0522




TrainOutput(global_step=100, training_loss=2.727244428396225, metrics={'train_runtime': 279.9525, 'train_samples_per_second': 0.714, 'train_steps_per_second': 0.357, 'total_flos': 2321183302631424.0, 'train_loss': 2.727244428396225, 'epoch': 1.0})

In [None]:
# infer
prompt = "What a beautiful"
batch = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to(device)


for i in range(50):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch["input_ids"] = torch.cat([batch["input_ids"], next_token], dim=-1)
    batch["attention_mask"] = torch.cat(
        [batch["attention_mask"], torch.ones_like(next_token)], dim=-1
    )

print(
    "\nOutput:",
    tokenizer.decode(batch["input_ids"][0, :].cpu().numpy().tolist()),
)


Output: <s>What a beautiful day! The sun is shining, the birds are singing, and the flowers are blooming.
The flowers are blooming, the birds are singing, and the sun is shining.
The sun is shining, the birds are singing


Ну он попытался зарифмовать на -ing!