In [None]:
import os
os.environ.update({'CUDA_VISIBLE_DEVICES': '0'})

In [None]:
! pip install peft
! pip install jsonlines
! pip install accelerate
! pip install bitsandbytes

In [None]:
! nvidia-smi

# Загружаем модель и токенизатор

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_path = "openlm-research/open_llama_3b_v2"

tokenizer = AutoTokenizer.from_pretrained(model_path)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    use_cache=False,
    torch_dtype=torch.float16,
    quantization_config=bnb_config
)

In [None]:
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training

lora_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=1,
    target_modules=['q_proj', 'v_proj'],
    lora_dropout=0.05
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

In [None]:
print(f'Number of parameters: {model.num_parameters()}')

# Попробуем что-нибудь сгенерировать

In [None]:
model.eval()
model.cuda()

In [None]:
model.train()

In [None]:
%%time
from transformers import GenerationConfig

prompt = '### Вопрос: Как приготовить суп?\n\n### Ответ:'

tokens = tokenizer(prompt, return_tensors='pt')

outputs = model.generate(
    inputs=tokens['input_ids'].cuda(),
    generation_config=GenerationConfig(
        max_new_tokens=512,
        do_sample=True,
        temperature=0.5,
        top_k=40,
        top_p=0.8
    )
)

print(tokenizer.decode(outputs[0][len(tokens['input_ids'][0]):]).strip())

# Готовим датасет для обучения и валидации

In [None]:
from datasets import load_dataset
dataset = load_dataset('IlyaGusev/ru_turbo_alpaca')

In [None]:
class MyDataset:
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'instruction': (self.data[idx]['instruction'] + '\n' + self.data[idx]['input']).strip(),
            'output': self.data[idx]['output'].strip()
        }

In [None]:
def collate_fn(data):
    inputs, outputs = [], []
    
    for x in data:
        inp = f'### Вопрос: {x["instruction"]}\n\n### Ответ:'
        input_ids = tokenizer(
            inp,
            add_special_tokens=True
        )['input_ids']
        label_ids = tokenizer(
            x['output'] + tokenizer.eos_token,
            add_special_tokens=False,
            max_length=512,
            truncation=True
        )['input_ids']
        inputs.append(torch.tensor(input_ids + label_ids))
        outputs.append(torch.tensor([-100] * len(input_ids) + label_ids))
        
    input_ids = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(outputs, batch_first=True, padding_value=-100)
        
    return {
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': input_ids.ne(0)
    }

In [None]:
train_dataset = MyDataset([dataset['train'][i] for i in range(128)])
eval_dataset = MyDataset([dataset['train'][i] for i in range(128, 128+64)])

In [None]:
from transformers import Trainer, TrainingArguments

In [None]:
train_args = TrainingArguments(
    output_dir='./output',
    learning_rate=5e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    logging_steps=1,
    save_strategy="no",
    report_to="none",
    warmup_ratio=0.0,
    evaluation_strategy="steps",
    eval_steps=8,
    remove_unused_columns=False,
    gradient_checkpointing=True
)

In [None]:
trainer = Trainer(
    model,
    train_args,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn
)

In [None]:
! nvidia-smi

In [None]:
trainer.train()