In [1]:
%%capture
!pip install transformers

In [None]:
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import torch
from torch.utils.data import Dataset 
import random
import time
import datetime
import random
from transformers import GPT2LMHeadModel, GPT2Config
import numpy as np
from torch.utils.data import random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

In [1]:
import base64
import requests

master = "https://raw.githubusercontent.com/mcelikkaya/medium_articles/main/japan_wiki.txt"
req = requests.get(master)
req = req.text

In [2]:
all_sentences = req.split("\n")
all_sentences = [s.replace("\r","") for s in all_sentences]


In [3]:
print("sample size : ",len(all_sentences))
print("samples     : " )
all_sentences[0:10]

sample size :  40389
samples     : 


['Hokkaido was formerly known as Ezo  Yezo  Yeso  or Yesso.',
 'According to Matsuura  the name was thought up because the Ainu called the region Kai.',
 'In contrast to the island of Honshu  Hokkaido saw an absence of conflict during this time period.',
 'From the Middle Ages  the people in Hokkaido began to be called Ezo.',
 'Hokkaido subsequently became known as Ezochi  蝦夷地  lit.',
 'The disputes eventually developed into war.',
 'Takeda Nobuhiro killed the Ainu leader  Koshamain  and defeated the opposition in 1457.',
 'The Matsumae family s economy relied upon trade with the Ainu.',
 'They held authority over the south of Ezochi until the end of the Edo period.',
 'There were numerous revolts by the Ainu against the feudal rule.']

47709


In [19]:
from transformers import GPT2Tokenizer
#get pretrained tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<sos>', eos_token='<eos>', pad_token='<pad>')

#tokenizer some samples
print( tokenizer.encode("Japan Tokyo") )
print( tokenizer.encode("Japan") )
print( tokenizer.encode("japan tokyo") )
print( tokenizer.encode("japan") )
print( tokenizer.encode("tokyo") )

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.



[16504, 11790]
[16504]
[73, 2674, 284, 2584, 78]
[73, 2674]
[83, 482, 8226]


In [20]:
max_len = max([len(tokenizer.encode(s)) for s in all_sentences])

print(f"max_len {max_len}")

max_len 85


In [22]:
#since we will be feeding with sentences from wikipedia
#we can mark beginning and end of sentences with with sos and eos
def tokenize_seq(sent,tokenizer,max_length):
  return tokenizer('<sos>'+ sent + '<eos>', truncation=True, max_length=max_length, padding="max_length")

class JapanDataset(Dataset):

  def __init__(self, sentences, tokenizer, gpt2_type="gpt2", max_length=max_len):

    self.tokenizer = tokenizer 
    self.input_ids = []
    self.attn_masks = []

    for sentence in sentences:      
      encodings = tokenize_seq(sentence,tokenizer,max_length)
            
      self.input_ids.append(torch.tensor(encodings['input_ids']))
      self.attn_masks.append(torch.tensor(encodings['attention_mask']))
    
  def __len__(self):
    return len(self.input_ids)

  def __getitem__(self, idx):
    return self.input_ids[idx], self.attn_masks[idx]   

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))    

In [23]:
import gc
gc.collect() 


13771

In [24]:
#create an instance of Dataset
dataset = JapanDataset(all_sentences, tokenizer, max_length=max_len)

# Split into training and validation sets
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size

train_set, val_set = random_split(dataset, [train_size, val_size])
print("train_size :",train_size)
print("val_size   :",val_size)

gc.collect() 

train_size : 42938
val_size   : 4771


0

In [25]:
#lets check a sample from dataset 
#50257 beginning of sentence token
#50258 end of sentence token
#50259 pad token
dataset[0]

