# Large Language Model Tutorial

Large Language Model (LLM) tutorial for Information Retrieval course, Faculty of Computer Science, Universitas Indonesia. This notebook contain code for loading per-trained language model, preprocessing instruction dataset, causal language modeling supervised training, inference (generating text), and various decoding methods.

## Quick Introduction

Before we start, let's have a quick recap about Large Language Model. Slide: [Link](https://docs.google.com/presentation/d/1CamCGqiDMlJ4IdthpIxhBt_2DsLkbRK3fLYvkqi4w9s/edit?usp=sharing)

## Preparation

The following will install and import the required packages, and some define variables. We will mainly use [PyTorch](https://pytorch.org/) and [transformers](https://github.com/huggingface/transformers) package from HuggingFace.

In [1]:
!pip install transformers==4.35.2 datasets==2.14.0 accelerate==0.24.1

Collecting transformers==4.35.2
  Downloading transformers-4.35.2-py3-none-any.whl.metadata (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.5/123.5 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets==2.14.0
  Downloading datasets-2.14.0-py3-none-any.whl.metadata (19 kB)
Collecting accelerate==0.24.1
  Downloading accelerate-0.24.1-py3-none-any.whl.metadata (18 kB)
Downloading transformers-4.35.2-py3-none-any.whl (7.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m69.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hDownloading datasets-2.14.0-py3-none-any.whl (492 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m492.2/492.2 kB[0m [31m25.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected 

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers import PreTrainedModel
from transformers import pipeline, set_seed, TextGenerationPipeline
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.tokenization_utils_base import BatchEncoding
from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import set_seed
from datasets import load_dataset, Dataset, DatasetDict
from datasets.iterable_dataset import IterableDataset
from torch import nn
from torch import Tensor
from typing import List, Dict, Any
import pandas as pd
import re
import json
import multiprocessing
import torch

model_id: str = 'gpt2' # you can also try: distilgpt2 gpt2
num_train_epochs: int = 3
instruction_format: str = (
    "Below is an instruction that describes a task. "
    "Write a response that appropriately completes the request.\n"
    "\n"
    "### Question:\n"
    "{question}"
    "\n\n"
    "### Answer:\n"
    "{answer}"
)
device: torch.device = torch.device("cuda") \
  if torch.cuda.is_available() else torch.device("cpu")
device

2024-05-07 05:59:09.631974: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-07 05:59:09.632089: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-07 05:59:09.776168: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


device(type='cuda')

## Load Pre-Trained Model

The following will load the casual language pre-trained model from HuggingFace model repository. The model architecture is shown below.

In [3]:
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model: nn.Module = AutoModelForCausalLM.from_pretrained(model_id)
print(f"model: {model}")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

model: GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


## Dataset Preprocessing

These section will load, preprocess, tokenize, and split the dataset to be used in training. The sample rows will be shown below.

In [4]:
data = pd.read_csv('/kaggle/input/customer-support-on-twitter/twcs/twcs.csv')
data

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
0,1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,@115712 I understand. I would like to assist y...,2,3.0
1,2,115712,True,Tue Oct 31 22:11:45 +0000 2017,@sprintcare and how do you propose we do that,,1.0
2,3,115712,True,Tue Oct 31 22:08:27 +0000 2017,@sprintcare I have sent several private messag...,1,4.0
3,4,sprintcare,False,Tue Oct 31 21:54:49 +0000 2017,@115712 Please send us a Private Message so th...,3,5.0
4,5,115712,True,Tue Oct 31 21:49:35 +0000 2017,@sprintcare I did.,4,6.0
...,...,...,...,...,...,...,...
2811769,2987947,sprintcare,False,Wed Nov 22 08:43:51 +0000 2017,"@823869 Hey, we'd be happy to look into this f...",,2987948.0
2811770,2987948,823869,True,Wed Nov 22 08:35:16 +0000 2017,@115714 wtf!? I’ve been having really shitty s...,2987947,
2811771,2812240,121673,True,Thu Nov 23 04:13:07 +0000 2017,@143549 @sprintcare You have to go to https://...,,2812239.0
2811772,2987949,AldiUK,False,Wed Nov 22 08:31:24 +0000 2017,"@823870 Sounds delicious, Sarah! 😋 https://t.c...",,2987950.0


In [5]:
import re

def preprocess_text(text):
    # Menghapus @mentions
    text = re.sub(r'@\w+', '', text)

    # Menghapus URL
    text = re.sub(r'http\S+|www.\S+', '', text)

    # Menghapus emoticon
    text = re.sub(r'[\U00010000-\U0010ffff]', '', text)

    # Menghapus karakter non-ASCII
    text = re.sub(r'[^\x00-\x7F]+', '', text)

    return text

data['text'] = data['text'].apply(preprocess_text)
data

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
0,1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,I understand. I would like to assist you. We ...,2,3.0
1,2,115712,True,Tue Oct 31 22:11:45 +0000 2017,and how do you propose we do that,,1.0
2,3,115712,True,Tue Oct 31 22:08:27 +0000 2017,I have sent several private messages and no o...,1,4.0
3,4,sprintcare,False,Tue Oct 31 21:54:49 +0000 2017,Please send us a Private Message so that we c...,3,5.0
4,5,115712,True,Tue Oct 31 21:49:35 +0000 2017,I did.,4,6.0
...,...,...,...,...,...,...,...
2811769,2987947,sprintcare,False,Wed Nov 22 08:43:51 +0000 2017,"Hey, we'd be happy to look into this for you....",,2987948.0
2811770,2987948,823869,True,Wed Nov 22 08:35:16 +0000 2017,wtf!? Ive been having really shitty service a...,2987947,
2811771,2812240,121673,True,Thu Nov 23 04:13:07 +0000 2017,You have to go to and ask them to add the H...,,2812239.0
2811772,2987949,AldiUK,False,Wed Nov 22 08:31:24 +0000 2017,"Sounds delicious, Sarah!",,2987950.0


In [6]:
questions = data[data['inbound']]
questions

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
1,2,115712,True,Tue Oct 31 22:11:45 +0000 2017,and how do you propose we do that,,1.0
2,3,115712,True,Tue Oct 31 22:08:27 +0000 2017,I have sent several private messages and no o...,1,4.0
4,5,115712,True,Tue Oct 31 21:49:35 +0000 2017,I did.,4,6.0
6,8,115712,True,Tue Oct 31 21:45:10 +0000 2017,is the worst customer service,9610,
8,12,115713,True,Tue Oct 31 22:04:47 +0000 2017,You gonna magically change your connectivity ...,111314,15.0
...,...,...,...,...,...,...,...
2811765,2987944,823868,True,Wed Nov 22 07:43:36 +0000 2017,\n\nI am unable to do web checkin. I am gett...,2987943,
2811768,2987946,524544,True,Wed Nov 22 08:25:48 +0000 2017,Hope you are well? Does the 9.30am train from...,2987945,
2811770,2987948,823869,True,Wed Nov 22 08:35:16 +0000 2017,wtf!? Ive been having really shitty service a...,2987947,
2811771,2812240,121673,True,Thu Nov 23 04:13:07 +0000 2017,You have to go to and ask them to add the H...,,2812239.0


In [7]:
answers = data[data['in_response_to_tweet_id'].notna() &  (~data['inbound'])]
answers

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
0,1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,I understand. I would like to assist you. We ...,2,3.0
3,4,sprintcare,False,Tue Oct 31 21:54:49 +0000 2017,Please send us a Private Message so that we c...,3,5.0
5,6,sprintcare,False,Tue Oct 31 21:46:24 +0000 2017,"Can you please send us a private message, so ...",57,8.0
7,11,sprintcare,False,Tue Oct 31 22:10:35 +0000 2017,This is saddening to hear. Please shoot us a ...,,12.0
9,15,sprintcare,False,Tue Oct 31 20:03:31 +0000 2017,We understand your concerns and we'd like for...,12,16.0
...,...,...,...,...,...,...,...
2811764,2987943,AirAsiaSupport,False,Wed Nov 22 07:54:57 +0000 2017,"Sorry but kindly try to clear browser,cache,c...",,2987944.0
2811766,139628,ArgosHelpers,False,Wed Nov 22 08:03:26 +0000 2017,Can you Dm us your order number and we can lo...,,139627.0
2811767,2987945,VirginTrains,False,Wed Nov 22 08:27:34 +0000 2017,That's a Peak service. The 09:56 is the first...,,2987946.0
2811769,2987947,sprintcare,False,Wed Nov 22 08:43:51 +0000 2017,"Hey, we'd be happy to look into this for you....",,2987948.0


In [8]:
qa = pd.merge(questions[['tweet_id', 'text', 'in_response_to_tweet_id']], answers[['text', 'in_response_to_tweet_id', 'tweet_id', 'author_id']], left_on='tweet_id', right_on='in_response_to_tweet_id')

qa.columns = ['question_id', 'question', 'in_response_to_tweet_id', 'answer', 'lol', 'answer_id', 'author_id']

qa = qa.drop('lol', axis=1)
qa

Unnamed: 0,question_id,question,in_response_to_tweet_id,answer,answer_id,author_id
0,3,I have sent several private messages and no o...,4.0,I understand. I would like to assist you. We ...,1,sprintcare
1,5,I did.,6.0,Please send us a Private Message so that we c...,4,sprintcare
2,8,is the worst customer service,,"Can you please send us a private message, so ...",6,sprintcare
3,8,is the worst customer service,,I would love the chance to review the account...,9,sprintcare
4,8,is the worst customer service,,Hello! We never like our customers to feel li...,10,sprintcare
...,...,...,...,...,...,...
1261883,2987942,Hai #asking how many days needed to proceed c...,,we have replied you via DM.Thanks-Emir,2987941,AirAsiaSupport
1261884,2987944,\n\nI am unable to do web checkin. I am gett...,,"Sorry but kindly try to clear browser,cache,c...",2987943,AirAsiaSupport
1261885,2987946,Hope you are well? Does the 9.30am train from...,,That's a Peak service. The 09:56 is the first...,2987945,VirginTrains
1261886,2987948,wtf!? Ive been having really shitty service a...,,"Hey, we'd be happy to look into this for you....",2987947,sprintcare


In [9]:
dont_have_answer = questions[~questions['tweet_id'].isin(qa['question_id'])]
dont_have_answer = dont_have_answer.dropna(subset=['response_tweet_id'])
dont_have_answer = dont_have_answer[dont_have_answer['tweet_id'].isin(qa['in_response_to_tweet_id'])]
dont_have_answer

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
40,46,115722,True,Tue Oct 31 22:04:29 +0000 2017,"Hello Duke, Do you have a copy of your bill? ...",42,47.0
148,234,115762,True,Tue Oct 31 05:55:08 +0000 2017,i pre ordered wwii but how do i get the code?,232,
204,293,115769,True,Wed Oct 18 14:07:45 +0000 2017,Whoa! Come along with Lightrooms own Ben Warde...,292,
228,318,115785,True,Tue Oct 31 20:03:55 +0000 2017,can anyone let me know when our pre orders...,311,
238,328,115794,True,Tue Oct 31 21:57:43 +0000 2017,Ah maybe could help on this one,327,329.0
...,...,...,...,...,...,...,...
2811400,2987587,823760,True,Thu Nov 30 02:32:10 +0000 2017,You even tried to charge us when THE WRONG TE...,2987586,2987588.0
2811415,2987600,665118,True,Wed Nov 22 01:25:59 +0000 2017,Literally called customer support for them to...,2987599,
2811582,2987763,823812,True,Thu Nov 30 00:40:58 +0000 2017,"I need an everything bagel and like, a Target ...",2987762,
2811620,2987799,630312,True,Tue Oct 31 21:32:46 +0000 2017,Icing on the cake was flight was 1/2 full. If...,2987798,2987800.0


In [10]:
new_qa = pd.merge(dont_have_answer[['tweet_id', 'text', 'in_response_to_tweet_id']], qa[['answer', 'in_response_to_tweet_id', 'answer_id', 'author_id']], left_on='tweet_id', right_on='in_response_to_tweet_id')

new_qa.columns = ['question_id', 'question', 'in_response_to_tweet_id', 'answer', 'lol', 'answer_id', 'author_id']

new_qa = new_qa.drop('lol', axis=1)
new_qa

Unnamed: 0,question_id,question,in_response_to_tweet_id,answer,answer_id,author_id
0,46,"Hello Duke, Do you have a copy of your bill? ...",47.0,Please follow and DM us so that we can look i...,40,VerizonSupport
1,234,i pre ordered wwii but how do i get the code?,,"Hello there, I apologize for the delay. Can y...",231,ATVIAssist
2,293,Whoa! Come along with Lightrooms own Ben Warde...,,"Hi Duncan, please update Lightroom to 6.13 fr...",291,AdobeCare
3,318,can anyone let me know when our pre orders...,,Hi there! We'd recommend reaching out to the ...,309,XboxSupport
4,328,Ah maybe could help on this one,329.0,Hi. Emergency services are dealing with an i...,326,nationalrailenq
...,...,...,...,...,...,...
89995,2987587,You even tried to charge us when THE WRONG TE...,2987588.0,"Hello, We are able to look into your service ...",2987585,comcastcares
89996,2987600,Literally called customer support for them to...,,I apologize for the experience you had. I'd l...,2987598,Ask_WellsFargo
89997,2987763,"I need an everything bagel and like, a Target ...",,Thanks for reaching out to us. We're here to ...,2987761,AskTarget
89998,2987799,Icing on the cake was flight was 1/2 full. If...,2987800.0,"We're sorry to hear of your experience, Carol...",2987797,SouthwestAir


In [11]:
qa = pd.concat([qa, new_qa], ignore_index=True)

In [12]:
qa = qa.drop(['in_response_to_tweet_id', 'question_id', 'answer_id'], axis=1)
qa

Unnamed: 0,question,answer,author_id
0,I have sent several private messages and no o...,I understand. I would like to assist you. We ...,sprintcare
1,I did.,Please send us a Private Message so that we c...,sprintcare
2,is the worst customer service,"Can you please send us a private message, so ...",sprintcare
3,is the worst customer service,I would love the chance to review the account...,sprintcare
4,is the worst customer service,Hello! We never like our customers to feel li...,sprintcare
...,...,...,...
1351883,You even tried to charge us when THE WRONG TE...,"Hello, We are able to look into your service ...",comcastcares
1351884,Literally called customer support for them to...,I apologize for the experience you had. I'd l...,Ask_WellsFargo
1351885,"I need an everything bagel and like, a Target ...",Thanks for reaching out to us. We're here to ...,AskTarget
1351886,Icing on the cake was flight was 1/2 full. If...,"We're sorry to hear of your experience, Carol...",SouthwestAir


In [13]:
qa = qa[qa['question'].apply(len) >= 5]
qa = qa[qa['answer'].apply(len) >= 10]
qa = qa.drop_duplicates(subset=['question'])
qa = qa.drop_duplicates(subset=['answer'])
drop_dm = qa[qa['answer'].str.contains('DM')].head(200000).index
qa = qa.drop(drop_dm)
qa

Unnamed: 0,question,answer,author_id
0,I have sent several private messages and no o...,I understand. I would like to assist you. We ...,sprintcare
1,I did.,Please send us a Private Message so that we c...,sprintcare
2,is the worst customer service,"Can you please send us a private message, so ...",sprintcare
8,Since I signed up with you....Since day 1,We understand your concerns and we'd like for...,sprintcare
9,yall lie about your great connection. 5 bars ...,H there! We'd definitely like to work with yo...,sprintcare
...,...,...,...
1351801,# Dear Idea Team My Bill Amount 1241 rs but ...,We would like to inform you that full and fin...,idea_cares
1351811,when are yall dropping season 3 of rick and m...,Sorry for the delayed reply! We don't have a ...,hulu_support
1351816,via,(Amazon RI,AmazonHelp
1351839,. This is the state of Dedicated Game Battle...,"Hello, this was already fixed with the patch ...",ATVIAssist


In [14]:
qa = qa[qa['author_id'] == 'AppleSupport'].head(20000)
qa

Unnamed: 0,question,answer,author_id
193,This is what it looks like,Any steps tried since it started last night?,AppleSupport
194,I have an iPhone 7 Plus and yes I do,That's great it has iOS 11.1 as we can rule o...,AppleSupport
195,I need answers because its annoying,We'd like to look into this with you. Which m...,AppleSupport
198,Tf is wrong with my keyboard,"Fill us in on what is happening, then we can ...",AppleSupport
200,hello are all the lines closed for tonight #...,What's going on? We're hapy to help if we can.,AppleSupport
...,...,...,...
649011,Previously I had Iphone 6 Plus and I had to g...,We can understand why you reached out. Could ...,AppleSupport
649013,somehow with you three updates in the past ...,We know how important it is for your iPhone t...,AppleSupport
649016,Today in my morning run I was stuck to fix my...,We'd be happy to look at this with you. Can y...,AppleSupport
649017,Is ever going to correct the I,Thanks for reaching out! Check out this artic...,AppleSupport


In [15]:
# Reset index
qa = qa.reset_index(drop=True)
qa

Unnamed: 0,question,answer,author_id
0,This is what it looks like,Any steps tried since it started last night?,AppleSupport
1,I have an iPhone 7 Plus and yes I do,That's great it has iOS 11.1 as we can rule o...,AppleSupport
2,I need answers because its annoying,We'd like to look into this with you. Which m...,AppleSupport
3,Tf is wrong with my keyboard,"Fill us in on what is happening, then we can ...",AppleSupport
4,hello are all the lines closed for tonight #...,What's going on? We're hapy to help if we can.,AppleSupport
...,...,...,...
19995,Previously I had Iphone 6 Plus and I had to g...,We can understand why you reached out. Could ...,AppleSupport
19996,somehow with you three updates in the past ...,We know how important it is for your iPhone t...,AppleSupport
19997,Today in my morning run I was stuck to fix my...,We'd be happy to look at this with you. Can y...,AppleSupport
19998,Is ever going to correct the I,Thanks for reaching out! Check out this artic...,AppleSupport


In [16]:
def take_up_to_200_rows(group):
    return group.head(200)

# Kelompokkan berdasarkan 'author_id' dan ambil sampai 500 baris pertama dari setiap group
qa = qa.groupby('author_id').apply(take_up_to_200_rows)

# Reset index
qa = qa.reset_index(drop=True)
qa

  qa = qa.groupby('author_id').apply(take_up_to_200_rows)


Unnamed: 0,question,answer,author_id
0,I'll be attending #ATTBizSummit this week. Co...,Glad you are getting excited! Be sure and giv...,ATT
1,I got my Justice League posters thx #Justice...,Whos your favorite hero? #Batman #WonderWoman...,ATT
2,Thanks Amber! Returning all my Charter Spectr...,Awesome Megan! Please send your telephone num...,ATT
3,is raping us all. There prices are outrageous...,We'll be happy to have you apart of our famil...,ATT
4,Shout out to for sending Gary in one day to f...,We are glad to know the issue is resolved. We...,ATT
...,...,...,...
20952,"""Hey, : are y'all really gonna keep me from ge...",You would have to purchase through a 3rd part...,sprintcare
20953,you gonna fight for our love or lose me to ?,Hey there. Are you having any issues? Let us ...,sprintcare
20954,I'm in Hong Kong now and it didn't ask me to ...,... you to contact our International Departme...,sprintcare
20955,I arrive Saturday. I'll await the message on ...,You're welcome. Do have a good day and thank ...,sprintcare


In [16]:
dataset_raw: Dataset = Dataset.from_pandas(qa)
print(f"dataset_raw: {dataset_raw}")
print(f"example raw row: {json.dumps(dataset_raw[13], indent=2)}")

dataset_raw: Dataset({
    features: ['question', 'answer', 'author_id'],
    num_rows: 20000
})
example raw row: {
  "question": "MY HOME BUTTON DOESNT WORK #IOS11 ",
  "answer": " Let us help with your Home button. Did this issue start right after iOS 11? Which version of iOS 11 are you running?",
  "author_id": "AppleSupport"
}


In [17]:
from typing import List, Dict, Any, Set, Tuple

In [18]:
# Tokenize
tokenizer_pattern: str = r"\b\w[\w']*\b"
example_passage_en = "Thank you. DM me your name, address, contact details, and phone number associated with your account. Thanks-Emir. ^BG  *HDG  *HDG  ^LC  *HDG *Mobile *Mobile * *Regards *"
def tokenize_text_en(text: str, tokenizer_pattern: str) -> List[str]:
  tokens: List[str] = re.findall(tokenizer_pattern, text)
  return tokens
example_tokens_en: List[str] = tokenize_text_en(
  text = example_passage_en,
  tokenizer_pattern = tokenizer_pattern,
)
print(f"example of tokenized text: {example_tokens_en}")

example of tokenized text: ['Thank', 'you', 'DM', 'me', 'your', 'name', 'address', 'contact', 'details', 'and', 'phone', 'number', 'associated', 'with', 'your', 'account', 'Thanks', 'Emir', 'BG', 'HDG', 'HDG', 'LC', 'HDG', 'Mobile', 'Mobile', 'Regards']


In [19]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [20]:
import spacy
# Install packages as needed
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m76.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [21]:
# Lemmatize
nlp = spacy.load("en_core_web_sm")

def lemmatize_tokens_en(tokens: List[str], nlp) -> List[str]:
    example_lemmatized_en = []

    for doc in nlp.pipe(tokens):
        tok = [token.lemma_ for token in doc]
        example_lemmatized_en.extend(tok)

    return example_lemmatized_en

example_lemmatized_en: List[str] = lemmatize_tokens_en(
  tokens = example_tokens_en,
  nlp = nlp,
)
print(f"example of tokens before lemmatization: {example_tokens_en}")
print(f"example of tokens after lemmatization:  {example_lemmatized_en}")

example of tokens before lemmatization: ['Thank', 'you', 'DM', 'me', 'your', 'name', 'address', 'contact', 'details', 'and', 'phone', 'number', 'associated', 'with', 'your', 'account', 'Thanks', 'Emir', 'BG', 'HDG', 'HDG', 'LC', 'HDG', 'Mobile', 'Mobile', 'Regards']
example of tokens after lemmatization:  ['thank', 'you', 'dm', 'I', 'your', 'name', 'address', 'contact', 'detail', 'and', 'phone', 'number', 'associate', 'with', 'your', 'account', 'thank', 'Emir', 'BG', 'HDG', 'HDG', 'lc', 'HDG', 'mobile', 'mobile', 'regard']


In [22]:
# Stemming
import nltk
from nltk.stem import PorterStemmer

stemmer = PorterStemmer()
def stem_tokens_en(tokens: List[str], stemmer: PorterStemmer) -> List[str]:
    stemmed_tokens = [stemmer.stem(token) for token in tokens]
    return stemmed_tokens

example_tokens_after_stemming_en: List[str] = stem_tokens_en(
  tokens = example_lemmatized_en,
  stemmer = stemmer,
)
print(f"example of tokens before stemming: {example_lemmatized_en}")
print(f"example of tokens after stemming:  {example_tokens_after_stemming_en}")

example of tokens before stemming: ['thank', 'you', 'dm', 'I', 'your', 'name', 'address', 'contact', 'detail', 'and', 'phone', 'number', 'associate', 'with', 'your', 'account', 'thank', 'Emir', 'BG', 'HDG', 'HDG', 'lc', 'HDG', 'mobile', 'mobile', 'regard']
example of tokens after stemming:  ['thank', 'you', 'dm', 'I', 'your', 'name', 'address', 'contact', 'detail', 'and', 'phone', 'number', 'associ', 'with', 'your', 'account', 'thank', 'emir', 'BG', 'hdg', 'hdg', 'lc', 'hdg', 'mobil', 'mobil', 'regard']


In [23]:
# Stop Words Removal
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

nltk_stop_words_list: List[str] = stopwords.words('english')
nltk_stop_words_set: Set[str] = set(nltk_stop_words_list)

def remove_stop_words_en(tokens: List[str], stop_words: Dict[str, Any]) -> List[str]:
  tokens_without_stop_words: List[str] = [
      token
      for token in tokens
      if token not in stop_words
  ]
  return tokens_without_stop_words

example_tokens_without_stop_words_en: List[str] = remove_stop_words_en(
  tokens = example_lemmatized_en,
  stop_words = nltk_stop_words_set,
)

print(f"stop words from NLTK: {nltk_stop_words_set}\n")
print(f"example of tokens with stop words (before):   {example_lemmatized_en}")
print(f"example of tokens without stop words (after): {example_tokens_without_stop_words_en}")

[nltk_data] Downloading package stopwords to /usr/share/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
stop words from NLTK: {'over', "weren't", 'is', 'with', 'are', 'after', 'own', "you'd", 'no', 'theirs', 'down', 'because', 'yourself', 'on', 'couldn', 'wouldn', 'isn', 'what', 'hasn', 'did', 'she', 'too', 'before', 'for', 'shan', "wasn't", 'being', 'this', 'which', "that'll", 'doing', 'having', "you'll", "you're", 'some', "should've", 'ma', 'won', 'of', 'same', "it's", 'shouldn', 'more', 'his', "mightn't", "hasn't", 'do', 'you', 'haven', 'why', 'o', 'doesn', 'didn', 'can', 'those', "don't", 'hers', 'that', 'were', 'once', "doesn't", 'few', "mustn't", 'further', 'out', 'am', 'himself', 'mustn', 'by', 'each', 'mightn', "shouldn't", 'off', 'don', 'weren', "she's", "haven't", 'will', 'below', 'above', 'our', 'whom', 'there', 'a', 'hadn', 'itself', 'should', 'just', 'themselves', 'had', 'these', 'all', 'aren', 'at', 'here', 'when', 'their', 'than', 'he', 'any', 'yours'

In [24]:
def join_words_en(tokens: List[str]) -> str:
    words: str = ' '.join(tokens)
    return words

example_words_without_stop_words_en: List[str] = join_words_en(
  tokens = example_tokens_without_stop_words_en,
)

print(f"stop words from NLTK: {nltk_stop_words_set}\n")
print(f"example of tokens with stop words (before):   {example_tokens_without_stop_words_en}")
print(f"example of tokens without stop words (after): {example_words_without_stop_words_en}")

stop words from NLTK: {'over', "weren't", 'is', 'with', 'are', 'after', 'own', "you'd", 'no', 'theirs', 'down', 'because', 'yourself', 'on', 'couldn', 'wouldn', 'isn', 'what', 'hasn', 'did', 'she', 'too', 'before', 'for', 'shan', "wasn't", 'being', 'this', 'which', "that'll", 'doing', 'having', "you'll", "you're", 'some', "should've", 'ma', 'won', 'of', 'same', "it's", 'shouldn', 'more', 'his', "mightn't", "hasn't", 'do', 'you', 'haven', 'why', 'o', 'doesn', 'didn', 'can', 'those', "don't", 'hers', 'that', 'were', 'once', "doesn't", 'few', "mustn't", 'further', 'out', 'am', 'himself', 'mustn', 'by', 'each', 'mightn', "shouldn't", 'off', 'don', 'weren', "she's", "haven't", 'will', 'below', 'above', 'our', 'whom', 'there', 'a', 'hadn', 'itself', 'should', 'just', 'themselves', 'had', 'these', 'all', 'aren', 'at', 'here', 'when', 'their', 'than', 'he', 'any', 'yours', 'its', 'him', 'if', 'her', 'yourselves', 'it', 'other', 'so', 'in', 'during', 'nor', "aren't", 'who', 'm', "shan't", 'we',

In [25]:
# Define your preprocessing pipeline as a function
def preprocess_text_into_tokens_en(text: str,
                                   tokenizer_pattern: str,
                                   nlp) -> str:
  tokens: List[str] = tokenize_text_en(
    text = text,
    tokenizer_pattern = tokenizer_pattern,
  )
  tokens: List[str] = lemmatize_tokens_en(
    tokens = tokens,
    nlp = nlp,
  )
  tokens: List[str] = remove_stop_words_en(
    tokens = tokens,
    stop_words = nltk_stop_words_set,
  )
  words: str = join_words_en(
      tokens = tokens,
  )
  return words

def preprocess_answer(text: str) -> str:
    text = re.sub(r'[\^*\\~-].*', '', text)
    return text

# Apply preprocess to all data
dataset_preprocessed_en: Dataset = dataset_raw.map(
  lambda row: dict(
    question = preprocess_text_into_tokens_en(
      text = row['question'],
      tokenizer_pattern = tokenizer_pattern,
      nlp = nlp,
    ),
    answer = preprocess_answer(
        text = row['answer']
    ),
  ),
  num_proc = (multiprocessing.cpu_count()),
)

print(f"preprocessed dataset: {dataset_preprocessed_en}")


Map (num_proc=4):   0%|          | 0/20000 [00:00<?, ? examples/s]

preprocessed dataset: Dataset({
    features: ['question', 'answer', 'author_id'],
    num_rows: 20000
})


In [26]:

# Specify the size of your evaluation set. For example, 0.1 means 10% of the data will be used for evaluation.
eval_size = 0.025

num_eval_samples = int(eval_size * len(dataset_preprocessed_en))
num_test_samples = len(dataset_preprocessed_en) - int(eval_size * len(dataset_preprocessed_en))

# Split the dataset
dataset_dict = DatasetDict({
    'test' : dataset_preprocessed_en.select(range(num_test_samples, len(dataset_preprocessed_en))),
    'train': dataset_preprocessed_en.select(range(num_eval_samples, num_test_samples)),
    'eval': dataset_preprocessed_en.select(range(num_eval_samples))
})

print(f"Number of samples in the training set: {len(dataset_dict['train'])}")
print(f"Number of samples in the evaluation set: {len(dataset_dict['eval'])}")


Number of samples in the training set: 19000
Number of samples in the evaluation set: 500


In [27]:
question_list = dataset_dict['test']['question']
answer_list = dataset_dict['test']['answer']

In [28]:
dataset_dict['train']: Dataset = dataset_dict['train'].map(
    lambda row: dict(text=instruction_format.format(
        **row
    )),
    remove_columns=['question', 'answer', 'author_id'],
    num_proc = (multiprocessing.cpu_count()),
)
dataset_dict['test']: Dataset = dataset_dict['test'].map(
    lambda row: dict(text=instruction_format.format(
        **row
    )),
    remove_columns=['question', 'answer', 'author_id'],
)
dataset_dict['eval']: Dataset = dataset_dict['eval'].map(
    lambda row: dict(text=instruction_format.format(
        **row
    )),
    remove_columns=['question', 'answer', 'author_id'],
)
print(f"dataset_text: {dataset_dict['eval']}")
print(f"example text: \n {dataset_dict['eval'][3]['text']}")
print(f"example text: \n {dataset_dict['eval'][13]['text']}")

Map (num_proc=4):   0%|          | 0/19000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

dataset_text: Dataset({
    features: ['text'],
    num_rows: 500
})
example text: 
 Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Question:
tf wrong keyboard

### Answer:
 Fill us in on what is happening, then we can help out from there.
example text: 
 Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Question:
HOME BUTTON DOESNT work IOS11

### Answer:
 Let us help with your Home button. Did this issue start right after iOS 11? Which version of iOS 11 are you running?


In [62]:
print(f"dataset_text: {dataset_dict['eval']}")
print(f"example text: \n {dataset_dict['eval'][3]['text']}")
print(f"example text: \n {dataset_dict['eval'][13]['text']}")

dataset_text: Dataset({
    features: ['text'],
    num_rows: 500
})
example text: 
 Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Question:
tf wrong keyboard

### Answer:
 Fill us in on what is happening, then we can help out from there.
example text: 
 Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Question:
HOME BUTTON DOESNT work IOS11

### Answer:
 Let us help with your Home button. Did this issue start right after iOS 11? Which version of iOS 11 are you running?


In [29]:
def tokenize_row(row: Dict[str, Any]) -> Dict[str, Any]:
  return tokenizer(row['text'])

dataset_preprocessed_train: Dataset = dataset_dict['train'].map(
    tokenize_row,
    remove_columns=['text'],  # Anda mungkin perlu menghapus kolom 'target' juga
    num_proc = (multiprocessing.cpu_count()),
)
dataset_preprocessed_eval: Dataset = dataset_dict['eval'].map(
    tokenize_row,
    remove_columns=['text'],  # Anda mungkin perlu menghapus kolom 'target' juga
    
)
dataset_preprocessed_test: Dataset = dataset_dict['test'].map(
    tokenize_row,
    remove_columns=['text'],  # Anda mungkin perlu menghapus kolom 'target' juga
    
)

dataset_preprocessed_train

Map (num_proc=4):   0%|          | 0/19000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 19000
})

In [44]:
def tokenize_row(row: Dict[str, Any]) -> Dict[str, Any]:
    # Misalkan 'text' adalah teks sumber dan 'target' adalah teks target
    encoding = tokenizer(row['question'], truncation=True, padding='max_length', max_length=128)
    # Anda perlu menambahkan 'labels' ke dalam encoding
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(row['answer'], truncation=True, padding='max_length', max_length=128)['input_ids']
    encoding['labels'] = labels
    return encoding

dataset_preprocessed_train: Dataset = dataset_dict['train'].map(
    tokenize_row,
    remove_columns=['question', 'answer', 'author_id'],  # Anda mungkin perlu menghapus kolom 'target' juga
    num_proc = (multiprocessing.cpu_count()),
)
dataset_preprocessed_eval: Dataset = dataset_dict['eval'].map(
    tokenize_row,
    remove_columns=['question', 'answer', 'author_id'],  # Anda mungkin perlu menghapus kolom 'target' juga
    num_proc = (multiprocessing.cpu_count()),
)
dataset_preprocessed_test: Dataset = dataset_dict['test'].map(
    tokenize_row,
    remove_columns=['question', 'answer', 'author_id'],  # Anda mungkin perlu menghapus kolom 'target' juga
    num_proc = (multiprocessing.cpu_count()),
)
dataset_preprocessed_train

     

#0:   0%|          | 0/4750 [00:00<?, ?ex/s]



 

#1:   0%|          | 0/4750 [00:00<?, ?ex/s]



 

#2:   0%|          | 0/4750 [00:00<?, ?ex/s]



 

#3:   0%|          | 0/4750 [00:00<?, ?ex/s]



     

#0:   0%|          | 0/125 [00:00<?, ?ex/s]



 

#1:   0%|          | 0/125 [00:00<?, ?ex/s]



 

#2:   0%|          | 0/125 [00:00<?, ?ex/s]

 



#3:   0%|          | 0/125 [00:00<?, ?ex/s]



     

#0:   0%|          | 0/125 [00:00<?, ?ex/s]



  

#1:   0%|          | 0/125 [00:00<?, ?ex/s]



#2:   0%|          | 0/125 [00:00<?, ?ex/s]



 

#3:   0%|          | 0/125 [00:00<?, ?ex/s]



Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 19000
})

## Supervised Training

Although the model we load previously already pre-trained and capable enough for language modeling, we will try to illustrate how to further train the model with supervised causal language modeling objective.

In [30]:
!pip install evaluate
!pip install rouge_score

Collecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)
Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.2
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=08977021630ede0ec97eb1477bfd6402449f0f15bf41363f9e2b2376d842743d
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [30]:
references = []
for i in answer_list:
    x = []
    x.append(i)
    references.append(x)

In [31]:
import evaluate
def evaluate_bleu_rouge(generate_type):
    predictions = []
    for q in question_list:
        try:
            answer = generate_type(q)
        except:
            answer = ''
        predictions.append(generate_type(q))
    bleu = evaluate.load('bleu')
    results_bleu = bleu.compute(predictions=predictions, references=answer_list)
    rouge = evaluate.load('rouge')
    results_rouge = rouge.compute(predictions=predictions, references=answer_list)
    return (results_bleu, results_rouge)

In [32]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [88]:
model = AutoModelForSeq2SeqLM.from_pretrained("/kaggle/working/facebook/bart-base-stackexchange/checkpoint-1000")
model.to(device)

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [33]:
print('gpt2-causal')
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir=f"{model_id}-stackexchange",
    per_device_train_batch_size=2, # important, to prevent GPU OOM
    gradient_accumulation_steps=8,
    num_train_epochs=num_train_epochs,
    push_to_hub=False,
    save_strategy='no'
)
# for more args, visit:
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_preprocessed_train,
    eval_dataset=dataset_preprocessed_eval,
    data_collator=data_collator,
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


gpt2-causal


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,1.9021
1000,1.7158
1500,1.6286
2000,1.5873
2500,1.5604
3000,1.5245
3500,1.5248


TrainOutput(global_step=3561, training_loss=1.6328970039536932, metrics={'train_runtime': 1235.6954, 'train_samples_per_second': 46.128, 'train_steps_per_second': 2.882, 'total_flos': 2103782588928000.0, 'train_loss': 1.6328970039536932, 'epoch': 3.0})

In [64]:
print('distilgpt2-causal')
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir=f"{model_id}-stackexchange",
    per_device_train_batch_size=2, # important, to prevent GPU OOM
    gradient_accumulation_steps=8,
    num_train_epochs=num_train_epochs,
    push_to_hub=False,
    save_strategy='no'
)
# for more args, visit:
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_preprocessed_train,
    eval_dataset=dataset_preprocessed_eval,
    data_collator=data_collator,
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


distilgpt2-causal


Step,Training Loss
500,2.0164
1000,1.8008
1500,1.7171
2000,1.6782
2500,1.6488
3000,1.6205
3500,1.62


TrainOutput(global_step=3561, training_loss=1.7269465707052092, metrics={'train_runtime': 791.0548, 'train_samples_per_second': 72.056, 'train_steps_per_second': 4.502, 'total_flos': 1051910290243584.0, 'train_loss': 1.7269465707052092, 'epoch': 3.0})

In [56]:
print('bart-causal')
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir=f"{model_id}-stackexchange",
    per_device_train_batch_size=2, # important, to prevent GPU OOM
    gradient_accumulation_steps=8,
    num_train_epochs=num_train_epochs,
    push_to_hub=False,
    save_strategy='no'
)
# for more args, visit:
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_preprocessed_train,
    eval_dataset=dataset_preprocessed_eval,
    data_collator=data_collator,
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


bart-causal


Step,Training Loss
500,0.6788
1000,0.0015
1500,0.0005
2000,0.0002
2500,0.0002
3000,0.0001
3500,0.0001


TrainOutput(global_step=3561, training_loss=0.09567870958495077, metrics={'train_runtime': 791.232, 'train_samples_per_second': 72.04, 'train_steps_per_second': 4.501, 'total_flos': 1441508000919552.0, 'train_loss': 0.09567870958495077, 'epoch': 3.0})

In [45]:
print('t5-seq')
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir=f"{model_id}-stackexchange",
    per_device_train_batch_size=2, # important, to prevent GPU OOM
    gradient_accumulation_steps=8,
    num_train_epochs=num_train_epochs,
    push_to_hub=False,
    save_strategy='no'
)
# for more args, visit:
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_preprocessed_train,
    eval_dataset=dataset_preprocessed_eval,
    data_collator=data_collator,
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


t5-seq


Step,Training Loss
500,0.5516
1000,0.4541
1500,0.4302
2000,0.4225
2500,0.4108
3000,0.4059
3500,0.4063


TrainOutput(global_step=3561, training_loss=0.4396354109676525, metrics={'train_runtime': 3089.8389, 'train_samples_per_second': 18.448, 'train_steps_per_second': 1.152, 'total_flos': 8673996193136640.0, 'train_loss': 0.4396354109676525, 'epoch': 3.0})

In [38]:
print('gpt2-causal')
print(evaluate_bleu_rouge(generate_text_sampling_top_p_nucleus))

gpt2-causal


Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

({'bleu': 0.03260060193646305, 'precisions': [0.16569327276660178, 0.041629613161405066, 0.01932311914017837, 0.00847457627118644], 'brevity_penalty': 1.0, 'length_ratio': 1.7815028901734105, 'translation_length': 18492, 'reference_length': 10380}, {'rouge1': 0.21745952049576406, 'rouge2': 0.05818402692277768, 'rougeL': 0.17028620176287035, 'rougeLsum': 0.1703564483676806})


In [89]:
print('distilgpt2-causal')
print(evaluate_bleu_rouge(generate_text_sampling_top_p_nucleus))

distilgpt2-causal
({'bleu': 0.02653391004659951, 'precisions': [0.1490657793703609, 0.03451536643026005, 0.015376315079579175, 0.0062655946770169115], 'brevity_penalty': 1.0, 'length_ratio': 1.8819845857418112, 'translation_length': 19535, 'reference_length': 10380}, {'rouge1': 0.20421929905778385, 'rouge2': 0.05005163310237155, 'rougeL': 0.1653422985968551, 'rougeLsum': 0.165160212540932})


In [58]:
print('bard-causal')
print(evaluate_bleu_rouge(generate_text_sampling_top_p_nucleus))

bard-causal
({'bleu': 0.0, 'precisions': [0.013783719829626125, 0.000510387894800048, 0.0, 0.0], 'brevity_penalty': 1.0, 'length_ratio': 3.2570327552986513, 'translation_length': 33808, 'reference_length': 10380}, {'rouge1': 0.02601676053349125, 'rouge2': 0.003554704064793692, 'rougeL': 0.022658611281319237, 'rougeLsum': 0.0225093789638791})


In [48]:
print('t5')
print(evaluate_bleu_rouge(generate_text_sampling_top_p_nucleus))

t5
({'bleu': 0.03421938117831349, 'precisions': [0.27009222661396576, 0.06304654442877292, 0.028831562974203338, 0.012151067323481117], 'brevity_penalty': 0.692401746839994, 'length_ratio': 0.7312138728323699, 'translation_length': 7590, 'reference_length': 10380}, {'rouge1': 0.23009494851230786, 'rouge2': 0.05602997634191734, 'rougeL': 0.1925131128468977, 'rougeLsum': 0.1927995410330124})


In [34]:
def preprocess_answer(text):
  split_output = text.split('### Answer:')
  answer_split = []
  if len(split_output)>=2:
      answer_split = split_output[1].split(' ')
  answer = ''
  index=0
  for i in answer_split:
      if i!='':
          index+=1
      if index > 2 and i=='':
          break
      else:
          answer = answer+i+' '
  answer = answer.split()
  answer = ' '.join(answer)
  return answer

In [35]:
set_seed(87)
def generate_text_sampling_top_p_nucleus(
    input_prompt: str,
    min_length: str = 10,
    max_length: int = 100,
    top_p: float = 0.22,
  ) -> str:
  input_prompt: str = instruction_format.format(
      question=input_prompt,
      answer='',
  )
  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  sampling_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      do_sample=True,
      pad_token_id=50256,
      top_p=top_p,
      top_k=0,
  )
  sampling_output_text: str = tokenizer.decode(sampling_output_tensor[0], skip_special_tokens=True)
  answer = preprocess_answer(sampling_output_text)
  return answer

In [37]:
print('gpt2-causal')
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(question_list[i])}')
    print(f'expected answer: {answer_list[i]}')
    print('++++++++++++++++++')

gpt2-causal
question: yooooo wtf wrong I
answer: We'd like to help. What device are you experiencing this on? What's the exact iOS version?
expected answer:  We can check it out. Which iPhone do you own? Do you have iOS 11.1 installed?
++++++++++++++++++
question: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
answer: We're happy to help. What version of iOS 11 are you currently running? You can check in Settings &gt; General &gt; About.
expected answer:  That's correct. When disabling Airplane mode, Bluetooth &amp; Wi
++++++++++++++++++
question: pop every 5 minute really annoying
answer: We're here to help. Which device are you using? Have you tried restarting your device? Have you tried any steps so far? If not, let's do that now. Let us know if the issue persists. Let us know if we can help.
expected answer:  We know messages like this can be confusing. Let's chat over in Direct Message about the beta alert. 
++++++++++++++++++
questi

In [88]:
print('distilgpt2-causal')
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(question_list[i])}')
    print(f'expected answer: {answer_list[i]}')
    print('++++++++++++++++++')

distilgpt2-causal
question: yooooo wtf wrong I
answer: We'd like to help. What's happening with your iPhone? Which iOS version is installed on your iPhone? Also, what's the iOS version? Also, what's the exact iOS version installed? Also, what's the exact version number installed? Also, what's the exact iOS version installed? Also, what
expected answer:  We can check it out. Which iPhone do you own? Do you have iOS 11.1 installed?
++++++++++++++++++
question: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
answer: We're happy to help. What's going on with your iPhone? Are you using the AirDrop app or a third party app? Also, what's the exact iOS version installed?
expected answer:  That's correct. When disabling Airplane mode, Bluetooth &amp; Wi
++++++++++++++++++
question: pop every 5 minute really annoying
answer: We're here to help. Which device and iOS version are you using? What happens when you try to open the app? Also, what's the ex

In [59]:
print('bart-causal')
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(question_list[i])}')
    print(f'expected answer: {answer_list[i]}')
    print('++++++++++++++++++')

bart-causal
question: yooooo wtf wrong I
answer: yooooo wtf wrong I having having having Write Write Write Answer Answer Answer instruction instruction instruction with with with Just Just Just scra scra scra kg kg kgapproapproappro helps helps helps will will will ver ver ver cover cover cover vaping vaping vaping low low lowxfxfxf scal scal scal dynamically dynamically dynamically Cannot Cannot Cannotvariablevariablevariable column column column 6 6 6 move move moveminationminationminationogieogieogieortunortunortun202020cgicgicgi receive receive receive click click click
expected answer:  We can check it out. Which iPhone do you own? Do you have iOS 11.1 installed?
++++++++++++++++++
question: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
answer: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
expected answer:  That's correct. When disabling Airplane mode, Bluetooth &amp; Wi
+++++++++++++

In [47]:
print('t5-seq')
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(question_list[i])}')
    print(f'expected answer: {answer_list[i]}')
    print('++++++++++++++++++')

t5-seq
question: yooooo wtf wrong I
answer: We'd be happy to look into this with you. What version of iOS are you running?
expected answer:  We can check it out. Which iPhone do you own? Do you have iOS 11.1 installed?
++++++++++++++++++
question: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
answer: We'd be happy to help. Which device are you using? Are you getting any errors?
expected answer:  That's correct. When disabling Airplane mode, Bluetooth &amp; Wi
++++++++++++++++++
question: pop every 5 minute really annoying
answer: We're here to help. Are you using the same app on both devices?
expected answer:  We know messages like this can be confusing. Let's chat over in Direct Message about the beta alert. 
++++++++++++++++++
question: I I go trip I usaly download line use
answer: We'd love to help. Which version of iOS are you using?
expected answer:  No worries! You'll be able to watch the movie. Your device will play it at the high

In [36]:
print('bard-seq')
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(question_list[i])}')
    print(f'expected answer: {answer_list[i]}')
    print('++++++++++++++++++')

bard
question: yooooo wtf wrong I
answer:  We'd like to help. What's going on with your iPhone?
expected answer:  We can check it out. Which iPhone do you own? Do you have iOS 11.1 installed?
++++++++++++++++++
question: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
answer:  Thanks for reaching out to us. We'd like to look into this with you. Can you tell us which version of iOS 11 you're using?
expected answer:  That's correct. When disabling Airplane mode, Bluetooth &amp; Wi
++++++++++++++++++
question: pop every 5 minute really annoying
answer:  We'd like to help. Which iPhone and iOS version are you using?
expected answer:  We know messages like this can be confusing. Let's chat over in Direct Message about the beta alert. 
++++++++++++++++++
question: I I go trip I usaly download line use
answer:  We'd like to help. What happens when you try to download the download?
expected answer:  No worries! You'll be able to watch the movie. Y

In [90]:
trainer.evaluate(dataset_preprocessed_eval)

{'eval_loss': 0.5648128986358643,
 'eval_runtime': 3.718,
 'eval_samples_per_second': 140.665,
 'eval_steps_per_second': 17.751,
 'epoch': 3.0}

In [68]:
trainer.evaluate(dataset_preprocessed_eval)

{'eval_loss': 0.5641575455665588,
 'eval_runtime': 5.7748,
 'eval_samples_per_second': 90.565,
 'eval_steps_per_second': 11.429,
 'epoch': 1.0}

## Inference

Let's try to use our model to generate some text. You can edit and try other prompts.

In [None]:
model()

In [91]:
def generate_text_by_instruction(question: str, min_length: str = 10, max_length: int = 1000) -> str:
  encoded_input: BatchEncoding = tokenizer(question, return_tensors='pt').to(device)
  output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      repetition_penalty=2.0,
  )
  output_text: str = tokenizer.batch_decode(output_tensor, skip_special_tokens=True)[0]
  return output_text

In [42]:
for i in [0,2,5,27,94]:
    print(f'question {question_list[i]}')
    print(f'answer {answer_list[i]}')
    print('++++++++')

question I attend attbizsummit week come see presentation
answer  Glad you are getting excited! Be sure and give us a sneak peek behind the curtain at the #ATTBizSummit.
++++++++
question thank Amber return charter spectrum stuff tomorrow make change account seeyousoon
answer  Awesome Megan! Please send your telephone number and address, and we'll take care of you. Awaiting your reply! Thank you.
++++++++
question thank agent provide incorrect information go wrong store still problem persist att attwireless fail apple iphonex
answer  Hello Frank! We don't like to see that the agent provided you with the incorrect information.Let's fix the problem now once and for all. Is there something we can help you with? We look forward to hearing back from you today. 
++++++++
question please PUT ON YOUTUBE INTERNATIONAL fan CAN SEE
answer  Visit  for a sneak peek of the first episode about the making of Gorgeous. #TaylorSwiftNOW
++++++++
question get wifi instal morning I already problem anyone k

In [35]:
print('bard')
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(question_list[i])}')
    print(f'expected answer: {answer_list[i]}')
    print('++++++++++++++++++')

question: yooooo wtf wrong I
answer:  We're here to help. What's going on with your iPhone?
expected answer:  We can check it out. Which iPhone do you own? Do you have iOS 11.1 installed?
++++++++++++++++++
question: thx show I obvious however Blutooth still automatically turn soon I change Flightmode iOS 11 1
answer:  We'd like to help get you to the right spot for help. Reach out to us here: 
expected answer:  That's correct. When disabling Airplane mode, Bluetooth &amp; Wi
++++++++++++++++++
question: pop every 5 minute really annoying
answer:  We're here to help. Which iOS version are you running on your iPhone? Also, does restarting the device help at all?
expected answer:  We know messages like this can be confusing. Let's chat over in Direct Message about the beta alert. 
++++++++++++++++++
question: I I go trip I usaly download line use
answer:  We're here to help. Check out this article and let us know if it helps: 
expected answer:  No worries! You'll be able to watch the mov

Note: The generated text will vary across every run, but in general, we can see that GPT2 model is better at generating text about book content (such as The Lean Startup) compared to factual fact (such as Albert Einstein). This is mostly due to the fact that GPT2 was pre-trained using [BookCorpus](https://en.wikipedia.org/wiki/BookCorpus) and skipped the Wikipedia corpus.

In [46]:
print(generate_text_by_instruction(question='a'))

 I'm sorry to hear about this. Please let me know if there is anything else I can do to help. 


## Decoding Methods for Text Generation

This section will describe several decoding methods we can use for text generation.

In [47]:
# Some variables we will use for decoding
prompt_for_decoding: str = "Thank you  for listening and helping when I had an issue today. Great customer service."
random_seed: int = 49

set_seed(random_seed)

### Greedy Search

Greedy search is the most commonly used decoding method. Simply choose the word with the highest probability at every timestep.

![Greedy Search](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/greedy_search.png)

The problem with greedy search is that, like any greedy algorithm, it might miss a cumulatively better trajectory.

In [48]:
def generate_text_greedy(input_prompt: str, min_length: str = 10, max_length: int = 1000) -> str:
  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  greedy_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
  )
  greedy_output_text: str = tokenizer.batch_decode(greedy_output_tensor, skip_special_tokens=True)[0]
  return greedy_output_text


In [49]:
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_greedy(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

question: I attend attbizsummit week come see presentation
answer:  We are here for you. Please let us know if you have any questions. 
++++++++++++++++++
question: thank Amber return charter spectrum stuff tomorrow make change account seeyousoon
answer:  Thanks for the feedback. We will share your feedback with the relevant team. 
++++++++++++++++++
question: thank agent provide incorrect information go wrong store still problem persist att attwireless fail apple iphonex
answer:  We're here to help! Please let us know if there's anything we can assist you with. 
++++++++++++++++++
question: please PUT ON YOUTUBE INTERNATIONAL fan CAN SEE
answer:  We're sorry for the trouble. Can you send us a screenshot of what you're seeing? 
++++++++++++++++++
question: get wifi instal morning I already problem anyone know well internet provider
answer:  We're here to help! Please send us a DM with your email address at  and we'll be happy to help. 
++++++++++++++++++


### Beam Search

Beam search attempt to reduce the chance of missing out hidden high probability trajectory (word sequences) by enumerating every `num_beams` depth hypothesis in each timestep, and choose the one with highest overall probability.

The illustration below describe beam search with `num_beams=2`:

![Beam search](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/beam_search.png)

In the code, we simply add `num_beams` parameter.

In [44]:
def generate_text_beam(input_prompt: str, min_length: str = 1, max_length: int = 100, num_beams: int = 5) -> str:

  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  beam_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      num_beams=num_beams,
  )
  beam_output_text: str = tokenizer.batch_decode(beam_output_tensor, skip_special_tokens=True)[0]
  return beam_output_text


In [45]:
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_beam(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

question: look like
answer: We'd like to look into this with you. Can you tell us which device you're using?
++++++++++++++++++
question: I need answer annoying
answer: We'd like to look into this with you. Which iOS version are you using?
++++++++++++++++++
question: thank I update phone even slow barely work thank ruin phone
answer: We'd like to look into this with you. Which iOS version are you using?
++++++++++++++++++
question: app fuccin download update
answer: We'd like to look into this with you. Which version of iOS are you using?
++++++++++++++++++
question: I buy itune gift card worth 15 week ago email still come inbox tell I code helpppp
answer: We'd like to help. Which iOS version are you using?
++++++++++++++++++


The problem with beam search is the existence of repetition of the same word sequences. This is because the nature of probability of common word sequence is very high.

One of the solution for this problem is to give penalty to n-grams (sequence of n words), which is introduced by [Paulus et al. (2017)](https://arxiv.org/abs/1705.04304) and [Klein et al. (2017)](https://arxiv.org/abs/1701.02810). In the code, we simpy use `no_repeat_ngram_size` parameter.

In [46]:
def generate_text_beam_no_repeat(
    input_prompt: str,
    min_length: str = 5,
    max_length: int = 100,
    num_beams: int = 5,
    no_repeat_ngram_size: int = 2,
  ) -> str:
  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  beam_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      num_beams=num_beams,
      no_repeat_ngram_size=no_repeat_ngram_size,
  )
  beam_output_text: str = tokenizer.batch_decode(beam_output_tensor, skip_special_tokens=True)[0]
  return beam_output_text


In [47]:
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_beam_no_repeat(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

question: look like
answer: We'd like to look into this with you. Can you tell us which device you're using?
++++++++++++++++++
question: I need answer annoying
answer: We'd like to look into this with you. Which iOS version are you using?
++++++++++++++++++
question: thank I update phone even slow barely work thank ruin phone
answer: We'd like to look into this with you. Which iOS version are you using?
++++++++++++++++++
question: app fuccin download update
answer: We'd like to look into this with you. Which version of iOS are you using?
++++++++++++++++++
question: I buy itune gift card worth 15 week ago email still come inbox tell I code helpppp
answer: We'd like to help. Which iOS version are you using?
++++++++++++++++++


### Sampling

Sampling simply randomly choose the next word from all probability distribution. The text generation will be non-deterministic.

![vanilla_sampling](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/sampling_search.png)

In the code, use `do_sample=True` to enable sampling.

In [48]:
set_seed(random_seed)
def generate_text_sampling(
    input_prompt: str,
    min_length: str = 10,
    max_length: int = 100,
  ) -> str:
  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  sampling_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      do_sample=True,
      top_k=0,
  )
  sampling_output_text: str = tokenizer.batch_decode(sampling_output_tensor, skip_special_tokens=True)[0]
  return sampling_output_text



In [49]:
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

question: look like
answer: We want to help. Would you mind at any point doing something like this?
++++++++++++++++++
question: I need answer annoying
answer: Try following the steps below to Respond. How long is 24h fielding this as well as keystrokes and is it all the distance?
++++++++++++++++++
question: thank I update phone even slow barely work thank ruin phone
answer: We'd like you to consider our iPhone upgrade, How long has this since you updated? Also, do you see specific symptoms in the settings?
++++++++++++++++++
question: app fuccin download update
answer: Good question. Might still seem to be out of date. After asking for support from the company, there's a better solution in high school. Just wanted to remind you how to administer your app?
++++++++++++++++++
question: I buy itune gift card worth 15 week ago email still come inbox tell I code helpppp
answer: Let's find out what is going on. Which device and version of iOS are you using?
++++++++++++++++++


In [40]:
!pip install googletrans==4.0.0-rc1

Collecting googletrans==4.0.0-rc1
  Downloading googletrans-4.0.0rc1.tar.gz (20 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting httpx==0.13.3 (from googletrans==4.0.0-rc1)
  Downloading httpx-0.13.3-py3-none-any.whl (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.1/55.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting hstspreload (from httpx==0.13.3->googletrans==4.0.0-rc1)
  Downloading hstspreload-2024.2.1-py3-none-any.whl.metadata (2.1 kB)
Collecting chardet==3.* (from httpx==0.13.3->googletrans==4.0.0-rc1)
  Downloading chardet-3.0.4-py2.py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.4/133.4 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting idna==2.* (from httpx==0.13.3->googletrans==4.0.0-rc1)
  Downloading idna-2.10-py2.py3-none-any.whl (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25

In [41]:
from googletrans import Translator

In [44]:
translator=Translator()
per = ['kenapa ini internet lag?',
      'mau berapa lama lagi mati lampu ini?',
      'aku sangat lelah']
for i in per:
    pertanyaan = translator.translate(i, dest='en').text
    print(f'question: {i}')
    jawaban = translator.translate(generate_text_sampling(input_prompt=pertanyaan), dest='id').text
    print(f'answer: {jawaban}')
    print('++++++++++++++++++')

question: kenapa ini internet lag?
answer: Bisakah Anda menyenangkan PMD untuk membantu dan mendukung Anda di sana?Mari kita coba melanggar ini saat istirahat dan lihat apa yang terjadi.
++++++++++++++++++
question: mau berapa lama lagi mati lampu ini?
answer: Hai Uesha, kami benci melihat orang -orang mengemudi ke dalamnya, apakah kami mengirimnya?Jadi tentang mereka di Australia.
++++++++++++++++++
question: aku sangat lelah
answer: Hai Peter.Saya menyesal mendengar Anda lelah.
++++++++++++++++++


In [None]:
from googletrans import Translator

text=("How to convert some text to multiple languages")
destination_language = {
    "Spanish": "es",
    "Chinese":"zh-CN",
    "Italian":"it"
}
translator=Translator()
for key, value in destination_language.items():
    print(translator.translate(text, dest=value).text)

In [None]:
for i in [123,1992,8334,1406,1958]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

Because of the randomness of sampling method, the generated text, in most cases, is not very coherent and often generate incoherent gibberish, *cf.* [Ari Holtzman et al. (2019)](https://arxiv.org/abs/1904.09751).

### Top-K Sampling

Top-K Sampling method, introduced by [Fan et. al (2018)](https://arxiv.org/pdf/1805.04833.pdf), take K most likely words, redistribute the probability, and take the next word from those K words.

The following illustrate top-6 sampling:
![top_k_sampling](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/top_k_sampling.png)

To use top-K sampling in the code, set the `top_k` parameter.

In [93]:
set_seed(87)
def generate_text_sampling_top_k(
    input_prompt: str,
    min_length: str = 5,
    max_length: int = 100,
    top_k: int = 50,
  ) -> str:
  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  sampling_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      do_sample=True,
      top_k=top_k,
  )
  sampling_output_text: str = tokenizer.batch_decode(sampling_output_tensor, skip_special_tokens=True)[0]
  return sampling_output_text


In [70]:
print(evaluate_bleu_rouge(generate_text_sampling_top_k))

({'bleu': 0.002940595919354412, 'precisions': [0.14511098880939277, 0.007900568455535216, 0.0006087662337662338, 0.00010713520462824084], 'brevity_penalty': 1.0, 'length_ratio': 1.0412607449856734, 'translation_length': 10902, 'reference_length': 10470}, {'rouge1': 0.11885415419743303, 'rouge2': 0.007510586134077362, 'rougeL': 0.09568029128045283, 'rougeLsum': 0.0960208776544145})


In [94]:
print(evaluate_bleu_rouge(generate_text_sampling_top_k))

({'bleu': 0.013012794616588214, 'precisions': [0.17553348411839906, 0.018453244868339207, 0.0050422010303628195, 0.001976284584980237], 'brevity_penalty': 0.9708340185357985, 'length_ratio': 0.9712511938872971, 'translation_length': 10169, 'reference_length': 10470}, {'rouge1': 0.1483322355622818, 'rouge2': 0.019198926125339008, 'rougeL': 0.12082343331647175, 'rougeLsum': 0.12103604339065499})


In [75]:
predictions = []
for i in question_list:
    predictions.append(generate_text_sampling_top_k(input_prompt=i))


In [76]:
bleu = evaluate.load('bleu')
results = bleu.compute(predictions=predictions, references=references)
print(results)

{'bleu': 0.045400842938890094, 'precisions': [0.27326266195524146, 0.060075093867334166, 0.030040053404539385, 0.0178826895565093], 'brevity_penalty': 0.8331282169616259, 'length_ratio': 0.8456175298804781, 'translation_length': 1698, 'reference_length': 2008}


In [77]:
bleu = evaluate.load('rouge')
results = bleu.compute(predictions=predictions, references=references)
print(results)

{'rouge1': 0.2352147412057099, 'rouge2': 0.06216038809900354, 'rougeL': 0.19587656424067534, 'rougeLsum': 0.19434646196576177}


In [51]:
predictions = []
for i in question_list:
    predictions.append(generate_text_sampling_top_k(input_prompt=i))


In [53]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=0d49d668eb3b548b7d1c9aa31008845e934c024fee64378d58d058d40c58150b
  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [57]:
import evaluate
bleu = evaluate.load('bleu')
results = bleu.compute(predictions=predictions, references=references)
print(results)

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

{'bleu': 0.019608733532141362, 'precisions': [0.19124218051831993, 0.029934518241347054, 0.008341511285574092, 0.0030959752321981426], 'brevity_penalty': 1.0, 'length_ratio': 1.1145418326693226, 'translation_length': 2238, 'reference_length': 2008}


In [99]:
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_k(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

question: I attend attbizsummit week come see presentation
answer:  Thank you for bringing this to our attention. We will forward this to the appropriate teams. 
++++++++++++++++++
question: thank Amber return charter spectrum stuff tomorrow make change account seeyousoon
answer:  We're here to help! Send us a note via  so our team can connect.
++++++++++++++++++
question: thank agent provide incorrect information go wrong store still problem persist att attwireless fail apple iphonex
answer:  We're sorry to hear about the trouble. Please send us a Direct Message, so that we can further assist you. 
++++++++++++++++++
question: please PUT ON YOUTUBE INTERNATIONAL fan CAN SEE
answer:  Hey there! Can you tell us what device, operating system, and Spotify version you're using? We'll see what we can suggest /RS
++++++++++++++++++
question: get wifi instal morning I already problem anyone know well internet provider
answer:  We're here to help anyway we can. Please let us know if there's an

In [61]:
per = ['kenapa ini internet lag?',
      'mau berapa lama lagi mati lampu ini?',
      'menunggu sangat membosankan']
for i in per:
    print(f'question: {i}')
    print(f'answer: {generate_text_sampling_top_k(input_prompt=i)}')
    print('++++++++++++++++++')

question: kenapa ini internet lag?
answer: Hi there, where did you find that article from? If you have seen it you'll get a link for that.
++++++++++++++++++
question: mau berapa lama lagi mati lampu ini?
answer: We appreciate your prayers for our community. Please remember, we do not have all information, we just need to provide contact information. Please get in touch with us for assistance. We will be in touch as soon as we can in case of any difficulties.
++++++++++++++++++
question: menunggu sangat membosankan
answer: You might have been seeking help for a while but luckily here we are to assist with your issue :)
++++++++++++++++++


In [60]:
translator=Translator()
per = ['kenapa ini internet lag?',
      'mau berapa lama lagi mati lampu ini?',
      'menunggu sangat membosankan']
for i in per:
    pertanyaan = translator.translate(i, dest='en').text
    print(f'question: {pertanyaan}')
    jawaban = translator.translate(generate_text_sampling_top_k(input_prompt=pertanyaan), dest='id').text
    print(f'answer: {generate_text_sampling_top_k(input_prompt=pertanyaan)}')
    print('++++++++++++++++++')

NameError: name 'Translator' is not defined

In [46]:
translator=Translator()
per = ['kenapa ini internet lag?',
      'mau berapa lama lagi mati lampu ini?',
      'aku sangat lelah']
for i in per:
    pertanyaan = translator.translate(i, dest='en').text
    print(f'question: {i}')
    jawaban = translator.translate(generate_text_sampling_top_k(input_prompt=pertanyaan), dest='id').text
    print(f'answer: {jawaban}')
    print('++++++++++++++++++')

question: kenapa ini internet lag?
answer: :) Bisakah Anda membagikan tautan situs web kami sehingga kami dapat berbicara lebih dekat tentang masalah Anda?
++++++++++++++++++
question: mau berapa lama lagi mati lampu ini?
answer: Hai.Beri tahu kami seberapa cepat Anda dikirim jika Anda memiliki pertanyaan.Terima kasih!
++++++++++++++++++
question: aku sangat lelah
answer: Saya menyesal mendengar ini!Saya telah berjuang dengan sakit punggung saya selama 2 hingga 3 minggu terakhir.Bagaimana kalau mengejar berita terbaru
++++++++++++++++++


### Top-p (Nucleus) Sampling

Similar to Top-K sampling, Top-p sampling method take set of words in which the cumulative probability exceeds p, and take the next words from that set. With this, the number of candidate words can be dynamically increase and decrease depending on the previous words.

![top_p_sampling](https://github.com/patrickvonplaten/scientific_images/blob/master/top_p_sampling.png?raw=true)

In [83]:
import torch
from transformers import AdamW

# Set the seed for reproducibility
torch.manual_seed(87)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

def generate_text_sampling_top_p_nucleus(input_prompt: str, min_length: int = 10, max_length: int = 100, top_p: float = 0.92) -> str:
    encoded_input = tokenizer(input_prompt, return_tensors='pt', max_length=512, truncation=True).to(device)
    sampling_output_tensor = model.generate(
        input_ids=encoded_input['input_ids'],  # Add this line
        attention_mask=encoded_input['attention_mask'],  # Add this line
        min_length=min_length,
        max_length=max_length,
        do_sample=True,
        top_p=top_p,
        top_k=0,
    )
    sampling_output_text = tokenizer.decode(sampling_output_tensor[0], skip_special_tokens=True)
    return sampling_output_text


def interact_with_user_and_learn():
    reward=0
    # Interaction loop
    while True:
        # Get user input
        user_input = input("You: ")

        # Generate response
        response = generate_text_sampling_top_p_nucleus(user_input)

        # Print response
        print("Bot:", response)

        # Get feedback from user
        feedback = input("Was the response helpful? (yes/no): ")

        # Calculate reward based on user feedback
        if feedback.lower() == 'yes':
            reward += 1
        else:
            reward -= 1

        # Perform backpropagation with reward
        optimizer.zero_grad()
        encoded_user_input = tokenizer(user_input, return_tensors='pt', max_length=512, truncation=True).to(device)
        encoded_response = tokenizer(response, return_tensors='pt', max_length=512, truncation=True).to(device)
        model_input = {key: torch.cat([encoded_user_input[key], encoded_response[key]], dim=1) for key in encoded_user_input}
        output = model(input_ids=model_input['input_ids'], attention_mask=model_input['attention_mask'], decoder_input_ids=model_input['input_ids'])
        loss = -reward * output.logits[:, -1].mean()  # Reinforcement learning loss
        loss.backward()
        optimizer.step()

        # Ask if the user wants to continue
        continue_learning = input("Do you want to continue learning? (yes/no): ")
        if continue_learning.lower() != 'yes':
            break



# Start interaction and learning loop
interact_with_user_and_learn()


You:  thank agent provide incorrect information go wrong store still problem persist att attwireless fail apple iphonex


Bot: :) Could you share what is happening on the device? Thank you!


Was the response helpful? (yes/no):  yes
Do you want to continue learning? (yes/no):  yes
You:  why i cant use my apple iphonex?


Bot: Hi Cathryn! I have been getting the same problem this morning. Have a nice day.


Was the response helpful? (yes/no):  no
Do you want to continue learning? (yes/no):  yes
You:  i dont care about your problem


Bot: We use your new contact form and can look into this as we have documentation


Was the response helpful? (yes/no):  no
Do you want to continue learning? (yes/no):  no


In [96]:
set_seed(87)
def generate_text_sampling_top_p_nucleus(
    input_prompt: str,
    min_length: str = 10,
    max_length: int = 100,
    top_p: float = 0.92,
  ) -> str:
  encoded_input: BatchEncoding = tokenizer(input_prompt, return_tensors='pt').to(device)
  sampling_output_tensor: Tensor = model.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      do_sample=True,
      top_p=top_p,
      top_k=0,
  )
  sampling_output_text: str = tokenizer.decode(sampling_output_tensor[0], skip_special_tokens=True)
  return sampling_output_text

In [None]:
from transformers import T5TokenizerFast, AutoModelForSeq2SeqLM, AutoTokenizer

In [None]:
tok = AutoTokenizer.from_pretrained("t5-base")


In [None]:
mod = AutoModelForSeq2SeqLM.from_pretrained("/kaggle/working/t5-base-stackexchange/checkpoint-2000/")
mod

In [None]:
mod.to(device)

In [None]:
set_seed(87)
def generate_text_sampling_top_p_nucleus_2(
    input_prompt: str,
    min_length: str = 10,
    max_length: int = 100,
    top_p: float = 0.92,
  ) -> str:
  encoded_input: BatchEncoding = tok(input_prompt, return_tensors='pt').to(device)
  sampling_output_tensor: Tensor = mod.generate(
      **encoded_input,
      min_length=min_length,
      max_length=max_length,
      do_sample=True,
      top_p=top_p,
      top_k=0,
  )
  sampling_output_text: str = tok.decode(sampling_output_tensor[0], skip_special_tokens=True)
  return sampling_output_text

In [79]:
predictions_2 = []
for i in question_list:
    predictions_2.append(generate_text_sampling_top_p_nucleus(input_prompt=i))


In [80]:
bleu = evaluate.load('bleu')
results = bleu.compute(predictions=predictions_2, references=references)
print(results)

{'bleu': 0.04692441522877583, 'precisions': [0.26963207029104885, 0.05926786751888437, 0.028994447871684145, 0.015779092702169626], 'brevity_penalty': 0.9024059279063897, 'length_ratio': 0.9068725099601593, 'translation_length': 1821, 'reference_length': 2008}


In [81]:
bleu = evaluate.load('rouge')
results = bleu.compute(predictions=predictions_2, references=references)
print(results)

{'rouge1': 0.23861089999503549, 'rouge2': 0.05924843334675213, 'rougeL': 0.19187467001698252, 'rougeLsum': 0.19171920822355748}


In [60]:
predictions_2 = []
for i in question_list:
    predictions_2.append(generate_text_sampling_top_p_nucleus(input_prompt=i))


In [61]:
bleu = evaluate.load('bleu')
results = bleu.compute(predictions=predictions_2, references=references)
print(results)

{'bleu': 0.009259904716647453, 'precisions': [0.16756513926325248, 0.012229539040451553, 0.0034550839091806516, 0.0010384215991692627], 'brevity_penalty': 1.0, 'length_ratio': 1.1085657370517927, 'translation_length': 2226, 'reference_length': 2008}


In [62]:
bleu = evaluate.load('rouge')
results = bleu.compute(predictions=predictions_2, references=references)
print(results)

{'rouge1': 0.16246295194878127, 'rouge2': 0.014241026596730572, 'rougeL': 0.12424949272359068, 'rougeLsum': 0.12365995347125072}


In [97]:
for i in [0,2,5,27,94]:
    print(f'question: {question_list[i]}')
    print(f'answer: {generate_text_sampling_top_p_nucleus(input_prompt=question_list[i])}')
    print('++++++++++++++++++')

question: I attend attbizsummit week come see presentation
answer:  Hi, I am sorry to hear this. Please send us a DM with your email address at, and we will look into this for you.
++++++++++++++++++
question: thank Amber return charter spectrum stuff tomorrow make change account seeyousoon
answer:  You're welcome! Please let us know if you need anything else. 
++++++++++++++++++
question: thank agent provide incorrect information go wrong store still problem persist att attwireless fail apple iphonex
answer:  We're sorry to hear this. Please let us know if we can be of any further assistance. 
++++++++++++++++++
question: please PUT ON YOUTUBE INTERNATIONAL fan CAN SEE
answer:  Sorry to hear that. Please send us a DM with your email address at, and we will take a closer look.
++++++++++++++++++
question: get wifi instal morning I already problem anyone know well internet provider
answer:  Hello, please be informed that your internet service has been restored. Thank you. 
+++++++++++++

In [98]:
print(evaluate_bleu_rouge(generate_text_sampling_top_p_nucleus))

({'bleu': 0.009035301228865698, 'precisions': [0.17311054983485524, 0.020161702998669533, 0.003568339100346021, 0.0005730002292000917], 'brevity_penalty': 0.9830479918013059, 'length_ratio': 0.9831900668576886, 'translation_length': 10294, 'reference_length': 10470}, {'rouge1': 0.15012473614300312, 'rouge2': 0.020537517738697803, 'rougeL': 0.12234546494444018, 'rougeLsum': 0.1221873227371291})


In [54]:
print(evaluate_bleu_rouge(generate_text_sampling_top_p_nucleus))

Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

({'bleu': 0.013740191741393062, 'precisions': [0.1822091681527326, 0.023751387347391788, 0.006480499587604572, 0.0018830027617373838], 'brevity_penalty': 0.9063858777300878, 'length_ratio': 0.9105062082139446, 'translation_length': 9533, 'reference_length': 10470}, {'rouge1': 0.15212199454041325, 'rouge2': 0.02278546912322785, 'rougeL': 0.12419858674582787, 'rougeLsum': 0.12421887262731825})


In [56]:
trainer.evaluate(dataset_preprocessed_eval)

{'eval_loss': 0.5634387731552124,
 'eval_runtime': 3.7143,
 'eval_samples_per_second': 140.808,
 'eval_steps_per_second': 17.769,
 'epoch': 1.0}

In [68]:
trainer.evaluate(dataset_dict['eval'])

{'eval_loss': 0.7091214060783386,
 'eval_runtime': 0.5951,
 'eval_samples_per_second': 87.374,
 'eval_steps_per_second': 11.762,
 'epoch': 1.0}

In [None]:
quest = 'hey my connection is so bad right now'
print(f'question: {quest}')
print(f'answer: {generate_text_sampling_top_p_nucleus_2(input_prompt=quest)}')

In [None]:
quest = 'I already send you the DM'
print(f'question: {quest}')
print(f'answer: {generate_text_sampling_top_p_nucleus(input_prompt=quest)}')

In [None]:
quest = 'I dont understand'
print(f'question: {quest}')
print(f'answer: {generate_text_sampling_top_p_nucleus(input_prompt=quest)}')

### Conclusion

We have seen various decoding methods for text generation using LLM. Several important notes:
- Beam can produce a more fluent text compared to Greedy, but require more computation
- Top-p and Top-K sampling can produce a more fluent text compared to Greedy and Beam
- Top-K and Top-p sampling also suffer from generating repetitive word sequences ([Welleck et al. (2020)](https://arxiv.org/abs/2002.02492))

This is the end of this tutorial. Feel free to try other models, datasets, methods, parameters, or instructions.

If you have any question, feel free to contact: adrianus.saga21@ui.ac.id

## References
- https://huggingface.co/docs/transformers/tasks/language_modeling
- https://huggingface.co/blog/how-to-generate