## Сбор и подготовка данных

In [1]:
import pandas as pd
import os
import torch
from transformers import GPT2Tokenizer
from transformers import pipeline
from torch.utils.data import DataLoader
import sys
sys.path.append(os.path.abspath('src'))

  from .autonotebook import tqdm as notebook_tqdm


### очистка текста

In [None]:
import preprocess
prepare_dataset()
split_dataset()

train: 1277164, val: 159646, test: 159646


### токенизация, разбиение на тренировачный, валидационный и тестовый датасеты

In [2]:
import torch
import random
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
from tqdm import tqdm
from next_token_dataset import *

# Загружаем BERT токенизатор
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")


train_texts = pd.read_csv("data/train.csv")
train_texts=train_texts['text'].tolist() 
val_texts = pd.read_csv("data/val.csv")
val_texts = val_texts['text'].tolist() 
# тренировочный и валидационный датасеты
train_dataset = NextTokenDataset(train_texts, tokenizer, max_length=20)
val_dataset = NextTokenDataset(val_texts, tokenizer, max_length=20)

# даталоадеры
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=lambda batch: collate_fn_pad_sequence(batch, tokenizer))
val_loader = DataLoader(val_dataset, batch_size=64, collate_fn=lambda batch: collate_fn_pad_sequence(batch, tokenizer))


## объявление модели

In [4]:
from lstm_train import *


Объявлена модель на основе архитектуры LSTM со следующими параметрами:
embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.3

## обучение LSTM модели

In [None]:
vocab_size = tokenizer.vocab_size

# объявляем модель
model = LSTMTextGenerator(
    vocab_size=vocab_size,
    embed_dim=128,
    hidden_dim=256,  
    num_layers=2,
    dropout=0.3
)

# Обучение
train_losses, val_losses = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    tokenizer=tokenizer,
    epochs=10,
    lr=0.001,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# Сохранение
save_checkpoint(
    model=model,
    optimizer=optim.Adam(model.parameters(), lr=0.001),
    epoch=10,
    train_loss=train_losses[-1],
    val_loss=val_losses[-1],
    filename='lstm_text_generator_final.pt'
)

Модель обучена в течение 10 эпох. Финальные метрики обучения:  Train Loss: 5.3009, Val Loss: 5.1672

## запуск LSTM-модели на тестовых данных

In [None]:
import pandas as pd
import torch
from transformers import BertTokenizerFast
from torch.utils.data import Dataset, DataLoader
from eval_lstm import *
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

tokenizer.pad_token = tokenizer.eos_token  


# загрузка модели

