Use this notebook in [Google Colab](https://drive.google.com/file/d/1rjJYoIJsgnnN_tkYV00MQ5fh3sf9a6Sm/view?usp=sharing)

In [None]:
%%capture
!pip install datasets==1.5.0
!pip install transformers==4.5.1

In [None]:
from typing import List

from datasets import Dataset
from tensorflow.keras.utils import get_file
from transformers import AutoTokenizer, EncoderDecoderConfig, EncoderDecoderModel

In [None]:
PRETRAINED_MODEL_NAME = "bert-base-uncased"
TRAINED_MODEL_URL = "https://github.com/michaelnation26/yelp-review-generator/releases/download/1.0/model.zip"
TRAINED_MODEL_OUTPUT_DIR = "model"

ENCODER_MAX_LEN = 32
DECODER_MAX_LEN = 128

## Tokenizer

In [None]:
%%capture
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

## Load Pretrained Model

In [None]:
model_zip_filename = f"{TRAINED_MODEL_OUTPUT_DIR}.zip"
get_file(
    model_zip_filename, 
    TRAINED_MODEL_URL,
    cache_dir='.', 
    cache_subdir='',
    extract=True
)

In [None]:
%%capture
enc_dec_model_config = EncoderDecoderConfig.from_pretrained(TRAINED_MODEL_OUTPUT_DIR)
enc_dec_model = EncoderDecoderModel.from_pretrained(TRAINED_MODEL_OUTPUT_DIR, config=enc_dec_model_config)
enc_dec_model.to("cuda")

#### Utility Functions

In [None]:
def generate_reviews(test_ds: Dataset, decoder_max_length: int = DECODER_MAX_LEN) -> Dataset:
    def generate_reviews_batch(batch):
        # Tokenizer will automatically set [BOS] <text> [EOS]
        inputs = tokenizer(
            batch["input_text"], padding="max_length", truncation=True, max_length=ENCODER_MAX_LEN, return_tensors="pt"
        )
        input_ids = inputs.input_ids.to("cuda")
        attention_mask = inputs.attention_mask.to("cuda")
        outputs = enc_dec_model.generate(
            input_ids, attention_mask=attention_mask, max_length=decoder_max_length
        )
        
        batch["generated_reviews"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        return batch
    
    results = test_ds.map(generate_reviews_batch, batched=True)

    return results["generated_reviews"]

In [None]:
def build_input(
        stars: int,
        name: str, 
        city: str, 
        categories: List[str],
        funny: int = 50, 
        elite_level: int = 0
) -> str:
    """Builds an input string for a single example from the given features."""
    categories_str = ", ".join(categories)
    return (
        f"stars {stars}"
        f"; funny {funny}"
        f"; elite level {elite_level}"
        f"; name {name}"
        f"; city {city}"
        f"; categories {categories_str}"
    )

# Test Examples

#### Moe's Tavern from The Simpsons

In [None]:
test_input_text = {
    "input_text": [
        build_input(
            stars=star, name="Krusty Burger", city="Springfield", 
            categories=["Burgers", "Fast Food"]
        ) 
        for star in range(1, 6)
    ]
}
test_ds = Dataset.from_dict(test_input_text)

generate_reviews(test_ds)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




["i love krusty's, but don't have to waste a lot of time going there. they should be ashamed of their food. it's the same price as your normal fast food place. you would get better quality. the burgers are very bland and dry, like a few years ago. they serve it'll't seem much. for what they charge you for they just want to have you pay more for some... they were just giving out old burgers and you'd think they would have something to do with them. what does that mean? i don'll be going",
 "i really enjoy fast food. the food is good, but this location doesn't carry the same level of flavor i typically get at krusty burger in new jersey. the staff is friendly, the food comes hot and delicious ( though i'm just a burger - connoisseur and i love to see all the workers at the food counter smiling and chatting ).",
 "the krusty burgers here, like many of the other locations, are pretty good. the quality is quite good, though, and the customer service is very inconsistent. once in your life, 

#### Flourish & Blotts from Harry Potter

In [None]:
test_input_text = {
    "input_text": [
        build_input(
            stars=star, name="Flourish & Blotts", city="North Side, Diagon Alley", 
            categories=["Books", "Magic", "Bookstores", "Wizard"]
        ) 
        for star in range(1, 6)
    ]
}
test_ds = Dataset.from_dict(test_input_text)

generate_reviews(test_ds)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




['if i could give zero stars i would. i bought tickets online and they only accepted cash. that\'s not how they treat their customers. i didn\'t know how a company handles a credit card. the owner said he was a " manager " and he couldn\'t assist me. then he kept telling me " they don\'thing for your card or cash. " i told him this was so bad that i couldn\'nt believe. it was like being a number to him. it is ridiculous. i won\'t be going back.',
 "i'm giving it 2 starts simply based on the fact that the experience was just for me. i've never had more of an experience than that! after a series of laughs, i'll go forward with this review for a reason but i don't think anything stands out to me. first of all, this was my first time here since they moved, so i was pretty excited to see their new magic show there, so naturally we had to do something. then we were greeted by a very friendly person who showed us around. he let us stay, and the show started over an hour later. he told",
 'i\'