<a href="https://colab.research.google.com/github/meti-94/TextGeneration/blob/main/bert_autoencoder_personachat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
!tar -xf  simple-examples.tgz
!mkdir data
!mv /content/simple-examples/data/ptb.train.txt data/
!mv /content/simple-examples/data/ptb.valid.txt data/
!mv /content/simple-examples/data/ptb.test.txt data/
!rm -rf ./simple_examples

--2021-03-21 07:16:52--  http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
Resolving www.fit.vutbr.cz (www.fit.vutbr.cz)... 147.229.9.23, 2001:67c:1220:809::93e5:917
Connecting to www.fit.vutbr.cz (www.fit.vutbr.cz)|147.229.9.23|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 34869662 (33M) [application/x-gtar]
Saving to: ‘simple-examples.tgz’


2021-03-21 07:17:06 (2.60 MB/s) - ‘simple-examples.tgz’ saved [34869662/34869662]



In [1]:
!pip install transformers
!wget https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json
!mkdir models

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 18.1MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 40.1MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 49.5MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=087f

In [2]:
from torch.utils.data import Dataset, DataLoader
from transformers import EncoderDecoderModel, BertTokenizer, BertModel, BertConfig
import torch
from transformers import AdamW
from tqdm import tqdm
import json
import random 
from sklearn.model_selection import train_test_split
from datetime import datetime

import logging
logging.basicConfig(level=logging.DEBUG)

In [3]:
SPECIAL_TOKENS = ["<bos>", "<eos>", "<persona>", "<speaker1>", "<speaker2>", "<pad>"]

ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<bos>', 'eos_token': '<eos>', 'pad_token': '<pad>',
                         'additional_special_tokens': ['<speaker1>', '<speaker2>', '<persona>']}

def read_data(data_json_path='/content/personachat_self_original.json'):
  with open(data_json_path) as json_file:
    data_dict = json.load(json_file)
  return data_dict

def data_to_samples(data_dict, test=False):
  samples=[]
  for dialogue in (data_dict['train'] if test==False else data_dict['valid']):
    original_persona = dialogue['personality']
    
    for item in dialogue['utterances']:
      original_persona = [original_persona[-1]] + original_persona[:-1]
      history = item['history']
      response = item['candidates'][-1]
      samples.append({
          'persona':original_persona,
          'history':history,
          'response':response
      })
  return samples
  
def bertified(samples):
  bertified_data = []
  for item in samples:
    persona = ' <persona> '.join(item['persona'])
    persona = '<bos> ' + persona
    history = ''
    speakers = [" <speaker1> ", " <speaker2> "]
    speaker = 0
    for hst in item['history'][::-1]:
      history = speakers[speaker] + hst + history
      speaker = 1 - speaker
    response = '<speaker2> ' + item['response'] + ' <eos>'
    bertified_data.append({
          'persona':persona.replace('  ', ' '),
          'history':history.replace('  ', ' '),
          'input': persona.replace('  ', ' ')+' '+history.replace('  ', ' '), 
          'response':response.replace('  ', ' ')
      })
  return bertified_data


In [4]:
class PersonaDataset_v1(Dataset):
  '''
      Convert Data to proper Tensor dataset
  '''
  def __init__(self, samples):
    self.samples = samples
    self.n_samples = len(self.samples)

  def __getitem__(self, index):
    # returns specific item
    return self.samples[index] 
  def __len__(self):
    return self.n_samples
    # returns dataset length


class PTBDataset(Dataset):
  '''
      Convert Data to proper Tensor dataset
  '''
  def __init__(self, path):
    self.texts = []
    with open(path, 'r') as fin:
      for line in fin:
        self.texts.append(line.strip())
    self.n_samples = len(self.texts)

  def __getitem__(self, index):
    # returns specific item
    return self.texts[index] 
  def __len__(self):
    return self.n_samples
    # returns dataset length





In [5]:
class TrainingLoop:
  '''
  Everything related to model training
  '''
  def __init__( self, model, tokenizer, optimizer, freezeemb=True, 
                epochs=6, save_path='./models/', **kw):
    self.model = model
    params = []
    for paramname, param in self.model.named_parameters():
      if paramname.startswith("bert.embeddings.word_embeddings"):
        if not freezeemb:
          params.append(param)
      else:
        params.append(param)
    self.optimizer = optimizer(params, **kw)
    self.tokenizer = tokenizer
    self.epochs = epochs
    self.save_path = save_path
    self.predicts = None


  def train(self, dataloader, eval_dataloader, test_dataloader):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.model = self.model.to(device)
    for epoch in range(self.epochs):
      self.model.train()
      losses = []

      for _, batch in enumerate(tqdm(dataloader, position=0, leave=True, desc=f"Train Epoch Number {epoch+1}")):
        self.model.zero_grad()
        X = self.tokenizer(batch['input'], add_special_tokens=True, max_length=512, truncation=True, padding=True)
        y = self.tokenizer(batch['response'], add_special_tokens=True, max_length=512, truncation=True, padding=True)
        X = torch.tensor(X["input_ids"])
        y = torch.tensor(y['input_ids'])
        X = X.to(device); y = y.to(device)
        outputs = self.model(input_ids=X, decoder_input_ids=y, labels=y) 
        losses.append(outputs.loss.detach().item())
        outputs.loss.backward()
        self.optimizer.step()
        # break 
      logging.info(f'Epoch number: {epoch+1} Train Loss is equal: {sum(losses)/len(losses)}') 
      self.random_predict(test_dataloader, device, number_of_samples=10)
      self.eval(eval_dataloader, epoch, device)
      self.save(f"./models/autoencoder_{epoch}_{datetime.today().strftime('%Y-%m-%d')}.pt")


  def eval(self, dataloader, epoch, device):
    self.model = self.model.to(device)
    self.model.eval()
    losses = []
    for _, batch in enumerate(tqdm(dataloader, position=0, leave=True, desc=f"Eval Epoch Number {epoch+1}")):
      with torch.no_grad():
        X = self.tokenizer(batch['input'], add_special_tokens=True, max_length=512, truncation=True, padding=True)
        y = self.tokenizer(batch['response'], add_special_tokens=True, max_length=512, truncation=True, padding=True)
        X = torch.tensor(X["input_ids"])
        y = torch.tensor(y['input_ids'])
        X = X.to(device); y = y.to(device)
        outputs = self.model(input_ids=X, decoder_input_ids=y, labels=y) 
        
        losses.append(outputs.loss.detach().item())
        # break
    logging.info(f'Epoch number: {epoch+1} Eval Loss is equal: {sum(losses)/len(losses)}')
  
  def save(self, save_path='./models/autoencoder.pt'):
    logging.info(f'Saving model ...')
    torch.save(self.model, save_path)
	
  def load(self, save_path='./models/autoencoder.pt'):
    logging.info(f'Loading model ...')
    self.model = torch.load(save_path)

  def random_predict(self, dataloader, device, number_of_samples=10):
    counter=0
    for sample in dataloader:
      counter+=1
      _input = self.tokenizer(sample['input'], add_special_tokens=True, max_length=512, padding=True)
      _input = torch.tensor(_input['input_ids'])
      _input = _input.to(device)
      self.model = self.model.to(device)
      decoder_start = torch.tensor(30526).to(device)
      generated = self.model.generate(_input, decoder_start_token_id=torch.tensor(30526).to(device))
      logging.info('Real: '+ sample['response'][0])
      logging.info(tokenizer.convert_ids_to_tokens(generated[0]))
      if counter>number_of_samples:
        break 


