In [1]:
import torch
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split

In [2]:
dataset = load_dataset('daily_dialog') # датасет по повседневным диалогам на разные темы, позиционируется как more human-like speech

In [3]:
# реплики диалогов в ячейках записаны как str через запятую --> preprocess все реплики одного диалога как один str
def preprocess(df):
    df['dialog'] = " ".join(df['dialog'])
    return df

In [4]:
dataset = dataset.map(preprocess)

In [5]:
model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-small')
tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-small')
tokenizer.pad_token = tokenizer.eos_token

In [6]:
def tokenizing(df):
    tokenized_df = tokenizer(df['dialog'], max_length=128, padding='max_length', truncation=True)
    tokenized_df['labels'] = tokenized_df['input_ids'][:]
    return tokenized_df

In [None]:
data = dataset.map(tokenizing, batched=True)

In [8]:
train_data = data['train']
eval_data = data['validation']

In [9]:
training_args = TrainingArguments(
    output_dir='/content/chatbot_model',
    num_train_epochs=5,
    learning_rate = 0.00005,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    prediction_loss_only=True,
    save_total_limit=2,
    save_steps=340,
    logging_steps=340,
    overwrite_output_dir=True,
)

In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data
)

In [11]:
trainer.train()

Step,Training Loss
340,2.343
680,1.991
1020,1.8661
1360,1.8389
1700,1.7749
2040,1.7556
2380,1.7143
2720,1.6976
3060,1.6857
3400,1.6623


TrainOutput(global_step=3475, training_loss=1.82918417731635, metrics={'train_runtime': 1060.8864, 'train_samples_per_second': 52.4, 'train_steps_per_second': 3.276, 'total_flos': 3631306014720000.0, 'train_loss': 1.82918417731635, 'epoch': 5.0})

In [None]:
model.save_pretrained("/content/chatbot_model")
tokenizer.save_pretrained("/content/chatbot_tokenizer")

In [13]:
my_model = GPT2LMHeadModel.from_pretrained("/content/chatbot_model")
my_model.config.pad_token_id = my_model.config.eos_token_id
my_tokenizer = GPT2Tokenizer.from_pretrained("/content/chatbot_tokenizer")

In [14]:
def get_response(user_input):
    text = user_input.lower()
    input_ids = my_tokenizer.encode(text, return_tensors="pt", padding=True, truncation=True)
    attention_mask = torch.ones_like(input_ids)

    with torch.no_grad():
        output = my_model.generate(input_ids, attention_mask=attention_mask, max_length=32, num_beams=5, no_repeat_ngram_size=2)
        final_output = my_tokenizer.decode(output[0], skip_special_tokens=True)
        final_output = final_output.replace(text, '').strip()
    return final_output

In [17]:
print('\033[92m'+'ChatBot: '+'\033[0m', end='')
print('To end the dialogue, print "bye"')
while True:
    user_input = input('\033[91m'+'User: '+'\033[0m')
    if user_input.lower() == 'bye':
        print('\033[92m'+'ChatBot: '+'\033[0m', end='')
        print('Goodbye!')
        break
    bot_response = get_response(user_input)
    print('\033[92m'+'ChatBot: '+'\033[0m', end='')
    print(bot_response)

[92mChatBot: [0mTo end the dialogue, print "bye"
[91mUser: [0mhello, how do you do?
[92mChatBot: [0mHello, I ’ m calling to ask you a few questions about your new job. Can I help you?
[91mUser: [0mwhere do you work?
[92mChatBot: [0mI work in a publishing house.
[91mUser: [0mwhat are your hobbies? 
[92mChatBot: [0mI like to play golf, read a lot of books, and listen to classical music. I also like collecting stamps.
[91mUser: [0mdo you prefer cats or dogs? 
[92mChatBot: [0mI prefer dogs.
[91mUser: [0mwhat exercises can i do to gain muscles? 
[92mChatBot: [0mYou can do push-ups, pull-up or sit-down exercises.
[91mUser: [0mhow do i stay healthy in winter? 
[92mChatBot: [0mDon't worry about it. It's just a part of being a member of the family.
[91mUser: [0mhow do i make a pizza?
[92mChatBot: [0mPut the dough in the oven, and wait for it to be ready. Then put it on the grill.
[91mUser: [0mbye
[92mChatBot: [0mGoodbye!


Задаю некоторые вопросы такие же, что спрашивала и у tf-idf, чтобы посмотреть разницу в ответах.