In [1]:
from utils.loader import DataLoader
from models.gpt import GPT2
import numpy as np
import pandas as pd

In [2]:
SPECIAL_TOKENS  = { "bos_token": "<|BOS|>",
                    "eos_token": "<|EOS|>",
                    "unk_token": "<|UNK|>",                    
                    "pad_token": "<|PAD|>",
                    "sep_token": "<|SEP|>"}

In [3]:
def sample_start_amazon(df, length=5):
    sample = df.sample(n=1)
    title, category, text = list(sample['REVIEW_TITLE'])[0], list(sample['PRODUCT_CATEGORY'])[0], list(sample['REVIEW_TEXT'])[0]
    sample = str(text).split(' ')
    return ' '.join(sample[:length]), title, category, text

def sample_start_gold(df, length=5):
    sample = df.sample(n=1)
    text = list(sample['REVIEW_TEXT'])[0]
    sample = str(text).split(' ')
    return ' '.join(sample[:length]), text

In [4]:
# Load our test-data that we will be sampling categories and prompts from
data_loader = DataLoader()
data_amazon = data_loader.load_amazon(test_mode=True, deceptive=False)
data_gold = data_loader.load_gold_txt()

  return func(*args, **kwargs)


In [5]:
# Load our model
model_path = 'training/distilgpt-topic2/pytorch_model.bin'
model = GPT2(model_path=model_path, full_model=False, special_tokens=SPECIAL_TOKENS)

In [8]:
# These are the available categories
categories = ['Apparel', 'Automotive', 'Baby', 'Beauty', 'Books', 'Camera', 'Electronics', 'Furniture', 'Grocery', 'Health & Personal Care', 'Home', 'Home Entertainment', 'Home Improvement', 'Jewelry', 'Kitchen', 'Lawn and Garden', 'Luggage', 'Musical Instruments', 'Office Products', 'Outdoors', 'PC', 'Pet Products', 'Shoes', 'Sports', 'Tools', 'Toys', 'Video DVD', 'Video Games', 'Watches', 'Wireless']
start_words = ['A', 'The', 'We', 'I', 'This', 'I love', 'I hate', '']

In [None]:
# Sample a random prompt and corresponding category from the dataset and gemerate
prompt, title, cat, original = sample_start_amazon(data_amazon, length=np.random.randint(2, 4))
prompt = SPECIAL_TOKENS['bos_token'] + cat + SPECIAL_TOKENS['sep_token'] + prompt
print(f'Text: {original[:200]} \nPrompt: {prompt}...\n')
outputs = model.generate_text(prompt, cat, print_output=True, do_sample=True, max_length=70, num_beams=5, repetition_penalty=5.0, early_stopping=True, num_return_sequences=3)

In [54]:
# Begin generating samples
# 25k will be sampled with random category and random start word from OPSpam as it has better grammar
# 25k sampled from Amazon dataset with corresponding category and first 2-5 words and let GPT finish
fake_reviews = pd.DataFrame(columns=['PRODUCT_CATEGORY', 'REVIEW_TEXT'])

In [58]:
# Here we sample a random category, and a random start word.
all_reviews = []
while len(all_reviews) < 25000:
    print(f'{len(all_reviews)}/25000')
    prompt, original = sample_start_gold(data_gold, length=1)
    cat = np.random.choice(categories)
    prompt = SPECIAL_TOKENS['bos_token'] + cat + SPECIAL_TOKENS['sep_token'] + prompt
    outputs = model.generate_text(prompt, cat, print_output=False, do_sample=True, max_length=200, num_beams=5, repetition_penalty=5.0, early_stopping=True, num_return_sequences=3)
    for review in outputs:
        if len(review) > 10: # Ensure text generated is text
            all_reviews.append([cat, review])

0/25000
3/25000
6/25000
6/25000
9/25000
12/25000
15/25000
18/25000
21/25000
24/25000
27/25000
30/25000
33/25000
36/25000
37/25000
40/25000
43/25000
46/25000
49/25000
52/25000
55/25000
58/25000
61/25000
64/25000
67/25000
70/25000
73/25000
76/25000
79/25000
82/25000
85/25000
88/25000
91/25000
94/25000
97/25000
100/25000
103/25000
106/25000
109/25000
112/25000
115/25000
115/25000
118/25000
118/25000
121/25000
124/25000
127/25000
130/25000
133/25000
136/25000
139/25000
139/25000
142/25000
142/25000
145/25000
148/25000
151/25000
154/25000
157/25000
160/25000
163/25000
166/25000
169/25000
170/25000
173/25000
176/25000
179/25000
182/25000
185/25000
188/25000
191/25000
192/25000
195/25000
198/25000
198/25000
201/25000
201/25000
204/25000
207/25000
210/25000
213/25000
216/25000
219/25000
222/25000
225/25000
228/25000
231/25000
234/25000
237/25000
239/25000
242/25000
245/25000
245/25000
246/25000
249/25000
252/25000
253/25000
256/25000
259/25000
262/25000
265/25000
268/25000
271/25000
274/25000
