In [1]:
import numpy as np
import pandas as pd
import locale
import json
import re
import pymorphy2
import torch
import accelerate
from datetime import datetime
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from bs4 import BeautifulSoup
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Loading JSON-dataset to dictionary
with open('/Users/alexanderknyshov/Desktop/LLM/Data/datasets/train_set.json', 'r') as json_file:
    data = json.load(json_file)

In [3]:
def format_text(text):
    news_text = "form a recommendation using news: " + " ".join(text['news'])
    recommendation_text = " ".join(text['recommendations'])
    return {'input_text': news_text, 'target_text': recommendation_text}

# Formatting original dataset
dataset = Dataset.from_dict(data)
dataset = dataset.map(format_text)
# Removing unneccessery columns
dataset = dataset.remove_columns(['news', 'recommendations'])

Map: 100%|██████████| 5177/5177 [00:02<00:00, 1979.81 examples/s]


In [4]:
# Splitting into train and validate data
split_ratio = 0.8
split_index = int(len(data['news']) * split_ratio)

train_news = data['news'][:split_index]
train_recommendations = data['recommendations'][:split_index]
train_data = {'news': train_news, 'recommendations': train_recommendations}

val_news = data['news'][split_index:]
val_recommendations = data['recommendations'][split_index:]
val_data = {'news': val_news, 'recommendations': val_recommendations}

# Returning to the format needed
train_dataset = Dataset.from_dict(train_data)
train_dataset = train_dataset.map(format_text)
train_dataset = train_dataset.remove_columns(['news', 'recommendations'])

val_dataset = Dataset.from_dict(val_data)
val_dataset = val_dataset.map(format_text)
val_dataset = val_dataset.remove_columns(['news', 'recommendations'])

Map: 100%|██████████| 4141/4141 [00:02<00:00, 1895.00 examples/s]
Map: 100%|██████████| 1036/1036 [00:00<00:00, 2813.20 examples/s]


In [5]:
# Loading the tokenizer
tokenizer = T5Tokenizer.from_pretrained('sberbank-ai/ruT5-base')

# Method for data tokenization
def tokenize_data(example):
    input_encodings = tokenizer(example['input_text'], padding='max_length', truncation=True, max_length=512)
    target_encodings = tokenizer(example['target_text'], padding='max_length', truncation=True, max_length=256)

    example['input_ids'] = input_encodings['input_ids']
    example['attention_mask'] = input_encodings['attention_mask']
    example['labels'] = target_encodings['input_ids']
    return example

# Train and validation data tokenization
tokenized_train_dataset = train_dataset.map(tokenize_data)
tokenized_val_dataset = val_dataset.map(tokenize_data)

# Getting rid of text columns
tokenized_train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
tokenized_val_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Map: 100%|██████████| 4141/4141 [00:11<00:00, 347.74 examples/s]
Map: 100%|██████████| 1036/1036 [00:01<00:00, 534.99 examples/s]


In [6]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [7]:
lora_config = LoraConfig(
    r=16,                       # Ранг, может быть изменен в зависимости от задачи
    lora_alpha=32,             # Гиперпараметр, который можно настроить
    lora_dropout=0.1,          # Вероятность дропаута для LoRA
    target_modules=["q", "v"]  # Модули, которые будут адаптированы
)

In [8]:
# Model Loading
model = T5ForConditionalGeneration.from_pretrained('sberbank-ai/ruT5-base').to(device)
model = get_peft_model(model, lora_config)

# Defining training params
training_args = TrainingArguments(
    output_dir= '/Users/alexanderknyshov/Desktop/LLM/Data/model',
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    evaluation_strategy="epoch",
    save_strategy="epoch",   
    logging_dir='/Users/alexanderknyshov/Desktop/LLM/Data/model/logs',
    logging_steps=10,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss'
)

# Metrics count
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    loss = sum([1 if pred == label else 0 for pred, label in zip(decoded_preds, decoded_labels)]) / len(decoded_preds)
    return {'accuracy': loss}

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)



In [9]:
trainer.train()

# Evaluation
eval_results = trainer.evaluate()

# Results
print(f"Validation results: {eval_results}")

  0%|          | 10/4141 [00:04<22:29,  3.06it/s] 

{'loss': 13.2212, 'grad_norm': 9.34701919555664, 'learning_rate': 4.987925621830476e-05, 'epoch': 0.0}


  0%|          | 20/4141 [00:07<21:31,  3.19it/s]

{'loss': 11.9242, 'grad_norm': 12.915538787841797, 'learning_rate': 4.975851243660952e-05, 'epoch': 0.0}


  1%|          | 30/4141 [00:11<23:27,  2.92it/s]

{'loss': 11.0967, 'grad_norm': 7.858455181121826, 'learning_rate': 4.9637768654914274e-05, 'epoch': 0.01}


  1%|          | 40/4141 [00:14<21:32,  3.17it/s]