model, checkpoint = load_model(
    checkpoint_path='lstm_text_generator_final.pt',
    vocab_size=30522,  # BERT base uncased vocabulary size
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# создаем test dataset
test_texts = pd.read_csv("data/test.csv")['text'].tolist()


test_dataset = NextTokenDataset(test_texts, tokenizer, max_length=20)

print("\n1. Автодополнение...")
token_results, gen_texts, ref_texts = evaluate_model(
    model=model,
    test_dataset=test_dataset,
    tokenizer=tokenizer,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    batch_size=32
)

print_evaluation_results(token_results, "Результаты предсказания текста")

print("\nПримеры:")
for i in range(min(5, len(gen_texts))):
    print(f"Reference: '{ref_texts[i]}'")
    print(f"Generated: '{gen_texts[i]}'")
    print()

модель загружена lstm_text_generator_final.pt
Epoch: 10
Train Loss: 5.2689
Val Loss: 5.1347

1. Автодополнение...
Evaluating model on test dataset...
Total batches: 63638


Evaluating:   2%|▏         | 1049/63638 [00:03<03:53, 268.54it/s]

Processed 1000 batches, current avg ROUGE-1: 0.2206


Evaluating:   3%|▎         | 2035/63638 [00:07<03:48, 269.29it/s]

Processed 2000 batches, current avg ROUGE-1: 0.2224


Evaluating:   5%|▍         | 3053/63638 [00:11<03:48, 265.44it/s]

Processed 3000 batches, current avg ROUGE-1: 0.2210


Evaluating:   6%|▋         | 4045/63638 [00:15<03:41, 269.03it/s]

Processed 4000 batches, current avg ROUGE-1: 0.2219


Evaluating:   8%|▊         | 5042/63638 [00:18<03:38, 268.17it/s]

Processed 5000 batches, current avg ROUGE-1: 0.2217


Evaluating:   9%|▉         | 6039/63638 [00:22<03:36, 265.78it/s]

Processed 6000 batches, current avg ROUGE-1: 0.2210


Evaluating:  11%|█         | 7027/63638 [00:26<03:32, 266.43it/s]

Processed 7000 batches, current avg ROUGE-1: 0.2208


Evaluating:  13%|█▎        | 8028/63638 [00:29<03:25, 271.08it/s]

Processed 8000 batches, current avg ROUGE-1: 0.2206


Evaluating:  14%|█▍        | 9043/63638 [00:33<03:20, 272.89it/s]

Processed 9000 batches, current avg ROUGE-1: 0.2207


Evaluating:  16%|█▌        | 10052/63638 [00:37<03:17, 270.72it/s]

Processed 10000 batches, current avg ROUGE-1: 0.2204


Evaluating:  17%|█▋        | 11037/63638 [00:40<03:16, 267.67it/s]

Processed 11000 batches, current avg ROUGE-1: 0.2205


Evaluating:  19%|█▉        | 12045/63638 [00:44<03:08, 273.71it/s]

Processed 12000 batches, current avg ROUGE-1: 0.2201


Evaluating:  20%|██        | 13032/63638 [00:48<03:05, 273.40it/s]

Processed 13000 batches, current avg ROUGE-1: 0.2202


Evaluating:  22%|██▏       | 14046/63638 [00:51<03:02, 271.15it/s]

Processed 14000 batches, current avg ROUGE-1: 0.2200


Evaluating:  24%|██▎       | 15031/63638 [00:55<02:59, 271.51it/s]

Processed 15000 batches, current avg ROUGE-1: 0.2199


Evaluating:  25%|██▌       | 16040/63638 [00:59<02:56, 269.81it/s]

Processed 16000 batches, current avg ROUGE-1: 0.2199


Evaluating:  27%|██▋       | 17052/63638 [01:02<02:51, 270.96it/s]

Processed 17000 batches, current avg ROUGE-1: 0.2199


Evaluating:  28%|██▊       | 18038/63638 [01:06<02:48, 270.35it/s]

Processed 18000 batches, current avg ROUGE-1: 0.2203


Evaluating:  30%|██▉       | 19051/63638 [01:09<02:45, 269.10it/s]

Processed 19000 batches, current avg ROUGE-1: 0.2203


Evaluating:  31%|███▏      | 20033/63638 [01:13<02:44, 265.26it/s]

Processed 20000 batches, current avg ROUGE-1: 0.2203


Evaluating:  33%|███▎      | 21043/63638 [01:17<02:43, 261.07it/s]

Processed 21000 batches, current avg ROUGE-1: 0.2204


Evaluating:  35%|███▍      | 22045/63638 [01:21<02:37, 263.95it/s]

Processed 22000 batches, current avg ROUGE-1: 0.2207


Evaluating:  36%|███▌      | 23042/63638 [01:24<02:36, 259.06it/s]

Processed 23000 batches, current avg ROUGE-1: 0.2206


Evaluating:  38%|███▊      | 24040/63638 [01:28<02:32, 259.27it/s]

Processed 24000 batches, current avg ROUGE-1: 0.2206


Evaluating:  39%|███▉      | 25053/63638 [01:32<02:29, 258.55it/s]

Processed 25000 batches, current avg ROUGE-1: 0.2208


Evaluating:  41%|████      | 26041/63638 [01:36<02:24, 260.83it/s]

Processed 26000 batches, current avg ROUGE-1: 0.2208


Evaluating:  42%|████▏     | 27038/63638 [01:39<02:20, 260.88it/s]

Processed 27000 batches, current avg ROUGE-1: 0.2208


Evaluating:  44%|████▍     | 28035/63638 [01:43<02:22, 249.46it/s]

Processed 28000 batches, current avg ROUGE-1: 0.2209


Evaluating:  46%|████▌     | 29034/63638 [01:47<02:19, 248.51it/s]

Processed 29000 batches, current avg ROUGE-1: 0.2208


Evaluating:  47%|████▋     | 30036/63638 [01:51<02:08, 261.62it/s]

Processed 30000 batches, current avg ROUGE-1: 0.2208


Evaluating:  49%|████▉     | 31036/63638 [01:54<02:05, 259.43it/s]

Processed 31000 batches, current avg ROUGE-1: 0.2208


Evaluating:  50%|█████     | 32033/63638 [01:58<02:00, 261.23it/s]

Processed 32000 batches, current avg ROUGE-1: 0.2209


Evaluating:  52%|█████▏    | 33034/63638 [02:02<01:56, 261.76it/s]

Processed 33000 batches, current avg ROUGE-1: 0.2209


Evaluating:  53%|█████▎    | 34039/63638 [02:05<01:53, 261.32it/s]

Processed 34000 batches, current avg ROUGE-1: 0.2208


Evaluating:  55%|█████▌    | 35043/63638 [02:09<01:50, 259.43it/s]

Processed 35000 batches, current avg ROUGE-1: 0.2207


Evaluating:  57%|█████▋    | 36044/63638 [02:13<01:45, 261.69it/s]

Processed 36000 batches, current avg ROUGE-1: 0.2208


Evaluating:  58%|█████▊    | 37048/63638 [02:17<01:43, 258.05it/s]

Processed 37000 batches, current avg ROUGE-1: 0.2208


Evaluating:  60%|█████▉    | 38041/63638 [02:20<01:39, 257.65it/s]

Processed 38000 batches, current avg ROUGE-1: 0.2208


Evaluating:  61%|██████▏   | 39034/63638 [02:24<01:35, 256.70it/s]

Processed 39000 batches, current avg ROUGE-1: 0.2209


Evaluating:  63%|██████▎   | 40052/63638 [02:28<01:30, 259.92it/s]

Processed 40000 batches, current avg ROUGE-1: 0.2207


Evaluating:  64%|██████▍   | 41029/63638 [02:31<01:28, 254.89it/s]

Processed 41000 batches, current avg ROUGE-1: 0.2208


Evaluating:  66%|██████▌   | 42030/63638 [02:35<01:23, 257.97it/s]

Processed 42000 batches, current avg ROUGE-1: 0.2208


Evaluating:  68%|██████▊   | 43030/63638 [02:39<01:20, 256.70it/s]

Processed 43000 batches, current avg ROUGE-1: 0.2208


Evaluating:  69%|██████▉   | 44047/63638 [02:43<01:16, 257.33it/s]

Processed 44000 batches, current avg ROUGE-1: 0.2207


Evaluating:  71%|███████   | 45047/63638 [02:46<01:12, 256.67it/s]

Processed 45000 batches, current avg ROUGE-1: 0.2208


Evaluating:  72%|███████▏  | 46048/63638 [02:50<01:08, 257.64it/s]

Processed 46000 batches, current avg ROUGE-1: 0.2207


Evaluating:  74%|███████▍  | 47044/63638 [02:54<01:04, 257.78it/s]

Processed 47000 batches, current avg ROUGE-1: 0.2208


Evaluating:  75%|███████▌  | 48039/63638 [02:58<01:00, 256.81it/s]

Processed 48000 batches, current avg ROUGE-1: 0.2209


Evaluating:  77%|███████▋  | 49043/63638 [03:01<00:56, 258.36it/s]

Processed 49000 batches, current avg ROUGE-1: 0.2210


Evaluating:  79%|███████▊  | 50046/63638 [03:05<00:52, 258.16it/s]

Processed 50000 batches, current avg ROUGE-1: 0.2209


Evaluating:  80%|████████  | 51052/63638 [03:09<00:49, 256.19it/s]

Processed 51000 batches, current avg ROUGE-1: 0.2209


Evaluating:  82%|████████▏ | 52047/63638 [03:13<00:45, 253.19it/s]

Processed 52000 batches, current avg ROUGE-1: 0.2208


Evaluating:  83%|████████▎ | 53041/63638 [03:16<00:41, 254.17it/s]

Processed 53000 batches, current avg ROUGE-1: 0.2208


Evaluating:  85%|████████▍ | 54044/63638 [03:20<00:37, 254.49it/s]

Processed 54000 batches, current avg ROUGE-1: 0.2208


Evaluating:  86%|████████▋ | 55034/63638 [03:24<00:34, 245.94it/s]

Processed 55000 batches, current avg ROUGE-1: 0.2208


Evaluating:  88%|████████▊ | 56030/63638 [03:28<00:30, 251.81it/s]

Processed 56000 batches, current avg ROUGE-1: 0.2208


Evaluating:  90%|████████▉ | 57049/63638 [03:31<00:26, 253.26it/s]

Processed 57000 batches, current avg ROUGE-1: 0.2208


Evaluating:  91%|█████████ | 58034/63638 [03:35<00:23, 243.00it/s]

Processed 58000 batches, current avg ROUGE-1: 0.2209


Evaluating:  93%|█████████▎| 59032/63638 [03:39<00:18, 252.53it/s]

Processed 59000 batches, current avg ROUGE-1: 0.2208


Evaluating:  94%|█████████▍| 60031/63638 [03:43<00:14, 248.70it/s]

Processed 60000 batches, current avg ROUGE-1: 0.2208


Evaluating:  96%|█████████▌| 61043/63638 [03:46<00:10, 250.21it/s]

Processed 61000 batches, current avg ROUGE-1: 0.2208


Evaluating:  97%|█████████▋| 62034/63638 [03:50<00:06, 251.32it/s]

Processed 62000 batches, current avg ROUGE-1: 0.2209


Evaluating:  99%|█████████▉| 63052/63638 [03:54<00:02, 249.74it/s]

Processed 63000 batches, current avg ROUGE-1: 0.2209


Evaluating: 100%|██████████| 63638/63638 [03:56<00:00, 268.91it/s]


          Результаты предсказания текста          
ROUGE-1: 0.2209 (±0.4148)
ROUGE-2: 0.0000 (±0.0000)
Number of samples: 2036384

Примеры:
Reference: 'i'
Generated: 'i'

Reference: 'm'
Generated: 'm'

Reference: 'hell'
Generated: 'not'

Reference: '##a'
Generated: '##a'

Reference: 'bored'
Generated: 'tired'



### примеры предсказаний

In [10]:
sample_prompts = ["well i m hella", "i am so excited", "once upon"]
print("\nПримеры:")
for prompt in sample_prompts:
    generated = model.generate(tokenizer, prompt, max_length=20, device='cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Prompt: '{prompt}' -> '{generated}'")
print()


Примеры:
Prompt: 'well i m hella' -> 'well i m hella i m so sad i can t wait to see him and i'
Prompt: 'i am so excited' -> 'i am so excited i m so sad i can t wait to see you and i m'
Prompt: 'once upon' -> 'once upon the one i can t do it i m so sad to be the best of'



## Запуск предобученного трансформера на тестовых данных

In [3]:
from eval_transformer_pipeline import *
#  test датасет
test_texts = pd.read_csv("data/test.csv")['text'].tolist()
print(f"Loaded {len(test_texts)} test texts")



gpt_results, gpt_examples = evaluate_distilgpt(
        model_name="distilgpt2",
        test_texts=test_texts,
        prompt_length=5,
        max_new_tokens=30,
        batch_size=8,
        device='cuda' if torch.cuda.is_available() else 'cpu'
    )


print("Метрики DistilGPT2")
print(f"ROUGE-1: {gpt_results['rouge1']:.4f} (±{gpt_results['rouge1_std']:.4f})")
print(f"ROUGE-2: {gpt_results['rouge2']:.4f} (±{gpt_results['rouge2_std']:.4f})")



Loaded 159646 test texts
Loading distilgpt2...


Generating:   0%|          | 0/159646 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Generating:   6%|▋         | 10004/159646 [07:44<1:55:09, 21.66it/s]


Processed 10000 samples, Avg ROUGE-1: 0.0655, Avg ROUGE-2: 0.0042


Generating:  13%|█▎        | 20004/159646 [15:24<1:41:24, 22.95it/s]


Processed 20000 samples, Avg ROUGE-1: 0.0656, Avg ROUGE-2: 0.0043


Generating:  19%|█▉        | 30004/159646 [23:05<1:41:49, 21.22it/s]


Processed 30000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0042


Generating:  25%|██▌       | 40005/159646 [30:44<1:30:24, 22.06it/s]


Processed 40000 samples, Avg ROUGE-1: 0.0653, Avg ROUGE-2: 0.0042


Generating:  31%|███▏      | 50004/159646 [38:26<1:10:43, 25.84it/s]


Processed 50000 samples, Avg ROUGE-1: 0.0654, Avg ROUGE-2: 0.0042


Generating:  38%|███▊      | 60003/159646 [46:15<1:08:43, 24.16it/s]


Processed 60000 samples, Avg ROUGE-1: 0.0653, Avg ROUGE-2: 0.0041


Generating:  44%|████▍     | 70002/159646 [54:02<1:09:17, 21.56it/s]


Processed 70000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0041


Generating:  50%|█████     | 80003/159646 [1:01:41<1:10:51, 18.73it/s]


Processed 80000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0041


Generating:  56%|█████▋    | 90002/159646 [1:09:21<1:02:48, 18.48it/s]


Processed 90000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0040


Generating:  63%|██████▎   | 100006/159646 [1:17:06<35:49, 27.75it/s] 


Processed 100000 samples, Avg ROUGE-1: 0.0651, Avg ROUGE-2: 0.0040


Generating:  69%|██████▉   | 110004/159646 [1:24:46<38:48, 21.32it/s]


Processed 110000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0040


Generating:  75%|███████▌  | 120001/159646 [1:32:26<27:23, 24.13it/s]


Processed 120000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0039


Generating:  81%|████████▏ | 130004/159646 [1:40:06<20:22, 24.25it/s]


Processed 130000 samples, Avg ROUGE-1: 0.0652, Avg ROUGE-2: 0.0040


Generating:  88%|████████▊ | 140004/159646 [1:47:45<15:46, 20.76it/s]


Processed 140000 samples, Avg ROUGE-1: 0.0654, Avg ROUGE-2: 0.0040


Generating: 100%|██████████| 159646/159646 [2:02:49<00:00, 21.66it/s]


Метрики DistilGPT2
ROUGE-1: 0.0652 (±0.0773)
ROUGE-2: 0.0039 (±0.0209)


### примеры предсказаний

In [12]:
print("Примеры")
for i, ex in enumerate(gpt_examples[:5]):
    print(f"\nExample {i+1}:")
    print(f"Prompt: '{ex['prompt']}'")
    print(f"Reference: '{ex['reference']}'")
    print(f"Generated: '{ex['generated']}'")
    print(f"ROUGE-1: {ex['rouge1']:.4f}")

Примеры

Example 1:
Prompt: 'well i m hella'
Reference: ' bored here and i hate that school is just a week away'
Generated: '.
I have to admit I am quite a bit of a fan of the game but the overall experience was fantastic. The game is as good as'
ROUGE-1: 0.1538

Example 2:
Prompt: 'is so excited as c'
Reference: 'g has the package now i have to wait until she comes on msn'
Generated: 'uz you've got to make the most of it.


-
-
-
-
-
-
-
-'
ROUGE-1: 0.1667

Example 3:
Prompt: 'just went shopping but i'
Reference: ' m still bummed about no wrestling for me no triple h'
Generated: ''m not sure how to keep my money on the exchanges for the entire month! Thanks a lot for your support!'
ROUGE-1: 0.1250

Example 4:
Prompt: 'wish i could by'
Reference: ' music from itunes in brazil we re not allowed here you know'
Generated: 'the way i am now in the UK.


It was pretty much the only thing I can remember about the phone's firmware. I remember'
ROUGE-1: 0.0541

Example 5:
Prompt: 'listening to

## Сравнение моделей


Метрики LSTM
ROUGE-1: 0.2209 (±0.4148)
ROUGE-2: 0.0000 (±0.0000)

Примеры предсказаний:
Prompt1: 'well i m hella' 
Generated: 'well i m hella i m so sad i can t wait to see him and i'

Prompt2: 'i am so excited'
Generated: 'i am so excited i m so sad i can t wait to see you and i m'

Prompt3: 'once upon' 
Generated: 'once upon the one i can t do it i m so sad to be the best of'


Метрики DistilGPT2
ROUGE-1: 0.0655 (±0.0776)
ROUGE-2: 0.0039 (±0.0210)


Prompt1: 'well i m hella'
Generated: 'I have to admit I am quite a bit of a fan of the game but the overall experience was fantastic. The game is as good as'

Prompt2: 'is so excited as c'
Generated: 'uz you've got to make the most of it.'

Prompt3: 'just went shopping but i'
Generated: ''m not sure how to keep my money on the exchanges for the entire month! Thanks a lot for your support!'

Prompt4: 'wish i could by'
Generated: 'the way i am now in the UK. It was pretty much the only thing I can remember about the phone's firmware. I remember'

Вывод: обученная нами RNN-модель на основе LSTM генерирует связные, но бессмыссленные фрагменты текста. Совпадение биграмм (ROUGE-2) 0. По всей видимости, требуется дальнейшая оптимизация модели и тренировочных датасетов.  Модель DistilGPT2 в некоторых случаях тоже генерирует бессмысленные повторяющиеся слова, однако на большинство запросов генерирует логичные фрагменты текста, вполне осмысленные в контексте промта. При этом метрики ROUGE имеют очень малые значения, однако по всей видимости это отражает то, что модель была обучена на других данных и способна генерировать принципиально другой ответ, отличный от имеющегося в тестовом датасете, что и отражается на значениях метрик.

Время инференса DistilGPT2 на тестовом датасете 2:02:49
Время инференса LSTM на тестовом датасете 03:56
Таким образом, LSTM более чем 30 раз быстрее DistilGPT2, что также важно при запуске в условиях ограниченных ресурсов/времени.


