# HuggingTweets - Train a model to generate tweets

Choose your favorite Twitter account and train a language model to write new tweets based on their unique voice in just 5 minutes.

Here is an example where I fine-tuned a neural network to predict Elon Musk's next breakthrough 😉

![huggingtweets illustration](https://raw.githubusercontent.com/borisdayma/huggingtweets/master/img/example.png)

## To start the demo, click on menu at top, "Runtime" → "Run all"

In [1]:
#@title ⠀ {display-mode: "form"}

def stylize():
    "Handle dark mode"
    display(HTML('''
    <style>
    :root {
        --table_bg: #EBF8FF;
    }
    html[theme=dark] {
        --colab-primary-text-color: #d5d5d5;
        --table_bg: #2A4365;
    }
    .jupyter-widgets {
        color: var(--colab-primary-text-color);
    }
    table {
        border-collapse: collapse !important;
    }
    td {
        text-align:left !important;
        border: solid var(--table_bg) !important;
        border-width: 1px 0 !important;
        padding: 6px !important;
    }
    tr:nth-child(even) {
        background-color: var(--table_bg) !important;
    }
    .table_odd {
        background-color: var(--table_bg) !important;
        margin: 0 !important;
    }
    .table_even {
        border: solid var(--table_bg) !important;
        border-width: 1px 0 !important;
        margin: 0 !important;
    }
    .jupyter-widgets {
        margin: 6px;
    }
    .widget-html-content {
        font-size: var(--colab-chrome-font-size) !important;
        line-height: 1.24 !important;
    }
    </style>'''))

def print_html(x):
    "Better printing"
    x = x.replace('\n', '<br>')
    display(HTML(x))
        
# Check we use GPU
import torch
from IPython.display import display, HTML, Javascript, clear_output
if not torch.cuda.is_available():
    print_html('Error: GPU was not found\n1/ click on the "Runtime" menu and "Change runtime type"\n'\
          '2/ set "Hardware accelerator" to "GPU" and click "save"\n3/ click on the "Runtime" menu, then "Run all" (below error should disappear)')
    raise ValueError('No GPU available')
else:
    # colab requires special handling
    try:
        import google.colab
        IN_COLAB = True
    except:
        IN_COLAB = False

    # Install dependencies (mainly for colab)
    if IN_COLAB:
        !pip install transformers==3.1.0 torch wandb==0.9.7 -qq

    import ipywidgets as widgets
    from IPython import get_ipython
    import json
    import urllib3
    import pathlib
    import shutil
    import requests
    import os
    import re
    import random
    import wandb
    
    stylize()
    
    log_debug = widgets.Output()
    
    # Have global access
    trainer = None
    artifact_dataset = None
    metadata = {}
    card_val = {}
    model_preview = None
    hfapi, token, namespace = None, None, None
    handle = ''

    # W&B variables
    WANDB_PROJECT = 'huggingtweets'
    WANDB_NOTES = "Github repo: https://github.com/borisdayma/huggingtweets"
    WANDB_ENTITY = 'wandb'
    HW_VERSION = 0.4
    os.environ['WANDB_NOTEBOOK_NAME'] = 'huggingtweets-demo.ipynb'  # used in wandb cli

    # HYPER-PARAMETERS
    ALLOW_NEW_LINES = False     # seems to work better
    LEARNING_RATE = 1.372e-4
    EPOCHS = 4

    def fix_text(text):
        text = text.replace('&amp;', '&')
        text = text.replace('&lt;', '<')
        text = text.replace('&gt;', '>')
        return text

    def html_table(data, title=None):
        'Create a html table'
        width_twitter = '75px'
        def html_cell(i, twitter_button=False):
            nl = "\n"
            return f'<td style="width:{width_twitter}">{i}</td>' if twitter_button else f'<td>{i.replace(nl, "<br>")}</td>'
        def html_row(row):
            return f'<tr>{"".join(html_cell(r, not i if len(row)>1 else False) for i,r in enumerate(row))}</tr>'
        body = f'<table style="width:100%">{"".join(html_row(r) for r in data)}</table>'
        title_html = f'<h3>{title}</h3>' if title else ''
        html = '<html><body>' + title_html + body + '</body></html>'
        return(html)

    def clean_tweet(tweet, allow_new_lines = ALLOW_NEW_LINES):
        bad_start = ['http:', 'https:']
        for w in bad_start:
            tweet = re.sub(f" {w}\\S+", "", tweet)      # removes white space before url
            tweet = re.sub(f"{w}\\S+ ", "", tweet)      # in case a tweet starts with a url
            tweet = re.sub(f"\n{w}\\S+ ", "", tweet)    # in case the url is on a new line
            tweet = re.sub(f"\n{w}\\S+", "", tweet)     # in case the url is alone on a new line
            tweet = re.sub(f"{w}\\S+", "", tweet)       # any other case?
        tweet = re.sub(' +', ' ', tweet)                # replace multiple spaces with one space
        if not allow_new_lines:                         # TODO: predictions seem better without new lines
            tweet = ' '.join(tweet.split())
        return tweet.strip()
        
    def boring_tweet(tweet):
        "Check if this is a boring tweet"
        boring_stuff = ['http', '@', '#']
        not_boring_words = len([None for w in tweet.split() if all(bs not in w.lower() for bs in boring_stuff)])
        return not_boring_words < 3

    def create_model_card(card_val, output_dir):
        model_card_url = 'https://github.com/borisdayma/huggingtweets/raw/master/model_card/README.md'
        model_card = requests.get(model_card_url).content.decode('utf-8')
        for k, v in card_val.items():
            model_card = model_card.replace(k, v)
        # make model card size unique (for correct sync in huggingface)
        model_card = model_card.replace("RANDOM_SZ", ' ' * random.randint(1, 100))
        with open(f'{output_dir}/README.md', 'w') as f:
            f.write(model_card)
    
    def on_preview_clicked(b):
        global model_preview
        global hfapi, token, namespace
        model_preview = f'https://www.huggingtweets.com/{handle}/{b.url_id}/predictions.png'
        card_val['SOCIAL_LINK'] = model_preview
        create_model_card(card_val, handle)
        readme = pathlib.Path(handle) / 'README.md'
        readme_path, readme_name = str(readme.resolve()), str(readme)
        hfapi.presign_and_upload(token, filename=readme_name, filepath=readme_path, organization=namespace)

        # Reset view
        log_model.clear_output(wait=True)
        with log_model:
            print_html("<h2>Model Preview (select a tweet to update)</h2>")
            display(HTML(f'<img src="{model_preview}" width=560 style="border: 1px solid lightgray; margin:5px;">'))
        
    def create_button(url_id):
        layout = widgets.Layout(width='70px', min_width='70px') #set width and height
        button = widgets.Button(
            description='Preview',
            button_style='info',
            layout = layout,
            tooltip = 'Set as model preview'
        )
        button.url_id = url_id
        button.on_click(on_preview_clicked)
        return button

    def ensure_widgets_updated(n_iter=5):
        '''ensure we get correct inputs and states are updated'''
        pass
        # used to be necessary in colab ; seems not needed anymore and create issues like in Jupyter
        #if IN_COLAB:  # not a problem with jupyter + create print output issues
        #    for _ in range(n_iter):
        #        get_ipython().kernel.do_one_iteration()


    def dl_tweets():
        handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_dl_tweets.button_style = 'primary'
        ensure_widgets_updated()
        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle = handle.lower().strip()
        log_dl_tweets.clear_output(wait=True)

        success_try = False

        with log_dl_tweets:
            try:
                print_html(f'\nDownloading {handle_widget.value.strip()} tweets... This should take no more than a minute!')
                http = urllib3.PoolManager(retries=urllib3.Retry(3))
                res = http.request("GET", f"http://us-central1-huggingtweets.cloudfunctions.net/get_tweets?handle={handle}&force=1")
                res = json.loads(res.data.decode('utf-8'))
                
                # save user info
                card_val['USER_HANDLE'] = handle
                card_val['USER_NAME'] = res['user_name']
                card_val['USER_PROFILE'] = res['user_profile'].replace('_normal.', '_400x400.')
                card_val['SOCIAL_LINK'] = res['social_link']

                all_tweets = res['tweets']
                curated_tweets = [fix_text(tweet) for tweet in all_tweets]
                log_dl_tweets.clear_output(wait=True)
                print_html(f"\n{res['n_tweets']} tweets from {handle_widget.value.strip()} downloaded!\n\n")
                    
                # create dataset
                clean_tweets = [clean_tweet(tweet) for tweet in curated_tweets]
                cool_tweets = [tweet for tweet in clean_tweets if not boring_tweet(tweet)]

                # create a file based on multiple epochs with tweets mixed up
                seed_data = random.randint(0,2**32-1)
                dataRandom = random.Random(seed_data)
                total_text = '<|endoftext|>'
                for _ in range(EPOCHS):
                    dataRandom.shuffle(cool_tweets)
                    total_text += '<|endoftext|>'.join(cool_tweets) + '<|endoftext|>'

                # display a few tweets
                display(HTML(html_table([[t] for t in curated_tweets[:8]])))
                ensure_widgets_updated()  # for auto-scroll
                
                if len(total_text) / EPOCHS < 6000:
                    # need about 4000 chars for one data sample (but depends on spaces, etc)
                    raise ValueError(f"Error: this user does not have enough tweets to train a Neural Network\n{res['n_tweets']} tweets downloaded, including {res['n_RT']} RT's and {len(all_tweets) - len(cool_tweets)} boring tweets... only {len(cool_tweets)} tweets kept!")
                if len(total_text) / EPOCHS < 40000:
                    print_html('\n\n<b>Warning: this user does not have many tweets which may impact the results of the Neural Network</b>\n\n')

                print_html('\nCreating dataset...')
                ensure_widgets_updated() # for auto-scroll
                
                # log dataset
                with log_debug:
                    wandb.login(key=res['wandb'])

                    with wandb.init(name=f'@{handle}-dl_data',
                                    job_type='dl_data',
                                    config={'huggingtweets version':HW_VERSION,
                                            'handle':handle},
                                    project = WANDB_PROJECT,
                                    entity = WANDB_ENTITY,
                                    notes = WANDB_NOTES,
                                    reinit=True) as run:
                        # log raw tweets as input
                        global metadata
                        metadata={'handle':handle,
                                  'tweets downloaded': res['n_tweets'],
                                  'retweets': res['n_RT'],
                                  'tweets kept': len(all_tweets),
                                  'huggingtweets version': HW_VERSION}
                        artifact_input = wandb.Artifact(
                            f'tweets-{handle}',
                            type='raw-dataset',
                            description=f'Raw tweets from @{handle} downloaded with Tweepy',                            
                            metadata=metadata)
                        with artifact_input.new_file('tweets.txt') as f:
                            json.dump(all_tweets, f, indent=0)
                        run.log_artifact(artifact_input)

                    with wandb.init(name=f'@{handle}-preprocess',
                                    job_type='preprocess',
                                    config={'huggingtweets version':HW_VERSION,
                                            'handle':handle,
                                            'seed data':seed_data},
                                    project = WANDB_PROJECT,
                                    entity = WANDB_ENTITY,
                                    notes = WANDB_NOTES,
                                    reinit=True) as run:
                        run.use_artifact(artifact_input)
                        # log dataset as output                        
                        metadata={'handle':handle,
                                  'tweets downloaded': res['n_tweets'],
                                  'retweets': res['n_RT'],
                                  'short tweets': len(all_tweets) - len(cool_tweets),
                                  'tweets kept': len(cool_tweets),
                                  'seed data': seed_data,
                                  'epochs': EPOCHS,
                                  'huggingtweets version': HW_VERSION}
                        global artifact_dataset
                        artifact_dataset = wandb.Artifact(
                            f'dataset-{handle}',
                            type='train-dataset',
                            description=f'Dataset created from tweets of @{handle}',
                            metadata=metadata)
                        with open(f'data_{handle}_train.txt', 'w') as f:
                            f.write(total_text)
                        artifact_dataset.add_file(f'data_{handle}_train.txt')
                        run.log_artifact(artifact_dataset)
                        
                        # keep track of url
                        wandb_url = wandb.run.get_url()
                        card_val['WANDB_PREPROCESS'] = wandb_url

                    # Save data info
                    card_val['TWEETS_DL'] = str(res['n_tweets'])
                    card_val['RETWEETS'] = str(res['n_RT'])
                    card_val['SHORT_TWEETS'] = str(len(all_tweets) - len(cool_tweets))
                    card_val['TWEETS_KEPT'] = str(len(cool_tweets))
                
                success_try = True

            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_dl_tweets.button_style = 'danger'
        
        if success_try:
            run_dl_tweets.button_style = 'success'
            log_finetune.clear_output(wait=True)
            with log_finetune:
                print_html('\nFine-tune your model by clicking on "Train Neural Network"')
            run_finetune.disabled = False
            with log_dl_tweets:
                print_html(f"\n🎉 Dataset created: {res['n_tweets']} tweets downloaded, including {res['n_RT']} RT's and {len(all_tweets) - len(cool_tweets)} short tweets... keeping {len(cool_tweets)} tweets")

        handle_widget.disabled = False
        run_dl_tweets.disabled = False
                
    handle_widget = widgets.Text(value='@elonmusk',
                                placeholder='Enter twitter handle')

    run_dl_tweets = widgets.Button(
        description='Download tweets',
        button_style='primary')
    def on_run_dl_tweets_clicked(b):
        dl_tweets()
    run_dl_tweets.on_click(on_run_dl_tweets_clicked)

    log_restart = widgets.Output()
    log_dl_tweets = widgets.Output()
    
    def finetune():
        # transformers imports later as wandb needs to have logged in
        import transformers
        from transformers import (
            AutoTokenizer, AutoModelForCausalLM,
            TextDataset, DataCollatorForLanguageModeling,
            Trainer, TrainingArguments,
            get_cosine_schedule_with_warmup)
        from transformers.hf_api import HfApi

        if run_finetune.button_style == 'success':
            # user double clicked before start of function
            return

        handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_finetune.disabled = True
        run_finetune.button_style = 'primary'

        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle = handle.lower().strip()
        model_url = f'https://huggingface.co/huggingtweets/{handle}'
        log_finetune.clear_output(wait=True)
        clear_output(wait=True)

        success_try = False

        with log_finetune:
            print_html(f'\nTraining Neural Network on {handle_widget.value.strip()} tweets... This could take up to 2-3 minutes!\n')
            progress = widgets.FloatProgress(value=0.1, min=0.0, max=1.0, bar_style = 'info')
            label_progress = widgets.Label('Downloading pre-trained neural network...')
            display(widgets.HBox([progress, label_progress]))

        with log_debug:
            try:                
                # Setting up pre-trained neural network
                global trainer
                tokenizer = AutoTokenizer.from_pretrained('gpt2')
                model = AutoModelForCausalLM.from_pretrained('gpt2', cache_dir=pathlib.Path('cache').resolve())
                block_size = tokenizer.model_max_length
                train_dataset = TextDataset(tokenizer=tokenizer, file_path=f'data_{handle}_train.txt', block_size=block_size, overwrite_cache=True)
                data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
                seed = random.randint(0,2**32-1)
                training_args = TrainingArguments(
                    output_dir=f'output/{handle}',
                    overwrite_output_dir=True,
                    do_train=True,
                    num_train_epochs=1,
                    per_device_train_batch_size=1,
                    prediction_loss_only=True,
                    logging_steps=5,
                    save_steps=0,
                    seed=seed,
                    learning_rate = LEARNING_RATE)
                
                # create wandb run (before it's done automatically by Trainer)
                combined_dict = {**model.config.to_dict(), **training_args.to_sanitized_dict()}
                run = wandb.init(name=f'@{handle}-train',
                                 job_type='train',
                                 config={'huggingtweets version':HW_VERSION,
                                         'pytorch version': torch.__version__,
                                         'transformers version': transformers.__version__,
                                         'handle':handle,
                                         **combined_dict},
                                 project = WANDB_PROJECT,
                                 entity = WANDB_ENTITY,
                                 notes = WANDB_NOTES,
                                 reinit=True)
                
                # keep track of url
                wandb_url = wandb.run.get_url()
                card_val['WANDB_TRAIN'] = wandb_url

                # Set-up Trainer
                os.environ['WANDB_WATCH'] = 'false'  # used in Trainer
                trainer = Trainer(
                    model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    data_collator=data_collator,
                    train_dataset=train_dataset)
                
                # Update lr scheduler
                train_dataloader = trainer.get_train_dataloader()
                num_train_steps = len(train_dataloader)
                trainer.create_optimizer_and_scheduler(num_train_steps)
                trainer.lr_scheduler = get_cosine_schedule_with_warmup(
                    trainer.optimizer,
                    num_warmup_steps=0,
                    num_training_steps=num_train_steps)

                progress.value = 0.3
                label_progress.value = 'Logging input artifacts to W&B...'

                # log dataset and pretrained model
                run.use_artifact(artifact_dataset)
                artifact_gpt2 = wandb.Artifact(
                    f'gpt2',
                    type='pretrained-model',
                    description=f'Pretrained model from OpenAI downloaded from 🤗 Transformers: https://huggingface.co/gpt2',
                    metadata={'huggingtweets version': HW_VERSION})
                artifact_gpt2.add_dir('cache', name='gpt2')
                run.use_artifact(artifact_gpt2)
                progress.value = 0.4
                label_progress.value = 'Training neural network...'
                
                p_start, p_end = 0.4, 0.8
                def progressify(f):
                    "Control progress bar when calling f"
                    def inner(*args, **kwargs):
                        if trainer.epoch is not None:
                            # we only have one epoch, EPOCHS is built into dataset
                            progress.value = p_start + trainer.epoch * (p_end - p_start)
                        return f(*args, **kwargs)
                    return inner
        
                trainer.training_step = progressify(trainer.training_step)
                
                # Training neural network
                with log_finetune:
                    display(wandb.jupyter.Run())
                    print_html('\n')
                    display(widgets.HBox([progress, label_progress]))
                trainer.train()

                # set model config parameters
                trainer.model.config.task_specific_params['text-generation'] = {
                    'do_sample': True,
                    'min_length': 10,
                    'max_length': 160,
                    'temperature': 1.,
                    'top_p': 0.95,
                    'prefix': '<|endoftext|>'}

                # create a folder with model files
                model_name = handle
                shutil.rmtree(model_name, ignore_errors=True)
                trainer.save_model(model_name)
                valid_files = ['config.json',
                               'pytorch_model.bin',
                               'special_tokens_map.json',
                               'tokenizer_config.json',
                               'vocab.json',
                               'merges.txt',
                               'added_tokens.json']
                for f in pathlib.Path(model_name).glob('*'):
                    if f.name not in valid_files:
                        f.unlink()
                
                # log model to huggingface
                label_progress.value = 'Uploading model to Hugging Face'
                hf_urls = []
                try:
                    global hfapi, token, namespace
                    # Get token
                    hfapi = HfApi()
                    user, namespace = 'huggingtweets-app', 'huggingtweets'
                    token = hfapi.login(user, namespace)
                    assert hfapi.whoami(token)[0] == user, "Could not log into Hugging Face"

                    create_model_card(card_val, model_name)

                    # upload files
                    model_path = pathlib.Path(model_name)
                    assert model_path.is_dir(), f"Expected {model_path} to be a directory"
                    files = [(str(f.resolve()), str(f)) for f in model_path.glob('*')]
                    assert len(files) == 7, f"Unexpected number of files in model directory: {len(files)}"
                    for filepath, filename in files:
                        print(filename, filepath) #TODO
                        hf_urls.append(hfapi.presign_and_upload(token, filename=filename, filepath=filepath, organization=namespace))
                        progress.value += 0.02
                
                except Exception as e:
                    with log_finetune:
                        print_html(f'\n<b>Could not upload the model to Hugging Face</b>\n{e}')

                # log model to W&B
                label_progress.value = 'Logging model to W&B...'
                global metadata
                metadata={'model url':model_url,
                          'seed trainer':seed,
                          **metadata}
                artifact_trained = wandb.Artifact(
                    model_name,
                    type='finetuned-model',
                    description=f'Model fine-tuned on tweets from @{handle}',
                    metadata=metadata)
                for hf_url in hf_urls:
                    artifact_trained.add_reference(hf_url)
                run.log_artifact(artifact_trained)
                progress.value = 0.98

                run_finetune.button_style = 'success'
                run_predictions.disabled = False

                progress.value = 1.0
                progress.bar_style = 'success'
                success_try = True

                label_progress.value = '🎉 Neural network trained successfully!'
                log_predictions.clear_output(wait=True)
                with log_predictions:
                    print_html('\nEnter the start of a sentence and click "Run predictions"')
                with log_restart:
                    print_html('\n<b>To change user, refresh the page</b>\n')

            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_finetune.button_style = 'danger'
                run_finetune.disabled = False
                            
        if not success_try:
            display(log_debug)
            progress.bar_style = 'danger'
        
    run_finetune = widgets.Button(
        description='Train Neural Network',
        button_style='primary',
        disabled=True)
    def on_run_finetune_clicked(b):
        finetune()
    run_finetune.on_click(on_run_finetune_clicked)

    log_finetune = widgets.Output()
    with log_finetune:
        print_html('\nWaiting for Step 1 to complete...')

    predictions = []
    
    def shorten_text(text, max_char):
        while len(text) > max_char:
            text = ' '.join(text.split()[:-1]) + '…'
        return text
        
    def predict():
        run_predictions.disabled = True
        start_widget.disabled = True
        run_predictions.button_style = 'primary'
        global handle
        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle_uncased = handle.strip()
        handle = handle.lower().strip()
        model_url = f'https://huggingface.co/huggingtweets/{handle}'
        log_predictions.clear_output(wait=True)

        # tweet buttons don't appear well in colab if within log_predictions widget
        # we reset the entire cell
        clear_output(wait=True)
        display(widgets.VBox([start_widget, run_predictions, log_model, log_predictions]))
        stylize()

        def tweet_html(tweet_text, tweet_url):
            tweet_text = shorten_text(tweet_text, 238)
            tweet_text = tweet_text.replace('"', '&quot;')

            return '<div style="padding-top: 4px"><a href="https://twitter.com/share?ref_src=twsrc%5Etfw" class="twitter-share-button" data-size="large" '\
                    f'data-text="{tweet_text}" '\
                    f'data-url="{tweet_url}" data-related="borisdayma,weights_biases,huggingface"'\
                    'data-show-count="false">Tweet</a></div><script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>'

        success_try = False

        # get start sentence
        ensure_widgets_updated()
        start = start_widget.value.strip()
                
        with log_predictions:
            print_html(f'\nPerforming predictions of @{handle} starting with "{start}"...\nThis should take no more than 10 seconds!')
        
        with log_debug:
            try:
                # start a wandb run (should never happen)
                if wandb.run is None:
                    print('Unexpected missing W&B run process')
                    wandb.init()
                
                # prepare input
                start_with_bos = '<|endoftext|>' + start
                encoded_prompt = trainer.tokenizer(start_with_bos, add_special_tokens=False, return_tensors="pt").input_ids
                encoded_prompt = encoded_prompt.to(trainer.model.device)

                # prediction
                output_sequences = trainer.model.generate(
                    input_ids=encoded_prompt,
                    max_length=160,
                    min_length=10,
                    temperature=1.,
                    top_p=0.95,
                    do_sample=True,
                    num_return_sequences=10
                    )
                generated_sequences = []

                # decode prediction
                for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
                    generated_sequence = generated_sequence.tolist()
                    text = trainer.tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
                    if not ALLOW_NEW_LINES:
                        limit = text.find('\n')
                        text = text[: limit if limit != -1 else None]
                    generated_sequences.append(text.strip())
                
                for i, g in enumerate(generated_sequences):
                    predictions.append([start, g])

                # create previews
                r = requests.post('https://us-central1-huggingtweets.cloudfunctions.net/screenshot',
                                  data = {"NAME": card_val['USER_NAME'],
                                          "HANDLE": card_val['USER_HANDLE'],
                                          "URL": card_val['USER_PROFILE'],
                                          "INPUT": start,
                                          "OUTPUTS": generated_sequences})
                ids = r.json()
                global model_preview
                global hfapi, token, namespace
                if model_preview is None:
                    model_preview = f'https://www.huggingtweets.com/{handle}/{ids[0]}/predictions.png'
                    card_val['SOCIAL_LINK'] = model_preview
                    create_model_card(card_val, handle)
                    readme = pathlib.Path(handle) / 'README.md'
                    readme_path, readme_name = str(readme.resolve()), str(readme)
                    hfapi.presign_and_upload(token, filename=readme_name, filepath=readme_path, organization=namespace)
                    with log_model:
                        print_html("<h2>Model Preview (select a tweet to update)</h2>")
                        display(HTML(f'<img src="{model_preview}" width=560 style="border: 1px solid lightgray; margin:5px;">'))

                # log predictions
                wandb.log({'examples': wandb.Table(data=predictions, columns=['Input', 'Prediction'])})

                # display tweets
                widgets_tweet = []
                center = widgets.Layout(align_items='center', display='flex')
                layout_twitter = widgets.Layout(width = '76px')
                for i, (g, id) in enumerate(zip(generated_sequences, ids)):
                    preview_button = create_button(id)
                    tweet_button = tweet_html(f'I love this tweet generated by my AI bot of @{handle_uncased} with huggingtweets!\nPlay with my model or create your own!\n\nMade by @borisdayma using @huggingface and @weights_biases',
                                              f'http://www.huggingtweets.com/{handle}/{id}/predictions.html')
                    w = widgets.HBox([preview_button,
                                      widgets.HTML(tweet_button, layout=layout_twitter),
                                      widgets.HTML(g)],
                                     layout=center)
                    w.add_class("table_odd" if i%2 else "table_even")
                    widgets_tweet.append(w)

                # make model share table
                tweet_share = f'I created an AI bot of @{handle_uncased} with huggingtweets!\nPlay with my model or create your own!\n\nMade by @borisdayma using @huggingface and @weights_biases'
                link_model = f'<a href="{model_url}" rel="noopener" target="_blank">{model_url}</a>'
                share_data = [[tweet_html(tweet_share, model_url),
                               f'🎉 Share @{handle_uncased} model: {link_model} <i>(may take 30 seconds to become active)</i>']]
                share_table = HTML(html_table(share_data))

                run_predictions.button_style = 'success'
                success_try = True
                
            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_predictions.button_style = 'danger'

        if success_try:
            with log_predictions:
                log_predictions.clear_output(wait=True)
                
                # twitter button does not update within widget in colab
                if not IN_COLAB:
                    print_html('\n')
                    display(share_table)
                    print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')
                    for w in widgets_tweet:
                        display(w)
                    print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')

            if IN_COLAB:
                print_html('\n')
                display(share_table)
                print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')
                for w in widgets_tweet:
                    display(w)
                print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')
        else:
            display(log_debug)
        
        run_predictions.disabled = False
        start_widget.disabled = False
                
    start_widget = widgets.Text(value='My dream is',
                                placeholder='Start a sentence')

    run_predictions = widgets.Button(
        description='Run predictions',
        button_style='primary',
        disabled=True)
    def on_run_predictions_clicked(b):
        predict()
    run_predictions.on_click(on_run_predictions_clicked)

    log_predictions = widgets.Output()
    with log_predictions:
        print_html('\nWaiting for Step 2 to complete...')
    log_model = widgets.Output()

    clear_output(wait=True)
    print_html("🎉 Environment set-up correctly! You're ready to move to Step 1!")

## Step 1 - Enter a Twitter handle

Enter a Twitter handle and click Download tweets. This gives the model a dataset of examples to train on.

In [2]:
#@title ⠀ {display-mode: "form"}
stylize()
display(widgets.VBox([handle_widget, run_dl_tweets, log_restart, log_dl_tweets]))

VBox(children=(Text(value='@elonmusk', placeholder='Enter twitter handle'), Button(button_style='primary', des…

## Step 2 - Train your Neural Network

Fine-tune a language model on your unique set of tweets to generate predictions.

The model is downloaded from [HuggingFace transformers](https://huggingface.co/), an awesome open source library for Natural Language Processing and training is logged through [Weights & Biases](http://docs.wandb.com/).

In [3]:
#@title ⠀ {display-mode: "form"}
stylize()
display(widgets.VBox([run_finetune, log_finetune]))

VBox(children=(Button(button_style='primary', description='Train Neural Network', style=ButtonStyle()), Output…

## Step 3: Generate tweets

Type the beginning of a tweet, press Run predictions, and the model will try to come up with a realistic ending to your tweet.

In [4]:
#@title ⠀ {display-mode: "form"}
stylize()
if IN_COLAB:
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 2000})'''))
display(widgets.VBox([start_widget, run_predictions, log_model, log_predictions]))

VBox(children=(Text(value='I have a dream ', disabled=True, placeholder='Start a sentence'), Button(button_sty…

0,1
Tweet,🎉 Share @ivanpeer model: https://huggingface.co/huggingtweets/ivanpeer (may take 30 seconds to become active)


HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px', width='70px'…

Huggingtweets is still in its infancy and will get better over time!

In the future, it will train continuously to become a Twitter expert!

## About

*Built by Boris Dayma*

[![Follow](https://img.shields.io/twitter/follow/borisdayma?style=social)](https://twitter.com/intent/follow?screen_name=borisdayma)

My main goals with this project are:
* to experiment with how to train, deploy and maintain neural networks in production ;
* to make AI accessible to everyone ;
* to have fun!

For more details, visit the project repository.

[![GitHub stars](https://img.shields.io/github/stars/borisdayma/huggingtweets?style=social)](https://github.com/borisdayma/huggingtweets)

**Disclaimer: this project is not to be used to publish any false generated information but to perform research on Natural Language Generation.**

## Resources

* [Explore the W&B report](https://app.wandb.ai/wandb/huggingtweets/reports/HuggingTweets-Train-a-model-to-generate-tweets--VmlldzoxMTY5MjI) to understand how the model works
* [HuggingFace and W&B integration documentation](https://docs.wandb.com/library/integrations/huggingface)

## Got questions about W&B?

If you have any questions about using W&B to track your model performance and predictions, please reach out to the [slack community](http://bit.ly/wandb-forum).