{'loss': 10.9077, 'grad_norm': 70.02967071533203, 'learning_rate': 4.951702487321903e-05, 'epoch': 0.01}


  1%|          | 50/4141 [00:17<21:26,  3.18it/s]

{'loss': 9.0211, 'grad_norm': 7.83552360534668, 'learning_rate': 4.939628109152379e-05, 'epoch': 0.01}


  1%|▏         | 60/4141 [00:20<21:25,  3.18it/s]

{'loss': 9.1765, 'grad_norm': 7.433175563812256, 'learning_rate': 4.9275537309828546e-05, 'epoch': 0.01}


  2%|▏         | 70/4141 [00:23<21:19,  3.18it/s]

{'loss': 8.2973, 'grad_norm': 3.409534454345703, 'learning_rate': 4.91547935281333e-05, 'epoch': 0.02}


  2%|▏         | 80/4141 [00:26<21:34,  3.14it/s]

{'loss': 8.7944, 'grad_norm': 5.378474235534668, 'learning_rate': 4.903404974643806e-05, 'epoch': 0.02}


  2%|▏         | 90/4141 [00:30<21:14,  3.18it/s]

{'loss': 8.46, 'grad_norm': 4.407718658447266, 'learning_rate': 4.891330596474282e-05, 'epoch': 0.02}


  2%|▏         | 100/4141 [00:33<21:13,  3.17it/s]

{'loss': 7.6352, 'grad_norm': 16.936677932739258, 'learning_rate': 4.8792562183047574e-05, 'epoch': 0.02}


  3%|▎         | 110/4141 [00:36<21:27,  3.13it/s]

{'loss': 7.74, 'grad_norm': 4.328479290008545, 'learning_rate': 4.867181840135233e-05, 'epoch': 0.03}


  3%|▎         | 120/4141 [00:39<21:28,  3.12it/s]

{'loss': 7.6915, 'grad_norm': 4.855695724487305, 'learning_rate': 4.855107461965709e-05, 'epoch': 0.03}


  3%|▎         | 130/4141 [00:42<21:08,  3.16it/s]

{'loss': 7.4072, 'grad_norm': 5.778942584991455, 'learning_rate': 4.8430330837961846e-05, 'epoch': 0.03}


  3%|▎         | 140/4141 [00:45<21:12,  3.14it/s]

{'loss': 7.4922, 'grad_norm': 2.444490671157837, 'learning_rate': 4.83095870562666e-05, 'epoch': 0.03}


  4%|▎         | 150/4141 [00:49<21:40,  3.07it/s]

{'loss': 7.3122, 'grad_norm': 4.498382568359375, 'learning_rate': 4.818884327457136e-05, 'epoch': 0.04}


  4%|▍         | 160/4141 [00:52<21:05,  3.15it/s]

{'loss': 7.185, 'grad_norm': 3.8250176906585693, 'learning_rate': 4.806809949287612e-05, 'epoch': 0.04}


  4%|▍         | 170/4141 [00:55<21:03,  3.14it/s]

{'loss': 7.2767, 'grad_norm': 3.553750514984131, 'learning_rate': 4.7947355711180875e-05, 'epoch': 0.04}


  4%|▍         | 180/4141 [00:58<20:59,  3.14it/s]

{'loss': 7.0485, 'grad_norm': 2.3023831844329834, 'learning_rate': 4.782661192948563e-05, 'epoch': 0.04}


  5%|▍         | 190/4141 [01:01<20:56,  3.14it/s]

{'loss': 7.1877, 'grad_norm': 3.0321133136749268, 'learning_rate': 4.770586814779039e-05, 'epoch': 0.05}


  5%|▍         | 200/4141 [01:05<20:51,  3.15it/s]

{'loss': 6.881, 'grad_norm': 5.3091278076171875, 'learning_rate': 4.7585124366095146e-05, 'epoch': 0.05}


  5%|▌         | 210/4141 [01:08<20:52,  3.14it/s]

{'loss': 6.9186, 'grad_norm': 4.050538063049316, 'learning_rate': 4.7464380584399904e-05, 'epoch': 0.05}


  5%|▌         | 220/4141 [01:11<21:04,  3.10it/s]

{'loss': 6.936, 'grad_norm': 2.3444883823394775, 'learning_rate': 4.734363680270466e-05, 'epoch': 0.05}


  6%|▌         | 230/4141 [01:14<21:27,  3.04it/s]

{'loss': 7.0162, 'grad_norm': 3.450909376144409, 'learning_rate': 4.722289302100942e-05, 'epoch': 0.06}


  6%|▌         | 240/4141 [01:17<20:43,  3.14it/s]

