In [1]:
import numpy as np
import pandas as pd

import re
import itertools

import torch


from nltk.translate.gleu_score import corpus_gleu, sentence_gleu
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu


from transformers import T5Tokenizer, TFT5Model, T5ForConditionalGeneration
from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate


from torch.utils.data import Dataset, DataLoader
import datasets

from transformers import Adafactor, get_linear_schedule_with_warmup
import pytorch_lightning as pl

from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning import Trainer

import warnings
warnings.filterwarnings("ignore")


In [2]:
device = torch.device('mps')

In [3]:
class quoteT5(pl.LightningModule):
    def __init__(self, lr=5e-5, num_train_epochs=3, warmup_steps=1000):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained("t5-base")
        self.save_hyperparameters()

        self.train_losses = []
        self.val_losses = []


    def forward(self, input_ids, attention_mask, labels=None):     
        outputs = self.model(input_ids=input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1), labels=labels.squeeze(1))
        return outputs
    
    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss

        return loss
      
    def training_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.train_losses.append(loss.item())

        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.val_losses.append(loss.item())

        return loss

    def test_step(self, batch, batch_idx):
        loss = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # create optimizer
        optimizer = Adafactor(self.parameters(), relative_step=True, warmup_init=True, lr=None)
        # create learning rate scheduler        
        return {"optimizer": optimizer}

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return validation_dataloader

    def test_dataloader(self):
        return test_dataloader

In [4]:
trained = quoteT5()
trained.to(device)


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

