# 1. Библиотеки, фреймворки и параметры обучения

In [None]:
import numpy as np
import torch
from tqdm import tqdm
import math
import os

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

Словарь с параметрами обучения

In [None]:
args = {
    "model_name": "gpt2",
    "max_seq_length": 512,
    "train_batch_size": 4,
    "valid_batch_size": 4,
    "num_train_epochs": 1,
    "warmup": 0.1,
    "learning_rate": 5e-5,
    "input_text_path": ".",
    "story_length": 300
}

# 2. Подготовка данных

## Формирование тренеровочного и валидационного датасета и их очистка

Создаем тренеровочный и валидационный датасеты в виде списка из строк, где в каждой строке объеденены запросы и сами истории, разделенные токеном `<sep>`. Так же проведем небольшую чистку строк (удалим пробел слева от знаков пунктуации, заменим токен `<newline>` на `\n` и др.)


In [None]:
DATAPATH=args["input_text_path"]
def combinetext(prompt, story):
    fp=open(os.path.join(DATAPATH,prompt),encoding='utf8')
    fs=open(os.path.join(DATAPATH,story),encoding='utf8')
    prompts=fp.readlines()
    stories=fs.readlines()
    assert len(prompts)==len(stories)
    combine=[]
    for i in range(len(prompts)):
        combine.append(prompts[i].rstrip()[7:]+' <sep> '+" ".join(stories[i].split()[:args["story_length"]]))
    return combine

def clean_punctuation(s):
    for p in '!,.:;?':
        s=s.replace(' '+p,p)
    s=s.replace(' '+'n\'t','n\'t')
    s=s.replace(' '+'\'s','\'s')
    s=s.replace(' '+'\'re','\'re')
    s=s.replace(' '+'\'ve','\'ve')
    s=s.replace(' '+'\'ll','\'ll')
    s=s.replace(' '+'\'am','\'am')
    s=s.replace(' '+'\'m','\'m')
    s=s.replace(' '+'\' m','\'m')
    s=s.replace(' '+'\'m','\'m')
    s=s.replace(' '+'\' ve','\'ve')
    s=s.replace(' '+'\' s','\'s')
    s=s.replace('<newline>','\n')
    return s

train_text=combinetext('train.wp_source', 'train.wp_target')
train_text=list(map(clean_punctuation,train_text))
valid_text=combinetext('test.wp_source', 'test.wp_target')
valid_text=list(map(clean_punctuation,valid_text))

train_text[1]

Проведем токенизацию теста с фиксированной длинной.

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token=tokenizer.eos_token

inputs_train = tokenizer(train_text, padding=True,truncation=True,max_length=args["max_seq_length"])
inputs_valid = tokenizer(valid_text, padding=True,truncation=True,max_length=args["max_seq_length"])



Создаем таргеты для обучения модели

In [None]:
def create_labels(inputs):
    labels=[]
    for ids,attention_mask in zip(inputs['input_ids'],inputs['attention_mask']):
        label=ids.copy()
        real_len=sum(attention_mask)
        padding_len=len(attention_mask)-sum(attention_mask)
        label[:]=label[:real_len]+[-100]*padding_len
        labels.append(label)
    inputs['labels']=labels

create_labels(inputs_train)
create_labels(inputs_valid)

Инициализируем класс датасета и создаем его объекты для тренеровочной и валидационной выборок

In [None]:
class StoryDataset:
    def __init__(self, inputs):
        self.ids = inputs['input_ids']
        self.attention_mask = inputs['attention_mask']
        self.labels=inputs['labels']

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, item):

        return [torch.tensor(self.ids[item], dtype=torch.long),
                torch.tensor(self.attention_mask[item], dtype=torch.long),
                torch.tensor(self.labels[item], dtype=torch.long)]

In [None]:
train_batch_size=args["train_batch_size"]
valid_batch_size=args["valid_batch_size"]
traindata=StoryDataset(inputs_train)
train_dataloader = torch.utils.data.DataLoader(
    traindata,
    shuffle=False,
    batch_size=train_batch_size)

validdata=StoryDataset(inputs_valid)
valid_dataloader = torch.utils.data.DataLoader(
    validdata,
    shuffle=False,
    batch_size=valid_batch_size)