{'loss': 7.0991, 'grad_norm': 4.3455376625061035, 'learning_rate': 4.7102149239314175e-05, 'epoch': 0.06}


  6%|▌         | 250/4141 [01:21<20:41,  3.13it/s]

{'loss': 7.262, 'grad_norm': 26.057249069213867, 'learning_rate': 4.698140545761893e-05, 'epoch': 0.06}


  6%|▋         | 260/4141 [01:24<20:23,  3.17it/s]

{'loss': 6.7877, 'grad_norm': 9.555444717407227, 'learning_rate': 4.6860661675923696e-05, 'epoch': 0.06}


  7%|▋         | 270/4141 [01:27<20:34,  3.14it/s]

{'loss': 6.8717, 'grad_norm': 10.79602336883545, 'learning_rate': 4.6739917894228454e-05, 'epoch': 0.07}


  7%|▋         | 280/4141 [01:30<20:20,  3.16it/s]

{'loss': 7.0014, 'grad_norm': 19.431865692138672, 'learning_rate': 4.661917411253321e-05, 'epoch': 0.07}


  7%|▋         | 290/4141 [01:33<20:40,  3.10it/s]

{'loss': 6.5196, 'grad_norm': 5.331852912902832, 'learning_rate': 4.649843033083797e-05, 'epoch': 0.07}


  7%|▋         | 300/4141 [01:37<20:15,  3.16it/s]

{'loss': 6.9831, 'grad_norm': 2.7903618812561035, 'learning_rate': 4.6377686549142725e-05, 'epoch': 0.07}


  7%|▋         | 310/4141 [01:40<20:10,  3.17it/s]

{'loss': 6.6821, 'grad_norm': 3.0141243934631348, 'learning_rate': 4.625694276744748e-05, 'epoch': 0.07}


  8%|▊         | 320/4141 [01:43<20:04,  3.17it/s]

{'loss': 7.1937, 'grad_norm': 7.411855697631836, 'learning_rate': 4.613619898575224e-05, 'epoch': 0.08}


  8%|▊         | 330/4141 [01:46<20:32,  3.09it/s]

{'loss': 6.6471, 'grad_norm': 3.27835750579834, 'learning_rate': 4.6015455204057e-05, 'epoch': 0.08}


  8%|▊         | 340/4141 [01:49<19:39,  3.22it/s]

{'loss': 6.5327, 'grad_norm': 6.357583045959473, 'learning_rate': 4.5894711422361754e-05, 'epoch': 0.08}


  8%|▊         | 350/4141 [01:52<19:31,  3.24it/s]

{'loss': 6.7213, 'grad_norm': 3.556504726409912, 'learning_rate': 4.577396764066651e-05, 'epoch': 0.08}


  9%|▊         | 360/4141 [01:55<19:32,  3.23it/s]

{'loss': 6.5183, 'grad_norm': 2.9715468883514404, 'learning_rate': 4.565322385897127e-05, 'epoch': 0.09}


  9%|▉         | 370/4141 [01:59<19:26,  3.23it/s]

{'loss': 6.7074, 'grad_norm': 4.666072845458984, 'learning_rate': 4.5532480077276025e-05, 'epoch': 0.09}


  9%|▉         | 380/4141 [02:02<19:26,  3.23it/s]

{'loss': 6.5648, 'grad_norm': 2.924010753631592, 'learning_rate': 4.541173629558078e-05, 'epoch': 0.09}


  9%|▉         | 390/4141 [02:05<20:01,  3.12it/s]

{'loss': 6.8076, 'grad_norm': 3.6303629875183105, 'learning_rate': 4.529099251388554e-05, 'epoch': 0.09}


 10%|▉         | 400/4141 [02:08<20:17,  3.07it/s]

{'loss': 6.5856, 'grad_norm': 2.9167723655700684, 'learning_rate': 4.51702487321903e-05, 'epoch': 0.1}


 10%|▉         | 410/4141 [02:11<19:33,  3.18it/s]

{'loss': 6.7109, 'grad_norm': 3.379377603530884, 'learning_rate': 4.5049504950495054e-05, 'epoch': 0.1}


 10%|█         | 420/4141 [02:14<19:19,  3.21it/s]

{'loss': 6.6077, 'grad_norm': 3.2724320888519287, 'learning_rate': 4.492876116879981e-05, 'epoch': 0.1}


 10%|█         | 430/4141 [02:18<19:18,  3.20it/s]

{'loss': 6.5727, 'grad_norm': 5.208462715148926, 'learning_rate': 4.480801738710457e-05, 'epoch': 0.1}


 11%|█         | 440/4141 [02:21<19:22,  3.18it/s]

{'loss': 7.6085, 'grad_norm': 3.839852809906006, 'learning_rate': 4.4687273605409326e-05, 'epoch': 0.11}


 11%|█         | 450/4141 [02:24<19:03,  3.23it/s]