quoteT5(
  (model): T5ForConditionalGeneration(
    (shared): Embedding(32128, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=768, out_features=3072, bias=False)
                (wo): Linear(in_features=30

In [16]:
tokenizer = T5Tokenizer.from_pretrained('t5-large')
encoder_input_ids = tokenizer("Write a quote about love from the perspective of Dr. Seuss", return_tensors="pt", add_special_tokens=True).input_ids
with tokenizer.as_target_tokenizer():
    decoder_input_ids = tokenizer("You know you're in love when you can't fall asleep because reality is finally better than your dreams.", return_tensors="pt", add_special_tokens=True).input_ids

outputs = trained.model(input_ids=encoder_input_ids.to(device), decoder_input_ids=decoder_input_ids.to(device), output_attentions=True)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

In [17]:
from bertviz import model_view
model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text
)

<IPython.core.display.Javascript object>

# Evaluation and testing of t5 no-train

In [7]:
df = pd.read_csv('final_quotes.csv')
df = df[~df.quote.isna()]
df = df[df['quote'].str.split().apply(len) <= 50]
df['inputs'] = df.apply(lambda x: "Write a quote about {} from the perspective of {}".format(x['tags'], x['auth']), axis=1)

In [8]:
df = df[~df.tags.isna()]

In [9]:
test = df.sample(400)


In [10]:
from tqdm import tqdm

def generate_quote(request, beam=4, ngram=3):

    inputs_encoding =  tokenizer(
        request,
        add_special_tokens=True,
        max_length= 15,
        padding = 'max_length',
        truncation='only_first',
        return_attention_mask=True,
        return_tensors="pt"
        )

    
    generate_ids = trained.model.generate(
        input_ids = inputs_encoding["input_ids"].to(device),
        attention_mask = inputs_encoding["attention_mask"].to(device),
        max_length = 50,
        num_beams = beam,
        num_return_sequences = 1,
        no_repeat_ngram_size=ngram,
        early_stopping=True,
        )

    preds = [
        tokenizer.decode(gen_id,
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True)
        for gen_id in generate_ids
    ]


    max_length = 50
    stride = 512
    seq_len = inputs_encoding.input_ids.size(1)

    nlls = []
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = inputs_encoding.input_ids[:, begin_loc:end_loc]
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = trained.model(input_ids.to(device), labels=target_ids.to(device), attention_mask = inputs_encoding.attention_mask.to(device))

            # loss is calculated using CrossEntropyLoss which averages over input tokens.
            # Multiply it with trg_len to get the summation instead of average.
            # We will take average over all the tokens to get the true average
            # in the last step of this example.
            neg_log_likelihood = outputs.loss * trg_len

        nlls.append(neg_log_likelihood)

        prev_end_loc = end_loc
        if end_loc == seq_len:
            break

    ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
     
    return ("".join(preds), ppl)

In [11]:
from nltk.translate.bleu_score import SmoothingFunction
beam_widths = [2,3,4]
n_grams = [2,3,4]
smooth_fn = SmoothingFunction().method1

results_no_tune = {}
k=0.1
for beam_width, n_gram in itertools.product(beam_widths, n_grams):
    perplexities = []
    bleu = []
    gleu = []
    for idx, row in test.iterrows():    
        result, perplexity = generate_quote(row['inputs'], beam=beam_width, ngram=n_gram)
        perplexities.append(float(perplexity))
        if idx % 25 == 0:
            torch.cuda.empty_cache()
        test_set = df[(df['tags'] == row['tags']) & (df['auth'] == row['auth'])].quote

        reference_quotes = [quote.split() for quote in test_set.values]
        generated_quotes = result.split()

        bleu.append(sentence_bleu(reference_quotes, generated_quotes, smoothing_function=smooth_fn, weights=(0.25, 0.25, 0.25, 0.25), auto_reweigh=True))
        gleu.append(sentence_gleu(reference_quotes, generated_quotes))

    # store the results in the dictionary

    bleu = np.array(bleu)
    gleu = np.array(gleu)
    perplexity = np.array(perplexities)

    results_no_tune[str((beam_width, n_gram))] = {'GLEU': gleu.mean(), 'BLEU': bleu.mean(), 'PERPLEXITY': perplexity.mean()}

# analyze the results to determine the best combination of beam_width and n-gram
best_combination_gleu = max(results_no_tune, key=lambda x: results_no_tune[x]['GLEU'])
best_combination_bleu = max(results_no_tune, key=lambda x: results_no_tune[x]['BLEU'])
best_combination_perp = min(results_no_tune, key=lambda x: results_no_tune[x]['PERPLEXITY'])

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                       

In [12]:
results_no_tune

{'(2, 2)': {'GLEU': 0.02957321667788396,
  'BLEU': 0.010971643849585164,
  'PERPLEXITY': 16.318052982091903},
 '(2, 3)': {'GLEU': 0.029656324083997543,
  'BLEU': 0.010933501614821895,
  'PERPLEXITY': 16.318052982091903},
 '(2, 4)': {'GLEU': 0.029566072534121187,
  'BLEU': 0.010549144574472874,
  'PERPLEXITY': 16.318052982091903},
 '(3, 2)': {'GLEU': 0.029643973835197426,
  'BLEU': 0.01122549372106132,
  'PERPLEXITY': 16.318052982091903},
 '(3, 3)': {'GLEU': 0.029873830307368385,
  'BLEU': 0.011255282205237775,
  'PERPLEXITY': 16.318052982091903},
 '(3, 4)': {'GLEU': 0.029514074877906707,
  'BLEU': 0.01107651609284039,
  'PERPLEXITY': 16.318052982091903},
 '(4, 2)': {'GLEU': 0.0297379621848638,
  'BLEU': 0.011149472128586791,
  'PERPLEXITY': 16.318052982091903},
 '(4, 3)': {'GLEU': 0.029783591429309933,
  'BLEU': 0.011183307319227186,
  'PERPLEXITY': 16.318052982091903},
 '(4, 4)': {'GLEU': 0.029388582462604434,
  'BLEU': 0.01109612439402631,
  'PERPLEXITY': 16.318052982091903}}

In [13]:
df = pd.DataFrame(results_no_tune).transpose()

In [14]:
df

Unnamed: 0,GLEU,BLEU,PERPLEXITY
"(2, 2)",0.029573,0.010972,16.318053
"(2, 3)",0.029656,0.010934,16.318053
"(2, 4)",0.029566,0.010549,16.318053
"(3, 2)",0.029644,0.011225,16.318053
"(3, 3)",0.029874,0.011255,16.318053
"(3, 4)",0.029514,0.011077,16.318053
"(4, 2)",0.029738,0.011149,16.318053
"(4, 3)",0.029784,0.011183,16.318053
"(4, 4)",0.029389,0.011096,16.318053


In [15]:
best_combination_gleu

'(3, 3)'

In [19]:
generate_quote("Write a love quote from the perspective of Nancy Reagan", 2, 3)

  0%|                                                     | 0/1 [00:00<?, ?it/s]


('from the perspective of Nancy Reagan. Write a love quote from the point of view of the late Nancy Reagan and write a quote from Nancy Reagan’s perspective.',
 tensor(265.4142, device='mps:0'))