Создаем объект предобученной модели gpt-2

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

  model.load_state_dict(torch.load("model_3_183.pt", map_location=torch.device('cpu')))


<All keys matched successfully>

Функция для генерации истории по запросу. Модель генерирует историю определенной длины, затем из текста удаляется последнее незаконченное предложение

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def generate_story(prompt,k=0,p=0.9,output_length=300,temperature=1,num_return_sequences=3,repetition_penalty=1.0):
    encoded = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    model.to(device)
    model.eval()
    output_sequences = model.generate(
        input_ids=encoded,
        max_length=output_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences
    )
    res = []
    for generated_sequence in output_sequences:
        s = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        res.append(s[:s.rfind('.') + 1].replace(" <sep>", ""))
    return res

# Пробуем сгенерировать историю с помошью предобученной модели
generate_story("Students of the Moscow Aviation Institute celebrate the first of September in Pokrovskoye-Streshnevo Park")

# 3. Дообучение модели

Задаем параметры обучения, инициализируем оптимизатор и шедулер

In [None]:
num_train_epochs = args["num_train_epochs"]
training_steps_per_epoch=len(train_dataloader)
total_num_training_steps = int(training_steps_per_epoch*num_train_epochs)
weight_decay=0
learning_rate=args["learning_rate"]
adam_epsilon=1e-8
warmup_steps=int(total_num_training_steps*args.warmup)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": weight_decay,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_num_training_steps
)



Цикл обучения модели

In [None]:
print("***** Running training *****")
print("  Total_num_training_step = {}".format(total_num_training_steps))
print("  Num Epochs = {}".format(num_train_epochs))
print(f"  Train_batch_size per device = {train_batch_size}")
print(f"  Valid_batch_size per device = {valid_batch_size}")
model.to(device)
for epoch in range(num_train_epochs):
    print(f"Start epoch{epoch+1} of {num_train_epochs}")
    train_loss=0
    epoch_iterator = tqdm(train_dataloader,desc='Iteration')
    model.train()
    model.zero_grad()
    for _, inputs in enumerate(epoch_iterator):
        d1,d2,d3=inputs
        d1=d1.to(device)
        d2=d2.to(device)
        d3=d3.to(device)
        output = model(input_ids=d1, attention_mask=d2,labels=d3)
        batch_loss=output[0]
        batch_loss.backward()
        optimizer.step()
        scheduler.step()
        model.zero_grad()
        train_loss+=batch_loss.item()
        epoch_iterator.set_description('(batch loss=%g)' % batch_loss.item())
        del batch_loss
    print(f'Average train loss per example={train_loss/training_steps_per_epoch} in epoch{epoch+1}')
    print(f'Starting evaluate after epoch {epoch+1}')
    eval_loss=[]
    model.eval()
    for inputs in tqdm(valid_dataloader, desc="eval"):
        d1,d2,d3=inputs
        d1=d1.to(device)
        d2=d2.to(device)
        d3=d3.to(device)
        with torch.no_grad():
            output = model(input_ids=d1, attention_mask=d2,labels=d3)
            batch_loss=output[0]
        eval_loss+=[batch_loss.cpu().item()]
        del batch_loss
    eval_loss=np.mean(eval_loss)
    perplexity=math.exp(eval_loss)
    print(f'Average valid loss per example={eval_loss} in epoch{epoch+1}')
    print(f'Perplextiy for valid dataset in epoch{epoch+1} is {perplexity}')

***** Running training *****
  Total_num_training_step = 3905
  Num Epochs = 1
  Train_batch_size per device = 4
  Valid_batch_size per device = 4
Start epoch1 of 1


(batch loss=2.86966): 100%|██████████| 3905/3905 [37:49<00:00,  1.72it/s]


Average train loss per example=3.1536706931154495 in epoch1
Starting evaluate after epoch 1


eval: 100%|██████████| 3785/3785 [11:01<00:00,  5.72it/s]

Average valid loss per example=3.182995136064456 in epoch1
Perplextiy for valid dataset in epoch1 is 24.118884818416824





Сохраняем словарь с весами модели

In [None]:
torch.save(model.state_dict(), "model_state.pt")

Пробуем теперь сгенерировать текст с дообученной моделью

In [None]:
generate_story("Students of the Moscow Aviation Institute celebrate the first of September in Pokrovskoye-Streshnevo Park")