(tensor([50257,    39,   482,    74, 44354,   373, 15734,  1900,   355,   412,
         10872,   220,   575,  8471,    78,   220,  3363,    78,   220,   393,
           575,   408,    78,    13, 50258, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
         50259, 50259, 50259, 50259, 50259]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

In [26]:
#define dataloaders
train_dataloader = DataLoader(train_set,  sampler = RandomSampler(train_set), batch_size = 32)
validation_dataloader = DataLoader(val_set, sampler = SequentialSampler(val_set), batch_size = 32 )

In [27]:
# Create default config
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
# Load pretrained gpt2
model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)
model.resize_token_embeddings(len(tokenizer))

# Create device
device = torch.device("cuda")
model.cuda()


optimizer = torch.optim.Adam(model.parameters(),lr = 0.0005)
model = model.to(device)

In [29]:
#at every step i want to check if generations are getting better.
def eval_keywords(keywords):
  model.eval()
  for keyword in keywords:
    input_seq = "<sos> " + keyword
    generated = torch.tensor(tokenizer.encode(input_seq)).unsqueeze(0)
    generated = generated.to(device)
    sample_outputs = model.generate(
                                generated, 
                                do_sample=True,   
                                top_k=30, 
                                max_length = 50,
                                top_p=0.90, 
                                num_return_sequences=2
                                )
    for i, sample_output in enumerate(sample_outputs):
      print("{}: {}".format(i, tokenizer.decode(sample_output, skip_special_tokens=True)))

keywords = ["Osaka","Japan","Kyoto","Yokohama","Kanto","Nikko","Japan has","Tokyo is the","Osaka is the","Kyoto is the"]

In [31]:
#eval_keywords( keywords )

In [32]:
#call model with a batch of input
def process_one_batch(batch):
  b_input_ids = batch[0].to(device)
  b_labels = batch[0].to(device)
  b_masks = batch[1].to(device)
  outputs  = model(b_input_ids,  attention_mask = b_masks,labels=b_labels)
  return outputs

#do one epoch for training
def train_epoch():
  t0 = time.time()
  total_train_loss = 0
  model.train()
  for step, batch in enumerate(train_dataloader):
        
        model.zero_grad()        
        outputs = process_one_batch( batch)
        loss = outputs[0]  
        batch_loss = loss.item()
        total_train_loss += batch_loss

        loss.backward()
        optimizer.step()

        
  avg_train_loss = total_train_loss / len(train_dataloader)  
  print("avg_train_loss",avg_train_loss)  
  elapsed_time = format_time(time.time() - t0)
  print("elapsed time for 1 training epoch : ",elapsed_time)

#do one epoch for eval
def eval_epoch():
  t0 = time.time()
  total_eval_loss = 0
  nb_eval_steps = 0
  # Evaluate data for one epoch
  for batch in validation_dataloader:            
        
    with torch.no_grad():        
      outputs = process_one_batch( batch)
      loss = outputs[0]              
      batch_loss = loss.item()
      total_eval_loss += batch_loss         

  avg_val_loss = total_eval_loss / len(validation_dataloader)
  print("avg_val_loss",avg_val_loss) 
  elapsed_time = format_time(time.time() - t0)
  print("elapsed time for 1 eval epoch : ",elapsed_time)

In [33]:
#train eval 1 cycle
#then create sample sentences
train_epoch()
eval_epoch()
eval_keywords( keywords )

avg_train_loss 0.9285519781748691
elapsed time for 1 training epoch :  0:16:46
avg_val_loss 0.776570185025533
elapsed time for 1 eval epoch :  0:00:40
0:  Osaka City  Japan.
1:  Osaka  洋酢   the main shopping district in Osaka  is the largest shopping mall in Japan.
0:  Japan s population growth rate is estimated at 2  by 2050.
1:  Japan  快酢  lit.
0:  Kyoto has one of the world s highest natural gas reserves  the largest by volume in the world.
1:  Kyoto   平田聞境   Aomori Ichiran.
0:  Yokohama became the first Japanese city to host the 1964 Summer Olympics.
1:  Yokohama  石田の扛島  is the city of Toyosu in Japan.
0:  Kanto is the most populous island in the area.
1:  Kanto is the most populous province of Japan.
0:  Nikko  奀  is a type of rice cake made of rice rice flour.
1:  Nikko Giongård  a member of the Danish Parliament  was appointed the first Prime Minister.
0:  Japan has the third highest death toll in the world  after the United States and Italy.
1:  Japan has won the bronze medal i

In [34]:
train_epoch()
eval_epoch()
eval_keywords( keywords )

avg_train_loss 0.6879545370708102
elapsed time for 1 training epoch :  0:17:06
avg_val_loss 0.7604602559407552
elapsed time for 1 eval epoch :  0:00:40
0:  Osaka  The Life of Tokugawa Yorinori.
1:  Osaka   Kobe University of Tokyo.
0:  Japan s Modern History  Transnational Memory  Cultural Memory  and Culture.
1:  Japan s Great Northern Magellan and the Origins of its Pacific War.
0:  Kyoto  Osaka  Tokyo  Nagoya  Nagoya  Kobe  Kyoto University.
1:  Kyoto  A Japanese Military History of Japan.
0:  Yokohama  Tokyo Bay  Japan.
1:  Yokohama Castle  Yakuza Monogatari.
0:  Kanto  閑立  
Hokusai  閑立    or   Hakushikari  閑立    lit.   Hakushu  
1:  Kanto Kansai Shinkai  日本誓事文堂  in Japanese.
0:  Nikko Yūkawa
An Illustrated Encyclopedia of Japan  1st ed..
1:  Nikko Naruhito
The Tale of Two Sisters  The Life of a Mom.
0:  Japan has never been a strong patron of foreign languages.
1:  Japan has no significant domestic reserves in the sea.
0:  Tokyo is the most populous city and the second smallest in

In [35]:
train_epoch()
eval_epoch()
eval_keywords( keywords )

avg_train_loss 0.5925561110371094
elapsed time for 1 training epoch :  0:17:05
avg_val_loss 0.7761015864213308
elapsed time for 1 eval epoch :  0:00:40
0:  Osaka   Methodist publishing house  永原祬県  Shukkeikan Shukkeikan   kabuki  is a famous kabuki theater in Japan.
1:  Osaka   Shrines and the Shrines of the Osaka Plain   in Japanese.
0:  Japan s Anti Nucleararmament Act of 1925  The Politics of the Twenty One States.
1:  Japan s population has surpassed 600 million by 2050—and will surpass that by 2050.
0:  Kyoto  Kodansha International  1989.
1:  Kyoto  Hakata Shiga – Hakata Shiga – Hakata Shiga
Hirohirose  Tomohuki   Kato Mie  Sugata  Hōshi  Tetsu  Masatari  Kato
0:  Yokohama   University of Michigan Press.
1:  Yokohama    The largest city and metropolitan area in Japan.
0:  Kanto Kita no Hironi  The Tale of Genji.
1:  Kanto   which means  high   is the mountain above the Sea of Japan  山館   the Sea of Japan  山�    a.e.g.
0:  Nikko Koškěi – Škubyŋ   head of state   lit.   the head of

In [36]:
train_epoch()
eval_epoch()
eval_keywords( keywords )

avg_train_loss 0.5100571858056789
elapsed time for 1 training epoch :  0:17:06
avg_val_loss 0.8096200331052145
elapsed time for 1 eval epoch :  0:00:40
0:  Osaka  Methodist publishing house.
1:  Osaka   Tokyo  Kansai  Kansai International.
0:  Japan s foreign diplomatic system  in Japanese.
1:  Japan  Japanism  and the Japanese Miracle  1922–1941. 
0:  Kyoto  a Japanese major exporter  is located in Kyoto.
1:  Kyoto  Tokyo  Shochu  Kansai  Hanyaku  Kansai Togidaro  eds.
0:  Yokohama was also a founding member of the ICAFF  the first international association for the sport.
1:  Yokohama Prefecture 住治県  Yokohama ken  is a prefecture of Japan located on the island of Kyūshū.
0:  Kanto Island   – formerly Kansai I.
Kanto Airport  Kansai International Airport  Kansai International Airport  Kobe.
1:  Kanto was located at the heart of the Seto Inland Sea  which separates Japan from the others. 
0:  Nikko Krim won the bronze with a score of 1 117.776 seconds.
1:  Nikko Hachiman   born 1954  Di

In [37]:
train_epoch()
eval_epoch()
eval_keywords( keywords )

avg_train_loss 0.4388331177737425
elapsed time for 1 training epoch :  0:17:07
avg_val_loss 0.857030572493871
elapsed time for 1 eval epoch :  0:00:40
0:  Osaka  Kodansha International   1998.
1:  Osaka    listen   is a Commercial City of Japan.
0:  Japan is the only country within the G7 to still own a mobile phone.
1:  Japan s Changing Defense Policy   Politics and Economics.
0:  Kyoto     Tottori – Iwo Island   

Tottori was previously called Kyushu  ʔʔtōn.
1:  Kyoto  Kodansha Ltd. 1993. p. 943–98.
0:  Yokohama and Tokyo were designated as  key station for  U.S..
1:  Yokohama Expressway forks from the Chūgoku Expressway at Kita Hiroshima and stretches from Chūgoku Station.
0:  Kanto     
Kanto s dialects are split into three parts  dialect   i.e.
1:  Kanto Island  
Yoshima Island   or Yoshima Island 


   References   


  Further reading  
Akiba  2000.
0:  Nikko Filipinos  Japanese    Tatopoulos  Frank.
1:  Nikko Popov has been a member of the Eurovision Song Contest since 1997.
0:

In [38]:
train_epoch()
eval_epoch()
eval_keywords( keywords )

avg_train_loss 0.3771453394308886
elapsed time for 1 training epoch :  0:17:09
avg_val_loss 0.9076750135421753
elapsed time for 1 eval epoch :  0:00:40
0:  Osaka  New York  Kodansha International  1979  pp.
1:  Osaka  Chiyoda Tokyo  Osaka Chiyoda  Chiyoda Seiyama Gahōsha アシェット報構本麿  アシェット報��
0:  Japan s World Cities  Tokyo  Nagano  Shirokuma Yatosei  Taito Ward  1999.
1:  Japan at the Crossroads  Conflict and Compromise after Anpo. 
0:  Kyoto 
The place of my religion is Zen meditation  which forms the Buddhist religion of Japan.
1:  Kyoto  Shrines and Buddhist Convent in Kyoto  Chiyoda  2002   92 pp.
0:  Yokohama  Ponsonby Memorial Society. 
1:  Yokohama   Ponsonby Memorial Society. 
0:  Kanto   which in turn was influenced by westernization.
1:  Kanto      or  ho  is not a final word.
0:  Nikko Tange presented it to him on November 6  2008.
1:  Nikko and Romko in Japan  A History  1959 
Shiga s education became largely free.
0:  Japan has been a landlocked country since the end of the