{'loss': 6.4698, 'grad_norm': 4.413749694824219, 'learning_rate': 4.456652982371408e-05, 'epoch': 0.11}


 11%|█         | 460/4141 [02:27<18:56,  3.24it/s]

{'loss': 6.3255, 'grad_norm': 9.97079849243164, 'learning_rate': 4.444578604201884e-05, 'epoch': 0.11}


 11%|█▏        | 470/4141 [02:30<19:10,  3.19it/s]

{'loss': 6.5498, 'grad_norm': 39.67104721069336, 'learning_rate': 4.43250422603236e-05, 'epoch': 0.11}


 12%|█▏        | 480/4141 [02:33<18:53,  3.23it/s]

{'loss': 6.4435, 'grad_norm': 2.8942337036132812, 'learning_rate': 4.4204298478628355e-05, 'epoch': 0.12}


 12%|█▏        | 490/4141 [02:36<19:07,  3.18it/s]

{'loss': 6.2728, 'grad_norm': 4.543673038482666, 'learning_rate': 4.408355469693311e-05, 'epoch': 0.12}


 12%|█▏        | 500/4141 [02:39<18:55,  3.21it/s]

{'loss': 6.4406, 'grad_norm': 3.615333080291748, 'learning_rate': 4.396281091523787e-05, 'epoch': 0.12}


 12%|█▏        | 510/4141 [02:42<18:50,  3.21it/s]

{'loss': 6.6857, 'grad_norm': 3.436436891555786, 'learning_rate': 4.3842067133542626e-05, 'epoch': 0.12}


 13%|█▎        | 520/4141 [02:46<19:56,  3.03it/s]

{'loss': 6.4739, 'grad_norm': 3.48868989944458, 'learning_rate': 4.3721323351847383e-05, 'epoch': 0.13}


 13%|█▎        | 530/4141 [02:49<19:43,  3.05it/s]

{'loss': 6.2891, 'grad_norm': 2.474797248840332, 'learning_rate': 4.360057957015214e-05, 'epoch': 0.13}


 13%|█▎        | 540/4141 [02:52<19:53,  3.02it/s]

{'loss': 6.3552, 'grad_norm': 5.442928791046143, 'learning_rate': 4.34798357884569e-05, 'epoch': 0.13}


 13%|█▎        | 550/4141 [02:56<20:06,  2.98it/s]

{'loss': 6.6264, 'grad_norm': 4.297702789306641, 'learning_rate': 4.3359092006761655e-05, 'epoch': 0.13}


 14%|█▎        | 560/4141 [02:59<18:59,  3.14it/s]

{'loss': 6.367, 'grad_norm': 4.353188514709473, 'learning_rate': 4.323834822506641e-05, 'epoch': 0.14}


 14%|█▍        | 570/4141 [03:02<18:55,  3.15it/s]

{'loss': 6.5258, 'grad_norm': 5.444392204284668, 'learning_rate': 4.311760444337117e-05, 'epoch': 0.14}


 14%|█▍        | 580/4141 [03:05<18:50,  3.15it/s]

{'loss': 6.442, 'grad_norm': 4.085636615753174, 'learning_rate': 4.299686066167593e-05, 'epoch': 0.14}


 14%|█▍        | 590/4141 [03:08<18:43,  3.16it/s]

{'loss': 6.8933, 'grad_norm': 3.9689199924468994, 'learning_rate': 4.2876116879980684e-05, 'epoch': 0.14}


 14%|█▍        | 600/4141 [03:11<18:36,  3.17it/s]

{'loss': 6.3358, 'grad_norm': 4.174694061279297, 'learning_rate': 4.275537309828544e-05, 'epoch': 0.14}


 15%|█▍        | 610/4141 [03:15<18:33,  3.17it/s]

{'loss': 6.2582, 'grad_norm': 3.533534288406372, 'learning_rate': 4.26346293165902e-05, 'epoch': 0.15}


 15%|█▍        | 620/4141 [03:18<18:40,  3.14it/s]

{'loss': 6.3145, 'grad_norm': 3.0429177284240723, 'learning_rate': 4.2513885534894955e-05, 'epoch': 0.15}


 15%|█▌        | 630/4141 [03:21<18:39,  3.14it/s]

{'loss': 6.3944, 'grad_norm': 3.7604923248291016, 'learning_rate': 4.239314175319971e-05, 'epoch': 0.15}


 15%|█▌        | 640/4141 [03:24<19:34,  2.98it/s]

{'loss': 6.2128, 'grad_norm': 4.235511779785156, 'learning_rate': 4.227239797150447e-05, 'epoch': 0.15}


 16%|█▌        | 643/4141 [03:25<19:27,  3.00it/s]

KeyboardInterrupt: 