In [None]:


model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN)
model.get_encoder().resize_token_embeddings(len(tokenizer))
model.get_decoder().resize_token_embeddings(len(tokenizer))

optimizer = AdamW
kw = {'lr':0.0002, 'weight_decay':0.1}
tl = TrainingLoop(model, tokenizer, optimizer, False, **kw)

data = read_data()
data_samples = data_to_samples(data)
bertified_data = bertified(data_samples)
train, valid = train_test_split(bertified_data, test_size=0.15, random_state=99)
test_data = read_data()
test_data_samples = data_to_samples(test_data, True)
test = bertified(test_data_samples)

train_dataset = PersonaDataset_v1(train)
valid_dataset = PersonaDataset_v1(valid)
test_dataset = PersonaDataset_v1(test)

train_dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=6, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tl.train(train_dataloader, valid_dataloader, test_dataloader)



# tl.save()
# ##################################################
# tl.load()
# tl.random_predict(test_dataloader, device, number_of_samples=10)
# ##################################################
# tl.readable_predict(device, print_result=True)

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140094538936080 on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170.lock
INFO:filelock:Lock 140094538936080 acquired on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 433


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…

DEBUG:filelock:Attempting to release lock 140094538936080 on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170.lock
INFO:filelock:Lock 140094538936080 released on /root/.cache/huggingface/transformers/3c61d016573b14f7f008c02c4e51a366c67ab274726fe2910691e2a761acf43e.637c6035640bacb831febcc2b7f7bee0a96f9b30c2d7e9ef84082d9f252f3170.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
DEBUG:filelock:Attempting to acquire lock 140094534091152 on /root/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f.lock
INFO:filelock:Lock 140094534091152 acquired on /root/.cache/huggingface/transformers/a8041bf617d7f94e




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…

DEBUG:filelock:Attempting to release lock 140094534091152 on /root/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f.lock
INFO:filelock:Lock 140094534091152 released on /root/.cache/huggingface/transformers/a8041bf617d7f94ea26d15e218abd04afc2004805632abc0ed2066aa16d50d04.faf6ea826ae9c5867d12b22257f9877e6b8367890837bd60f7c54a29633f7f2f.lock





DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/config.json HTTP/1.1" 200 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification mo

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…

DEBUG:filelock:Attempting to release lock 140094529052240 on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock
INFO:filelock:Lock 140094529052240 released on /root/.cache/huggingface/transformers/45c3f7a79a80e1cf0a489e5c62b43f173c15db47864303a55d623bb3c96f72a5.d789d64ebfe299b0e416afc4a169632f903f693095b4629a7ea271d5a0cf2c99.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/added_tokens.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/special_tokens_map.json HTTP/1.1" 404 0
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https:




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…

DEBUG:filelock:Attempting to release lock 140094527329616 on /root/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79.lock
INFO:filelock:Lock 140094527329616 released on /root/.cache/huggingface/transformers/c1d7f0a763fb63861cc08553866f1fc3e5a6f4f07621be277452d26d71303b7e.20430bd8e10ef77a7d2977accefe796051e01bc2fc4aa146bc862997a1a15e79.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443





DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 0
DEBUG:filelock:Attempting to acquire lock 140094529052240 on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock
INFO:filelock:Lock 140094529052240 acquired on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "GET /bert-base-uncased/resolve/main/tokenizer.json HTTP/1.1" 200 466062


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…

DEBUG:filelock:Attempting to release lock 140094529052240 on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock
INFO:filelock:Lock 140094529052240 released on /root/.cache/huggingface/transformers/534479488c54aeaf9c3406f647aa2ec13648c06771ffe269edabebd4c412da1d.7f2721073f19841be16f41b0a70b600ca6b880c8f3df6f3535cbc704371bdfa4.lock





Train Epoch Number 1:   9%|▉         | 1694/18621 [12:01<1:50:56,  2.54it/s]