# HuggingTweets - Tweet Generation with Huggingface

*Disclaimer: this demo is not to be used to publish any false generated information but to perform research on Natural Language Generation (NLG).*

In [None]:
# Huggingface scripts for fine-tuning models and language generation
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/language-modeling/run_language_modeling.py -q
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/text-generation/run_generation.py -q

In [None]:
import ipywidgets as widgets
from IPython.display import display
import json
import urllib3
import random
import wandb

## Step 1 - Download tweets

We choose a Twitter user and download his tweets.

*Note*: Huggingtweets works only if the user has a lot of tweets!

In [None]:
def fix_text(text):
    text = text.replace('&amp;', '&')
    text = text.replace('&lt;', '<')
    text = text.replace('&gt;', '>')
    return text

In [None]:
def cleanup_tweet(tweet):
    "Clean tweet text"
    text = ' '.join(t for t in tweet.split() if 'http' not in t)
    if text.split() and text.split()[0] == '.':
         text = ' '.join(text.split()[1:])
    return text

In [None]:
def boring_tweet(tweet):
    "Check if this is a boring tweet"
    boring_stuff = ['http', '@', '#', 'thank', 'thanks', 'I', 'you']
    if len(tweet.split()) < 3:
        return True
    if all(any(bs in t.lower() for bs in boring_stuff) for t in tweet):
        return True
    return False

In [None]:
def dl_tweets(handle_value):
    handle = handle_value[1:] if handle_value[0] == '@' else handle_value
    run_dl_tweets.button_style = 'primary'
    log_dl_tweets.clear_output()
    with log_dl_tweets:
        try:
            print(f'\nDownloading {handle_value} tweets… This should take no more than a minute!')
            http = urllib3.PoolManager(retries=urllib3.Retry(3))
            res = http.request("GET", f"https://us-central1-playground-111.cloudfunctions.net/tweets_http?handle={handle}")
            curated_tweets = json.loads(res.data.decode('utf-8'))
            curated_tweets = [fix_text(tweet) for tweet in curated_tweets]
            log_dl_tweets.clear_output()
            print(f'\n{len(curated_tweets)} tweets from {handle_value} downloaded!')
            random.shuffle(curated_tweets)
            for i,t in enumerate(curated_tweets[:5]):
                print(f'\nExample #{i+1}\n{t}')
                
            # create dataset
            clean_tweets = [cleanup_tweet(t) for t in curated_tweets]
            cool_tweets = [tweet for tweet in clean_tweets if not boring_tweet(tweet)]
            with open('{}_train.txt'.format(handle), 'w') as f:
                f.write('\n'.join(cool_tweets))
            
            run_dl_tweets.button_style = 'success'
        except:
            print('An error occured…')
            run_dl_tweets.button_style = 'danger'

In [None]:
handle_widget = widgets.Text(value='@karpathy',
                             placeholder='Enter twitter handle',
                             description='User:')

run_dl_tweets = widgets.Button(
    description='Download tweets',
    button_style='primary')
def on_run_dl_tweets_clicked(b):
    dl_tweets(handle_widget.value)
run_dl_tweets.on_click(on_run_dl_tweets_clicked)

log_dl_tweets = widgets.Output()
with log_dl_tweets:
    print('\nEnter a Twitter handle and click "Download tweets"')

widgets.VBox([widgets.HBox([handle_widget, run_dl_tweets]), log_dl_tweets])

## Step 2 - Train your Neural Network

We use [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), a neural network which was trained to predict next words by reading large quantity of Internet text.

We fine-tune the model on our tweets using [Huggingface](https://huggingface.co/).

In [None]:
# Associate run to a project
%env WANDB_PROJECT=huggingtweets

In [None]:
def finetune():
    handle = handle_widget.value[1:] if handle_widget.value[0] == '@' else handle_widget.value
    run_finetune.button_style = 'primary'
    log_finetune.clear_output()
    with log_finetune:
        try:
            print(f'\nTraining Neural Network on {handle_widget.value} tweets… This could take up to 10 minutes!')
            !python run_language_modeling.py \
                --output_dir=output/$handle \
                --overwrite_output_dir \
                --model_type=gpt2 \
                --model_name_or_path=gpt2 \
                --do_train --train_data_file=$handle\_train.txt \
                --logging_steps 0 \
                --per_gpu_train_batch_size 1 \
                --num_train_epochs 4
            
            print('\n\nTraining Complete and Successful!!!')
            
            run_finetune.button_style = 'success'
        except:
            print('An error occured…')
            run_finetune.button_style = 'danger'

In [None]:
run_finetune = widgets.Button(
    description='Train Neural Network',
    button_style='primary')
def on_run_finetune_clicked(b):
    finetune()
run_finetune.on_click(on_run_finetune_clicked)

log_finetune = widgets.Output()
with log_finetune:
    print('\nFine-tune your model by clicking on "Train Neural Network"')

widgets.VBox([run_finetune, log_finetune])

## Step 3: Visualize Predictions and Have Fun!!!

If the model trained successfully, we can now visualize predictions!

We just start a sentence and let the model finish it!

In [None]:
def predict():
    handle = handle_widget.value[1:] if handle_widget.value[0] == '@' else handle_widget.value
    start = start_widget.value
    run_predictions.button_style = 'primary'
    log_predictions.clear_output()
    with log_predictions:
        try:
            print(f'\nPerforming predictions of {handle_widget.value} starting with "{start}"…\nThis should take no more than a minute!')
            seed = random.randint(0, 2**32-1)
            val = !python run_generation.py \
                --model_type gpt2 \
                --model_name_or_path output/$handle \
                --length 150 \
                --stop_token "{'\n'}" \
                --num_return_sequences 5 \
                --temperature 1 \
                --seed $seed \
                --prompt {'"' + start + '"'}
            generated = [val[-1-2*k] for k in range(5)[::-1]]
            log_predictions.clear_output()
            print(f'\nPredictions of {handle_widget.value} starting with "{start}" on #huggingtweet')
            for i, g in enumerate(generated):
                g = g.replace('<|endoftext|>', '')
                print(f'\nPrediction #{i+1}: {g}')
            
            run_predictions.button_style = 'success'
        except:
            print('An error occured…')
            run_predictions.button_style = 'danger'

In [None]:
start_widget = widgets.Text(value='I want',
                            placeholder='Enter twitter handle',
                            description='Start:')

run_predictions = widgets.Button(
    description='Run predictions',
    button_style='primary')
def on_run_predictions_clicked(b):
    predict()
run_predictions.on_click(on_run_predictions_clicked)

log_predictions = widgets.Output()
with log_predictions:
    print('\nEnter the start of a sentence and click "Run predictions"')

widgets.VBox([widgets.HBox([start_widget, run_predictions]), log_predictions])