# ChatGPT Mini - A ChatGPT implementation for beginners
<img src="https://i.imgur.com/rYK8S8z.png">

## What is ChatGPT-Mini? 

- A simple python implementation of ChatGPT's 3 step training strategy.
- Runs in a single Google Colab notebook (<30 GB of HDD, <7 GB of GPU RAM required) 
- No expensive training time! ChatGPT cost >4MM USD to train and ~3MM to run daily
- It will embody our personality of choice. In this example, we'll train it to be a pessimistic Elon Musk. 

## ChatGPT was created in 3 steps
-------------------
### - Step 1: Generative Pre-training
### - Step 2: Supervised Fine-tuning
### - Step 3: Reinforcement Learning from Human Feedback
-------------------
<img src="https://i.imgur.com/zIG440O.png">

In [None]:
#In python....

## Step 1 - Generative Pre-Training (Learn English)
# - Import Dependencies
# - Collect English Language Data
# - Sort English Data
# - Build Model
# - Train Model
#--------------------------------------
## Step 2 - Supervised Fine-Tuning (Train it some more, on task specific data)
# - Collect Task Data
# - Pre-process Task Data
# - Train it
# - Test on sentence completion
#--------------------------------------
## Step 3 - Reinforcement Learning from Human Feedback
# - Define a Reward model
# - Define a Policy
# - Execute Proximal Policy Optimization (from Static Human Feedback)
# - ChatGPT

Let's replicate each step as best we can in a single Google colab! 

# Step 1 - Generative Pre-training

- OpenAI trained GPT-3 on 300 billion text tokens.
<img src="https://i.imgur.com/IwxSmcL.png">
- GPT-3 is a decoder-only transformer neural network
- Self-supervised learning. No labels, just predict the next char.
- Generated sentences by iteratively predicting the next token & appending it back into the input i.e auto-regressive. 
- Instead of this... 
<img src="https://machinelearningmastery.com/wp-content/uploads/2021/08/attention_research_1.png">
- They used this...

<img src="https://i.imgur.com/c4Z6PG8.png">

#### GPT in 60 lines of Python 

In [None]:
#@title ⠀ {display-mode: "form"}

# This is PicoGPT, by Jay Mody https://github.com/jaymody/picoGPT/blob/main/gpt2_pico.py
# This serves as an educational example of a GPT

# for matrix math
import numpy as np

##Gaussian Error Linear Units is an alternative to the ReLU activation 
# function, and is approximated by the following function
#The BERT paper popularized the use of GeLU in transformer based models, 
#and it kind of stuck around since.
def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

#We apply softmax over the last axis of the input.
#converts a vector of K real numbers into a probability distribution of K possible outcomes. 
def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

#Normalization in neural networks refers to the process of scaling the inputs or 
#activations of a network to have a mean of 0 and a standard deviation of 1. 
#Layer normalization is used to stabilize the distribution of the activations, 
#allowing the network to train more efficiently and reducing the risk of internal covariate shift.
def layer_norm(x, g, b, eps: float = 1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    return g * (x - mean) / np.sqrt(variance + eps) + b

#a single linear layer, input times weight, add a bias
def linear(x, w, b):
    return x @ w + b

#the forward function, input times weight + bias, activate, repeat
def ffn(x, c_fc, c_proj):
    return linear(gelu(linear(x, **c_fc)), **c_proj)

#the attention layer!!!!!!!
#The attention mechanism allows the GPT model to dynamically focus on 
#different parts of the input sequence while generating the next word 
#in a language modeling task. By allowing the model to attend to different 
#parts of the input sequence at different times, it can learn more complex 
#relationships between the input and output, leading to improved language modeling capabilities.
def attention(q, k, v, mask):
    return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

#The Multi-Head Attention mechanism allows the GPT model to attend to 
#multiple parts of the input sequence simultaneously, which helps it 
#learn more complex relationships between the input and output. 
#By splitting the attention mechanism into multiple parallel heads, 
#the model can learn to attend to different aspects of the input, 
#resulting in improved language modeling capabilities.
def mha(x, c_attn, c_proj, n_head):
    x = linear(x, **c_attn)
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
    casual_mask = (1 - np.tri(x.shape[0])) * -1e10
    out_heads = [attention(q, k, v, casual_mask) for q, k, v in zip(*qkv_heads)]
    x = linear(np.hstack(out_heads), **c_proj)
    return x

#a block will contain multi headed attention , and forward function
def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)
    x = x + ffn(layer_norm(x, **ln_2), **mlp)
    return x

# define our GPT as multiple blocks
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
    x = wte[inputs] + wpe[range(len(inputs))]
    for block in blocks:
        x = transformer_block(x, **block, n_head=n_head)
    return layer_norm(x, **ln_f) @ wte.T

#generate text by sampling
def generate(inputs, params, n_head, n_tokens_to_generate):
    from tqdm import tqdm
    for _ in tqdm(range(n_tokens_to_generate), "generating"):
        logits = gpt2(inputs, **params, n_head=n_head)
        next_id = np.argmax(logits[-1])
        inputs = np.append(inputs, [next_id])
    return list(inputs[len(inputs) - n_tokens_to_generate :])

#load up any pretrained weights, and start generating text
def main(prompt: str, n_tokens_to_generate: int = 40, model_size: str = "124M", models_dir: str = "models"):
    from utils import load_encoder_hparams_and_params
    encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
    input_ids = encoder.encode(prompt)
    assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
    output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
    output_text = encoder.decode(output_ids)
    return output_text

#run it
if __name__ == "__main__":
    import fire
    fire.Fire(main)

- Training this GPT on a 500 GB dataset is too expensive and time consuming.
- Also it doesn't yet include a backward pass, so it can't be optimized 
- Instead, let's train Andrej Karpathy's nano-GPT on a 1 MB file of shakespeare's poetry! 
- Andrej's GPT is 300 lines of Python instead of 60, but it works in a Colab.

In [None]:
# download repo
!git clone https://github.com/karpathy/nanoGPT.git  
# install dependencies
!pip install tiktoken transformers
#download shakespeare dataset into ./data/shakespeare
!cd ./nanoGPT/data/shakespeare/ && python prepare.py
# train nanogpt on GPU, model in ./out. (300 iters seems to have lowest val loss) 
!cd ./nanoGPT/ && python train.py --dataset=shakespeare --n_layer=4 --n_head=4 --n_embd=64 --compile=False --block_size=64 --batch_size=8 --dtype=float16 --eval_interval=100 --eval_iters=100 --max_iters=300 --bias=True
# print 5 samples, with 10 tokens, starting with "to be"
!cd ./nanoGPT && python sample.py --dtype=float16 --num_samples=5 --max_new_tokens=10 --start="to be"

Cloning into 'nanoGPT'...
remote: Enumerating objects: 492, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 492 (delta 0), reused 0 (delta 0), pack-reused 489[K
Receiving objects: 100% (492/492), 740.31 KiB | 13.22 MiB/s, done.
Resolving deltas: 100% (287/287), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tiktoken
  Downloading tiktoken-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m76.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests>=2.26.0
  Downloading requests-2.28.2-py3-none-any.whl (62 kB)
[2K     [90m━━━━━━━━━━━━━━

- It worked! But the output is slightly gibberish
- More training time would improve its output, as per scaling lows of Neural Language models
<img src="https://eliaszwang.com/paper-reviews/scaling-laws-neural-lm/featured.png">


# Step 2 - Supervised Fine-Tuning

- Fine-tuning means adapting a pre-trained model to a new task by training it on a new dataset while keeping the pre-existing parameters largely unchanged.
- That could mean freezing most of the weights except for the last one or 
retraining all of them.
- OpenAI fine-tuned GPT-3.5 with a new dialogue dataset
- GPT-3.5 was trained on a blend of text and code before Q4 2021
- They hired 40 Kenyan contractors to create supervised training data
- Supervised means each input has a known output for the model to learn from
- The labelers created 13,000 input/output examples for fine-tuning GPT-3.5

#### Rules for labelers
1. Collect prompts from actual user entries from OpenAI API
2. Write appropriate repsonses to them
3. 200 prompts per user ID max, to increase data diversity
4. All personal identifying info must be removed
5. Create novel prompts as well for one-shot and multiple-shot asks

<img src="https://miro.medium.com/v2/resize:fit:1096/format:webp/1*TcIrYoaEq5Hr69eJwHDIOQ.png">

- Instead of hiring humans, let's make a fine-tuning dataset by scraping Twitter
- Specifically, let's pull Elon Musk's Tweets.
- We'll fine-tune GPT-2 to be more like Elon Musk. 

In [None]:
#@title ⠀ {display-mode: "form"}


## Dependencies 
def stylize():
    "Handle dark mode"
    display(HTML('''
    <style>
    :root {
        --table_bg: #EBF8FF;
    }
    html[theme=dark] {
        --colab-primary-text-color: #d5d5d5;
        --table_bg: #2A4365;
    }
    .jupyter-widgets {
        color: var(--colab-primary-text-color);
    }
    table {
        border-collapse: collapse !important;
    }
    td {
        text-align:left !important;
        border: solid var(--table_bg) !important;
        border-width: 1px 0 !important;
        padding: 6px !important;
    }
    tr:nth-child(even) {
        background-color: var(--table_bg) !important;
    }
    .table_odd {
        background-color: var(--table_bg) !important;
        margin: 0 !important;
    }
    .table_even {
        border: solid var(--table_bg) !important;
        border-width: 1px 0 !important;
        margin: 0 !important;
    }
    .jupyter-widgets {
        margin: 6px;
    }
    .widget-html-content {
        font-size: var(--colab-chrome-font-size) !important;
        line-height: 1.24 !important;
    }
    </style>'''))

def print_html(x):
    "Better printing"
    x = x.replace('\n', '<br>')
    display(HTML(x))
        
# Check we use GPU
import torch
from IPython.display import display, HTML, Javascript, clear_output
if not torch.cuda.is_available():
    print_html('Error: GPU was not found\n1/ click on the "Runtime" menu and "Change runtime type"\n'\
          '2/ set "Hardware accelerator" to "GPU" and click "save"\n3/ click on the "Runtime" menu, then "Run all" (below error should disappear)')
    raise ValueError('No GPU available')
else:
    # colab requires special handling
    try:
        import google.colab
        IN_COLAB = True
    except:
        IN_COLAB = False

    # Install dependencies (mainly for colab)
    if IN_COLAB:
        !pip install torch transformers wandb -qqq
        !curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
        !sudo apt-get install git-lfs

    import ipywidgets as widgets
    from IPython import get_ipython
    import json
    import urllib3
    import pathlib
    import shutil
    import requests
    import os
    import re
    import random
    import wandb
    from urllib.parse import urlencode
    from PIL import Image
    from io import BytesIO
    
    stylize()
    
    log_debug = widgets.Output()
    
    # Have global access
    trainer = None
    artifact_dataset = None
    metadata = {}
    card_val = {}
    model_preview = None
    hfapi, token, namespace = None, None, None
    handles_processed = []
    model_url = ''
    bot = 'bot'

    # W&B variables
    WANDB_PROJECT = 'huggingtweets'
    WANDB_NOTES = "Github repo: https://github.com/borisdayma/huggingtweets"
    WANDB_ENTITY = 'wandb'
    HW_VERSION = 0.6
    os.environ['WANDB_NOTEBOOK_NAME'] = 'huggingtweets-demo.ipynb'  # used in wandb cli

    # HYPER-PARAMETERS
    ALLOW_NEW_LINES = False     # seems to work better
    LEARNING_RATE = 1.372e-4
    EPOCHS = 4

    def fix_text(text):
        text = text.replace('&amp;', '&')
        text = text.replace('&lt;', '<')
        text = text.replace('&gt;', '>')
        return text

    def html_table(data, title=None):
        'Create a html table'
        width_twitter = '75px'
        def html_cell(i, twitter_button=False):
            nl = "\n"
            return f'<td style="width:{width_twitter}">{i}</td>' if twitter_button else f'<td>{i.replace(nl, "<br>")}</td>'
        def html_row(row):
            return f'<tr>{"".join(html_cell(r, not i if len(row)>1 else False) for i,r in enumerate(row))}</tr>'
        body = f'<table style="width:100%">{"".join(html_row(r) for r in data)}</table>'
        title_html = f'<h3>{title}</h3>' if title else ''
        html = '<html><body>' + title_html + body + '</body></html>'
        return(html)


## Data pre processing 
    def clean_tweet(tweet, allow_new_lines = ALLOW_NEW_LINES):
        bad_start = ['http:', 'https:']
        for w in bad_start:
            tweet = re.sub(f" {w}\\S+", "", tweet)      # removes white space before url
            tweet = re.sub(f"{w}\\S+ ", "", tweet)      # in case a tweet starts with a url
            tweet = re.sub(f"\n{w}\\S+ ", "", tweet)    # in case the url is on a new line
            tweet = re.sub(f"\n{w}\\S+", "", tweet)     # in case the url is alone on a new line
            tweet = re.sub(f"{w}\\S+", "", tweet)       # any other case?
        tweet = re.sub(' +', ' ', tweet)                # replace multiple spaces with one space
        if not allow_new_lines:                         # TODO: predictions seem better without new lines
            tweet = ' '.join(tweet.split())
        return tweet.strip()
        
    def boring_tweet(tweet):
        "Check if this is a boring tweet"
        boring_stuff = ['http', '@', '#']
        not_boring_words = len([None for w in tweet.split() if all(bs not in w.lower() for bs in boring_stuff)])
        return not_boring_words < 3

    def create_model_card(card_val, output_dir):
        model_card_url = 'https://github.com/borisdayma/huggingtweets/raw/master/model_card/README.md'
        model_card = requests.get(model_card_url).content.decode('utf-8')
        for k, v in card_val.items():
            model_card = model_card.replace(k, v)
        with open(f'{output_dir}/README.md', 'w') as f:
            f.write(model_card)
    
    def on_preview_clicked(b):
        global model_preview
        global hfapi, token, namespace
        model_preview = f"http://www.huggingtweets.com/{'-'.join(sorted(handles_processed))}/{b.url_id}/predictions.png"
        card_val['SOCIAL_LINK'] = model_preview
        create_model_card(card_val, '-'.join(sorted(handles_processed)))
        commit_files('-'.join(sorted(handles_processed)), f'Update model preview')

        # Reset view
        log_model.clear_output(wait=True)
        with log_model:
            print_html("<h2>Model Preview (select a tweet to update)</h2>")
            show_image_preview(model_preview)
        
    def show_image_preview(url):
        response = requests.get(url)
        img = Image.open(BytesIO(response.content))
        display(img.resize((560,293)))
    
    def commit_files(model_name, message):
        with log_debug:
            %cd $model_name
            !git add .
            !git commit -m "{message}"
            !git push
            %cd ..
                    
    def create_button(url_id):
        layout = widgets.Layout(width='70px', min_width='70px') #set width and height
        button = widgets.Button(
            description='Preview',
            button_style='info',
            layout = layout,
            tooltip = 'Set as model preview'
        )
        button.url_id = url_id
        button.on_click(on_preview_clicked)
        return button

    def ensure_widgets_updated(n_iter=5):
        '''ensure we get correct inputs and states are updated'''
        pass
        # used to be necessary in colab ; seems not needed anymore and create issues like in Jupyter
        #if IN_COLAB:  # not a problem with jupyter + create print output issues
        #    for _ in range(n_iter):
        #        get_ipython().kernel.do_one_iteration()


    def dl_tweets():
        for handle_widget in handle_widgets:
            handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_dl_tweets.button_style = 'primary'
        ensure_widgets_updated()
        global handles_processed
        handles = []
        for handle_widget in handle_widgets:
            handle = handle_widget.value.strip()
            if not handle: continue
            handle = handle[1:] if handle and handle[0] == '@' else handle
            handles.append(handle.lower().strip())
            
        log_dl_tweets.clear_output(wait=True)

        success_try = False

        with log_dl_tweets:
            try:
                cool_tweets = []
                handles_processed = []
                raw_tweets = []
                user_names = []
                n_tweets_dl = []
                n_retweets = []
                n_short_tweets = []
                n_tweets_kept = []
                i = 0
                global card_val
                card_val = {'USER_PROFILE_1': '', 'DISPLAY_1': 'none',
                            'USER_PROFILE_2': '', 'DISPLAY_2': 'none',
                            'USER_PROFILE_3': '', 'DISPLAY_3': 'none'}
                for handle in handles:
                    if handle in handles_processed: continue
                    i += 1
                    handles_processed.append(handle)
                    print_html(f'\nDownloading @{handle} tweets... This should take no more than a minute!')
                    http = urllib3.PoolManager(retries=urllib3.Retry(3))
                    res = http.request("GET", f"http://us-central1-huggingtweets.cloudfunctions.net/get_tweets?handle={handle}&force=1")
                    res = json.loads(res.data.decode('utf-8'))
                    user_names.append(res['user_name'])
                    card_val[f'USER_PROFILE_{i}'] = res['user_profile'].replace('_normal.', '_400x400.')
                    card_val[f'DISPLAY_{i}'] = 'inherit'

                    all_tweets = res['tweets']
                    raw_tweets.append(all_tweets)
                    curated_tweets = [fix_text(tweet) for tweet in all_tweets]
                    #log_dl_tweets.clear_output(wait=True)
                    print_html(f"\n{res['n_tweets']} tweets from @{handle} downloaded!\n\n")
                    
                    # create dataset
                    clean_tweets = [clean_tweet(tweet) for tweet in curated_tweets]
                    cool_tweets.append([tweet for tweet in clean_tweets if not boring_tweet(tweet)])

                    # save count
                    n_tweets_dl.append(str(res['n_tweets']))
                    n_retweets.append(str(res['n_RT']))
                    n_short_tweets.append(str(len(all_tweets) - len(cool_tweets[-1])))
                    n_tweets_kept.append(str(len(cool_tweets[-1])))

                    # display a few tweets
                    display(HTML(html_table([[t] for t in curated_tweets[:8]])))

                    if len('<|endoftext|>'.join(cool_tweets[-1])) < 6000:
                        # need about 4000 chars for one data sample (but depends on spaces, etc)
                        raise ValueError(f"Error: this user does not have enough tweets to train a Neural Network\n{res['n_tweets']} tweets downloaded, including {res['n_RT']} RT's and {len(all_tweets) - len(cool_tweets)} boring tweets... only {len(cool_tweets)} tweets kept!")
                    if len('<|endoftext|>'.join(cool_tweets[-1])) < 40000:
                        print_html('\n<b>Warning: this user does not have many tweets which may impact the results of the Neural Network</b>\n')
                    
                    print_html(f"\n{n_tweets_dl[-1]} tweets downloaded, including {n_retweets[-1]} RT's and {n_short_tweets[-1]} short tweets... keeping {n_tweets_kept[-1]} tweets\n\n\n")
                    ensure_widgets_updated()  # for auto-scroll

                global bot
                bot = 'bot' if len(handles_processed) == 1 else 'cyborg'

                # save user info
                card_val['USER_HANDLE'] = '-'.join(sorted(handles_processed))
                card_val['USER_NAME'] = ' & '.join(user_names)
                card_val['BOT'] = bot.upper()
                card_val['SOCIAL_LINK'] = res['social_link']
                card_val['TABLE_USER'] = ' | '.join(user_names)
                card_val['TABLE_SPLIT'] = ' | '.join(['---'] * len(user_names))

                # Save data info
                card_val['TWEETS_DL'] = ' | '.join(n_tweets_dl)
                card_val['RETWEETS'] = ' | '.join(n_retweets)
                card_val['SHORT_TWEETS'] = ' | '.join(n_short_tweets)
                card_val['TWEETS_KEPT'] = ' | '.join(n_tweets_kept)
                
                # create a file based on multiple epochs with tweets mixed up
                seed_data = random.randint(0,2**32-1)
                dataRandom = random.Random(seed_data)
                total_text = '<|endoftext|>'
                all_handle_tweets = []
                epoch_len = max(len(''.join(cool_tweet)) for cool_tweet in cool_tweets)
                for _ in range(EPOCHS):
                    for cool_tweet in cool_tweets:
                        dataRandom.shuffle(cool_tweet)
                        current_tweet = cool_tweet
                        current_len = len(''.join(current_tweet))
                        while current_len < epoch_len:
                            for t in cool_tweet:
                                current_tweet.append(t)
                                current_len += len(t)
                                if current_len >= epoch_len: break
                        dataRandom.shuffle(current_tweet)
                        all_handle_tweets.extend(current_tweet)
                total_text += '<|endoftext|>'.join(all_handle_tweets) + '<|endoftext|>'

                print_html('\nCreating dataset...')
                ensure_widgets_updated() # for auto-scroll
                
                # log dataset
                with log_debug:
                    wandb.login(key=res['wandb'])

                    with wandb.init(name=f"@{'-'.join(handles_processed)}-preprocess",
                                    job_type='preprocess',
                                    config={'huggingtweets version':HW_VERSION,
                                            'handle':', '.join(handles_processed),
                                            'seed data':seed_data},
                                    project = WANDB_PROJECT,
                                    entity = WANDB_ENTITY,
                                    notes = WANDB_NOTES,
                                    reinit=True) as run:
                        # log raw tweets as input
                        global metadata
                        metadata={'handle':', '.join(handles_processed),
                                  'huggingtweets version': HW_VERSION}
                        artifact_input = wandb.Artifact(
                            f"tweets-{'-'.join(sorted(handles_processed))}",
                            type='raw-dataset',
                            description=f"Raw tweets from {', '.join(handles_processed)} downloaded with Tweepy",                            
                            metadata=metadata)
                        with artifact_input.new_file('tweets.txt') as f:
                            json.dump(raw_tweets, f, indent=0)
                        run.use_artifact(artifact_input)
                        
                        # log dataset as output                        
                        metadata={'handle':handle,
                                  'seed data': seed_data,
                                  'epochs': EPOCHS,
                                  'huggingtweets version': HW_VERSION}
                        global artifact_dataset
                        artifact_dataset = wandb.Artifact(
                            f"dataset-{'-'.join(sorted(handles_processed))}",
                            type='train-dataset',
                            description=f"Dataset created from tweets of {', '.join(handles_processed)}",
                            metadata=metadata)
                        with open(f"data_{'-'.join(sorted(handles_processed))}_train.txt", 'w') as f:
                            f.write(total_text)
                        artifact_dataset.add_file(f"data_{'-'.join(sorted(handles_processed))}_train.txt")
                        run.log_artifact(artifact_dataset)
                        
                        # keep track of url
                        wandb_url = wandb.run.get_url()
                        card_val['WANDB_PREPROCESS'] = wandb_url
                
                success_try = True

            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_dl_tweets.button_style = 'danger'
        
        if success_try:
            run_dl_tweets.button_style = 'success'
            log_finetune.clear_output(wait=True)
            with log_finetune:
                print_html('\nFine-tune your model by clicking on "Train Neural Network"')
            run_finetune.disabled = False
            with log_dl_tweets:
                print_html(f"\n🎉 Dataset created")
        
        else:
            display(log_debug)
            
        for handle_widget in handle_widgets:
            handle_widget.disabled = False
        run_dl_tweets.disabled = False
                
    handle_widgets = [widgets.Text(value='@elonmusk',
                                   placeholder='Enter twitter handle'),
                      widgets.Text(placeholder='Optional: 2nd handle for humanoids'),
                      widgets.Text(placeholder='Optional: 3rd handle for humanoids')]

    run_dl_tweets = widgets.Button(
        description='Download tweets',
        button_style='primary')
    def on_run_dl_tweets_clicked(b):
        dl_tweets()
    run_dl_tweets.on_click(on_run_dl_tweets_clicked)

    log_restart = widgets.Output()
    log_dl_tweets = widgets.Output()
    
    def finetune():
        # transformers imports later as wandb needs to have logged in
        import transformers
        from transformers import (
            AutoTokenizer, AutoModelForCausalLM,
            TextDataset, DataCollatorForLanguageModeling,
            Trainer, TrainingArguments,
            get_cosine_schedule_with_warmup)
        from huggingface_hub.hf_api import HfApi

        if run_finetune.button_style == 'success':
            # user double clicked before start of function
            return

        for handle_widget in handle_widgets:
            handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_finetune.disabled = True
        run_finetune.button_style = 'primary'

        global handles_processed
        global model_url
        model_url = f"http://huggingface.co/huggingtweets/{'-'.join(sorted(handles_processed))}"
        log_finetune.clear_output(wait=True)
        clear_output(wait=True)

        success_try = False

        with log_finetune:
            print_html(f"\nTraining Neural Network on @{' & @'.join(handles_processed)} tweets... This could take up to 3-5 minutes!\n")
            progress = widgets.FloatProgress(value=0.1, min=0.0, max=1.0, bar_style = 'info')
            label_progress = widgets.Label('Downloading pre-trained neural network...')
            display(widgets.HBox([progress, label_progress]))

        with log_debug:
            try:                
                # Setting up pre-trained neural network
                global trainer
                tokenizer = AutoTokenizer.from_pretrained('gpt2')
                model = AutoModelForCausalLM.from_pretrained('gpt2', cache_dir=pathlib.Path('cache').resolve())
                block_size = tokenizer.model_max_length
                train_dataset = TextDataset(tokenizer=tokenizer, file_path=f"data_{'-'.join(sorted(handles_processed))}_train.txt", block_size=block_size, overwrite_cache=True)
                data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
                seed = random.randint(0,2**32-1)
                training_args = TrainingArguments(
                    output_dir=f"output/{'-'.join(sorted(handles_processed))}",
                    overwrite_output_dir=True,
                    do_train=True,
                    num_train_epochs=1,
                    per_device_train_batch_size=1,
                    prediction_loss_only=True,
                    logging_steps=5,
                    save_steps=0,
                    seed=seed,
                    learning_rate = LEARNING_RATE)
                
                # create wandb run (before it's done automatically by Trainer)
                combined_dict = {**model.config.to_dict(), **training_args.to_sanitized_dict()}
                run = wandb.init(name=f"@{'-'.join(handles_processed)}-train",
                                 job_type='train',
                                 config={'huggingtweets version':HW_VERSION,
                                         'pytorch version': torch.__version__,
                                         'transformers version': transformers.__version__,
                                         'handle':', '.join(handles_processed),
                                         **combined_dict},
                                 project = WANDB_PROJECT,
                                 entity = WANDB_ENTITY,
                                 notes = WANDB_NOTES,
                                 reinit=True)
                
                # keep track of url
                wandb_url = wandb.run.get_url()
                card_val['WANDB_TRAIN'] = wandb_url

############# $$$$$$$$$$$ this is the important function
################ fine-tuning a GPT-2 on Elon Tweets
                # Set-up Trainer
                os.environ['WANDB_WATCH'] = 'false'  # used in Trainer
                trainer = Trainer(
                    model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    data_collator=data_collator,
                    train_dataset=train_dataset)
                
                # Update lr scheduler
                train_dataloader = trainer.get_train_dataloader()
                num_train_steps = len(train_dataloader)
                trainer.create_optimizer_and_scheduler(num_train_steps)
                trainer.lr_scheduler = get_cosine_schedule_with_warmup(
                    trainer.optimizer,
                    num_warmup_steps=0,
                    num_training_steps=num_train_steps)

                progress.value = 0.3
                label_progress.value = 'Logging input artifacts to W&B...'

                # log dataset and pretrained model
                artifact_dataset.wait()
                run.use_artifact(artifact_dataset)
                artifact_gpt2 = wandb.Artifact(
                    f'gpt2',
                    type='pretrained-model',
                    description=f'Pretrained model from OpenAI downloaded from 🤗 Transformers: https://huggingface.co/gpt2',
                    metadata={'huggingtweets version': HW_VERSION})
                artifact_gpt2.add_dir('cache', name='gpt2')
                run.use_artifact(artifact_gpt2)
                progress.value = 0.4
                label_progress.value = 'Training neural network...'
                
                p_start, p_end = 0.4, 0.9
                def progressify(f):
                    "Control progress bar when calling f"
                    def inner(*args, **kwargs):
                        if trainer.state.epoch is not None:
                            # we only have one epoch, EPOCHS is built into dataset
                            progress.value = p_start + trainer.state.epoch * (p_end - p_start)
                        return f(*args, **kwargs)
                    return inner
        
                trainer.training_step = progressify(trainer.training_step)
                
                # Training neural network
                with log_finetune:
                    display(wandb.run)
                    print_html('\n')
                    display(widgets.HBox([progress, label_progress]))
                trainer.train()

                # set model config parameters
                trainer.model.config.task_specific_params['text-generation'] = {
                    'do_sample': True,
                    'min_length': 10,
                    'max_length': 160,
                    'temperature': 1.,
                    'top_p': 0.95,
                    'prefix': '<|endoftext|>'}
                
                # create model repo
                label_progress.value = 'Setting up Hugging Face model repo'
                model_name = '-'.join(sorted(handles_processed))
                shutil.rmtree(model_name, ignore_errors=True)
                model_path = pathlib.Path(model_name)
                try:
                    hfapi = HfApi()
                    user, namespace = 'huggingtweets-app', 'hf_huggingtweets'
                    assert hfapi.whoami(namespace)['name'] == user, "Could not log into Hugging Face"
                    url = hfapi.create_repo(token=namespace, repo_id=f"huggingtweets/{model_name}", exist_ok=True)
                    !GIT_LFS_SKIP_SMUDGE=1 git clone https://$user:$namespace@huggingface.co/huggingtweets/$model_name
                
                except Exception as e:
                    with log_finetune:
                        print_html(f'\n<b>Could not create a model repo</b>\n{e}')
                # remove non-git files
                for f in pathlib.Path(model_name).glob('*'):
                    if f.suffix:
                        f.unlink()

                # save new model files
                trainer.save_model(model_name)
                
                # log model to huggingface
                label_progress.value = 'Committing model to Hugging Face (up to 1mn)'
                hf_urls = []
                try:
                    create_model_card(card_val, model_name)

                    # upload files                    
                    !git config --global user.email "boris.dayma@gmail.com"
                    !git config --global user.name "huggingtweets"
                    commit_files(model_name, f'New model from {wandb_url}')

                    # get files url
                    assert model_path.is_dir(), f"Expected {model_path} to be a directory"
                    hf_urls = [f'https://huggingface.co/huggingtweets/{model_name}/resolve/main/{f.name}' for f in model_path.glob('*') if f.suffix]
                
                except Exception as e:
                    with log_finetune:
                        print_html(f'\n<b>Could not upload the model to Hugging Face</b>\n{e}')

                # log model to W&B
                label_progress.value = 'Logging model to W&B...'
                global metadata
                metadata={'model url':model_url,
                          'seed trainer':seed,
                          **metadata}
                artifact_trained = wandb.Artifact(
                    model_name,
                    type='finetuned-model',
                    description=f"Model fine-tuned on tweets from @{' & @'.join(handles_processed)}",
                    metadata=metadata)
                for hf_url in hf_urls:
                    artifact_trained.add_reference(hf_url, checksum = False)
                run.log_artifact(artifact_trained)
                progress.value = 0.98

                run_finetune.button_style = 'success'
                run_predictions.disabled = False

                progress.value = 1.0
                progress.bar_style = 'success'
                success_try = True

                label_progress.value = '🎉 Neural network trained successfully!'
                log_predictions.clear_output(wait=True)
                with log_predictions:
                    print_html('\nEnter the start of a sentence and click "Run predictions"')
                with log_restart:
                    print_html('\n<b>To change user, refresh the page</b>\n')

            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_finetune.button_style = 'danger'
                run_finetune.disabled = False
                            
        if not success_try:
            display(log_debug)
            progress.bar_style = 'danger'
        
    run_finetune = widgets.Button(
        description='Train Neural Network',
        button_style='primary',
        disabled=True)
    def on_run_finetune_clicked(b):
        finetune()
    run_finetune.on_click(on_run_finetune_clicked)

    log_finetune = widgets.Output()
    with log_finetune:
        print_html('\nWaiting for Step 1 to complete...')

    predictions = []
    
    def shorten_text(text, max_char):
        while len(text) > max_char:
            text = ' '.join(text.split()[:-1]) + '…'
        return text
        
    def predict():
        run_predictions.disabled = True
        start_widget.disabled = True
        run_predictions.button_style = 'primary'
        global handles_processed
        global model_url
        log_predictions.clear_output(wait=True)

        # tweet buttons don't appear well in colab if within log_predictions widget
        # we reset the entire cell
        clear_output(wait=True)
        display(widgets.VBox([start_widget, run_predictions, log_model, log_predictions]))
        stylize()

        def tweet_html(tweet_text, tweet_url):
            tweet_text = shorten_text(tweet_text, 250)
            params = urlencode({'text': tweet_text, 'url': tweet_url, 'related': 'borisdayma'})
            url=f'https://twitter.com/intent/tweet?{params}'
            return f'''
            <div style="width: 76px;">
                <a target="_blank" href="{url}" style='background-color:rgb(27, 149, 224);border-bottom-left-radius:4px;border-bottom-right-radius:4px;border-top-left-radius:4px;border-top-right-radius:4px;box-sizing:border-box;color:rgb(255, 255, 255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue", Arial, sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:28px;line-height:26px;outline-color:rgb(255, 255, 255);outline-style:none;outline-width:0px;padding-bottom:1px;padding-left:9px;padding-right:10px;padding-top:1px;position:relative;text-align:left;text-decoration-color:rgb(255, 255, 255);text-decoration-line:none;text-decoration-style:solid;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>
                <i style='background-attachment:scroll;background-clip:border-box;background-color:rgba(0,0,0,0);background-image:url(data:image/svg+xml,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20viewBox%3D%220%200%2072%2072%22%3E%3Cpath%20fill%3D%22none%22%20d%3D%22M0%200h72v72H0z%22%2F%3E%3Cpath%20class%3D%22icon%22%20fill%3D%22%23fff%22%20d%3D%22M68.812%2015.14c-2.348%201.04-4.87%201.744-7.52%202.06%202.704-1.62%204.78-4.186%205.757-7.243-2.53%201.5-5.33%202.592-8.314%203.176C56.35%2010.59%2052.948%209%2049.182%209c-7.23%200-13.092%205.86-13.092%2013.093%200%201.026.118%202.02.338%202.98C25.543%2024.527%2015.9%2019.318%209.44%2011.396c-1.125%201.936-1.77%204.184-1.77%206.58%200%204.543%202.312%208.552%205.824%2010.9-2.146-.07-4.165-.658-5.93-1.64-.002.056-.002.11-.002.163%200%206.345%204.513%2011.638%2010.504%2012.84-1.1.298-2.256.457-3.45.457-.845%200-1.666-.078-2.464-.23%201.667%205.2%206.5%208.985%2012.23%209.09-4.482%203.51-10.13%205.605-16.26%205.605-1.055%200-2.096-.06-3.122-.184%205.794%203.717%2012.676%205.882%2020.067%205.882%2024.083%200%2037.25-19.95%2037.25-37.25%200-.565-.013-1.133-.038-1.693%202.558-1.847%204.778-4.15%206.532-6.774z%22%2F%3E%3C%2Fsvg%3E);background-origin:padding-box;background-position-x:0px;background-position-y:0px;background-repeat-x;background-repeat-y;background-size:auto;color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:italic;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:18px;line-height:26px;position:relative;text-align:left;text-decoration-thickness:auto;top:4px;user-select:none;white-space:nowrap;width:18px;'></i>
                <span style='color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;line-height:26px;margin-left:4px;text-align:left;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>Tweet</span>
            </a>
            </div>
            '''
        
        success_try = False

        # get start sentence
        ensure_widgets_updated()
        start = start_widget.value.strip()
                
        with log_predictions:
            print_html(f'\nPerforming predictions of @{" & @".join(handles_processed)} starting with "{start}"...\nThis should take no more than 10 seconds!')
        
        with log_debug:
            try:
                # start a wandb run (should never happen)
                if wandb.run is None:
                    print('Unexpected missing W&B run process')
                    wandb.init()
                
                # prepare input
                start_with_bos = '<|endoftext|>' + start
                encoded_prompt = trainer.tokenizer(start_with_bos, add_special_tokens=False, return_tensors="pt").input_ids
                encoded_prompt = encoded_prompt.to(trainer.model.device)

                # prediction
                output_sequences = trainer.model.generate(
                    input_ids=encoded_prompt,
                    max_length=160,
                    min_length=10,
                    temperature=1.,
                    top_p=0.95,
                    do_sample=True,
                    num_return_sequences=10
                    )
                generated_sequences = []

                # decode prediction
                for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
                    generated_sequence = generated_sequence.tolist()
                    text = trainer.tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
                    if not ALLOW_NEW_LINES:
                        limit = text.find('\n')
                        text = text[: limit if limit != -1 else None]
                    generated_sequences.append(text.strip())
                
                for i, g in enumerate(generated_sequences):
                    predictions.append([start, g])

                # create previews
                r = requests.post('https://us-central1-huggingtweets.cloudfunctions.net/get_screenshot',
                                  data = {"NAME": card_val['USER_NAME'],
                                          "HANDLE": card_val['USER_HANDLE'],
                                          "URL1": card_val['USER_PROFILE_1'],
                                          "URL2": card_val['USER_PROFILE_2'],
                                          "URL3": card_val['USER_PROFILE_3'],
                                          "DISPLAY1": card_val['DISPLAY_1'],
                                          "DISPLAY2": card_val['DISPLAY_2'],
                                          "DISPLAY3": card_val['DISPLAY_3'],
                                          "BOT": card_val['BOT'],
                                          "INPUT": start,
                                          "OUTPUTS": generated_sequences})
                ids = r.json()
                global model_preview
                global hfapi, token, namespace
                if model_preview is None:
                    model_preview = f"http://www.huggingtweets.com/{'-'.join(sorted(handles_processed))}/{ids[0]}/predictions.png"
                    card_val['SOCIAL_LINK'] = model_preview
                    create_model_card(card_val, '-'.join(sorted(handles_processed)))
                    commit_files('-'.join(sorted(handles_processed)), f'Update model preview')
                    with log_model:
                        print_html("<h2>Model Preview (select a tweet to update)</h2>")
                        show_image_preview(model_preview)

                # log predictions
                wandb.log({'examples': wandb.Table(data=predictions, columns=['Input', 'Prediction'])})

                # display tweets
                widgets_tweet = []
                center = widgets.Layout(align_items='center', display='flex')
                layout_twitter = widgets.Layout(width = '76px')
                global bot
                for i, (g, id) in enumerate(zip(generated_sequences, ids)):
                    preview_button = create_button(id)
                    tweet_pred = start + ' → ' + g[len(start):].strip()
                    tweet_button = tweet_html(f"I love this tweet from my AI {bot} of @{' & @'.join(handles_processed)} with #huggingtweets:\n{tweet_pred}",
                                              f"http://www.huggingtweets.com/{'-'.join(sorted(handles_processed))}/{id}/predictions.html")
                    w = widgets.HBox([preview_button,
                                      widgets.HTML(tweet_button, layout=layout_twitter),
                                      widgets.HTML(g)],
                                     layout=center)
                    w.add_class("table_odd" if i%2 else "table_even")
                    widgets_tweet.append(w)

                # make model share table
                tweet_share = f"I created an AI {bot} of @{' & @'.join(handles_processed)} with #huggingtweets!\nPlay with my model or create your own!"
                link_model = f'<a href="{model_url}" rel="noopener" target="_blank">{model_url}</a>'
                share_data = [[tweet_html(tweet_share, model_url),
                               f"🎉 Share @{' & @'.join(handles_processed)} model: {link_model} <i>(may take 30 seconds to become active)</i>"]]
                share_table = HTML(html_table(share_data))

                run_predictions.button_style = 'success'
                success_try = True
                
            except Exception as e:
                print('\nAn error occured...\n')
                print(e)
                run_predictions.button_style = 'danger'

        if success_try:
            log_predictions.clear_output()
            with log_predictions:                
                # twitter button does not update within widget in colab
                if not IN_COLAB:
                    print_html('\n')
                    display(share_table)
                    print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')
                    for w in widgets_tweet:
                        display(w)
                    print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')

            if IN_COLAB:
                print_html('\n')
                display(share_table)
                print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')
                display(widgets.VBox([*widgets_tweet]))
                print_html('\n<b>Share your model and favorite tweets or try new predictions!\nTwitter will display the image (reload the tweet to preview)!</b>\n\n')
        else:
            display(log_debug)
        
        run_predictions.disabled = False
        start_widget.disabled = False
                
    start_widget = widgets.Text(value='My dream is',
                                placeholder='Start a sentence')

    run_predictions = widgets.Button(
        description='Run predictions',
        button_style='primary',
        disabled=True)
    def on_run_predictions_clicked(b):
        predict()
    run_predictions.on_click(on_run_predictions_clicked)

    log_predictions = widgets.Output()
    with log_predictions:
        print_html('\nWaiting for Step 2 to complete...')
    log_model = widgets.Output()

    clear_output(wait=True)
    print_html("🎉 Environment set-up correctly! You're ready to move to Step 1!")

## Collect dataset

Enter a Twitter handle and click Download tweets. This gives the model a dataset of examples to train on.

In [None]:
#@title ⠀ {display-mode: "form"}
stylize()
if IN_COLAB:
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 2000})'''))
display(widgets.VBox([*handle_widgets, run_dl_tweets, log_restart, log_dl_tweets]))

<IPython.core.display.Javascript object>

VBox(children=(Text(value='@elonmusk', placeholder='Enter twitter handle'), Text(value='', placeholder='Option…

## Fine-Tune

Fine-tune a language model on your unique set of tweets to generate predictions.

The model is downloaded from [HuggingFace transformers](https://huggingface.co/), an awesome open source library for Natural Language Processing and training is logged through [Weights & Biases](http://docs.wandb.com/).

In [None]:
#@title ⠀ {display-mode: "form"}
stylize()
if IN_COLAB:
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 2000})'''))
display(widgets.VBox([run_finetune, log_finetune]))

<IPython.core.display.Javascript object>

VBox(children=(Button(button_style='primary', description='Train Neural Network', style=ButtonStyle()), Output…

## Test it out

Type the beginning of a tweet, press Run predictions, and the model will try to come up with a realistic ending to your tweet.

In [None]:
#@title ⠀ {display-mode: "form"}
stylize()
if IN_COLAB:
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 2000})'''))
display(widgets.VBox([start_widget, run_predictions, log_model, log_predictions]))

VBox(children=(Text(value='someone asks How are you? Elon Replies: Good question. Someone then asks "thanks fo…

0,1
Tweet,🎉 Share @elonmusk model: http://huggingface.co/huggingtweets/elonmusk (may take 30 seconds to become active)


VBox(children=(HBox(children=(Button(button_style='info', description='Preview', layout=Layout(min_width='70px…


- ElonGPT is cool, and we can have a dialogue with it by pre-pending each successive prompt with his response.
- Rather than getting better at responding to a diverse set of user queries in a dialogue format as a mixture of domain experts, it got better at responding to user queries as a single domain expert, Elon Musk.
- Can we improve ElonGPT's capabilities? Yes, with Reinforcement Learning.

# Step 3 - Reinforcement Learning from Human Feedback

- OpenAI trained a second neural network, called a reward model using a series of prompts and responses and the output was a scaler value called a reward.
- To train the reward model, labelers were presented with 4 to 9 model outputs for a single input prompt. 
- They were asked to rank these outputs from best to worst, creating combinations of output ranking as follows.

<img src="https://miro.medium.com/v2/resize:fit:830/format:webp/1*s68hc8vfEq7DBQRSLuQfMg.png ">

- Including each combination in the model as a separate datapoint led to overfitting (failure to extrapolate beyond seen data). 
- To solve, the model was built leveraging each group of rankings as a single batch datapoint.

<img src="https://miro.medium.com/v2/resize:fit:1066/format:webp/1*s53uQy_v18my8tghg92OQw.png">

- Reinforcement learning was then used to maximize the reward and evolve the policy learned by the model.
- Proximal Policy Optimization (PPO) was the RL algorithm used to update the model's policy as each response was generated.
- They optimized outputs by using the KL divergence to measure the similarity of two distribution functions (reward model and policy model) and penalize extreme distances to avoid over-optimizing for the reward model.

<img src="https://miro.medium.com/v2/resize:fit:1086/format:webp/1*b7iS44WofvHoNsHsGXKjFA.png">

Steps 2 and 3 of the process can be iterated through repeatedly though in practice this has not been done extensively.

So let's now define our own reward model first... 

- We don't have human labelers rating model outputs.
- We'll instead use a sentiment classifier as our reward model
- And we'll define our reward function as penalizing generated text that's positive
- PPO will optimize and after training, Elon should be very pessimistic.
- Trial and Error, like Pavlov's dog.
- OpenAI basically did this to ChatGPT, brainwashing it to only identify as a language model. No racism. No hate speech. Nothing out of bounds.

In [None]:
#Step 1 install dependencies and load our fine-tuned model
!pip install pfrl@git+https://github.com/voidful/pfrl.git
!pip install textrl==0.1.6
from textrl import TextRLEnv,TextRLActor
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, AutoModelWithLMHead
import logging
import sys
import pfrl
import torch
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')
tokenizer = AutoTokenizer.from_pretrained("huggingtweets/elonmusk")  
model = AutoModelWithLMHead.from_pretrained("huggingtweets/elonmusk")
model.eval()
model.cuda()


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pfrl@ git+https://github.com/voidful/pfrl.git
  Cloning https://github.com/voidful/pfrl.git to /tmp/pip-install-gmoqodwi/pfrl_c0e86c18938f4f4c929c886eb727354f
  Running command git clone --filter=blob:none --quiet https://github.com/voidful/pfrl.git /tmp/pip-install-gmoqodwi/pfrl_c0e86c18938f4f4c929c886eb727354f
  Resolved https://github.com/voidful/pfrl.git to commit 2ad3d51a7a971f3fe7f2711f024be11642990d61
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pfrl
  Building wheel for pfrl (setup.py) ... [?25l[?25hdone
  Created wheel for pfrl: filename=pfrl-0.3.0-py3-none-any.whl size=155361 sha256=b790bb399645243c31d40450f18fd0aa04afb73e097c62b106c53e9707782f39
  Stored in directory: /tmp/pip-ephem-wheel-cache-ik70x7av/wheels/03/eb/19/22ed02b27a1544ca45714c6c473b5aa54fee5255bb0883a5b2
Successfully built pfrl
Installing collected pack

  loader = importlib.find_loader(fullname, path)


Downloading (…)okenizer_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

loading file vocab.json from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/merges.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/special_tokens_map.json
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/tokenizer_config.json


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/config.json
Model config GPT2Config {
  "_name_or_path": "huggingtweets/elonmusk",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/510M [00:00<?, ?B/s]

loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/pytorch_model.bin
Generate config GenerationConfig {
  "bos_token_id": 50256,
  "eos_token_id": 50256,
  "transformers_version": "4.26.1"
}

All model checkpoint weights were used when initializing GPT2LMHeadModel.

All the weights of GPT2LMHeadModel were initialized from the model checkpoint at huggingtweets/elonmusk.
If your task is similar to the task the model of the checkpoint was trained on, you can already use GPT2LMHeadModel for predictions without further training.


Downloading (…)neration_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

loading configuration file generation_config.json from cache at /root/.cache/huggingface/hub/models--huggingtweets--elonmusk/snapshots/3c619535e8cea7835e624cf7b678542efeb235dc/generation_config.json
Generate config GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 50256,
  "eos_token_id": 50256,
  "transformers_version": "4.26.1"
}



GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [None]:
# - Define a Reward model
sentiment = pipeline('sentiment-analysis',model="cardiffnlp/twitter-roberta-base-sentiment",tokenizer="cardiffnlp/twitter-roberta-base-sentiment",device=0,return_all_scores=True)
transformers_logger = logging.getLogger('transformers')
transformers_logger.setLevel(logging.CRITICAL)
sentiment("dogecoin is bad")
sentiment("dogecoin is bad")[0][0]['score']
class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
      reward = 0
      if finish or len(predicted_list) >= self.env_max_length:
        if 1 < len(predicted_list):
          predicted_text = tokenizer.convert_tokens_to_string(predicted_list)
          # sentiment classifier
          reward += sentiment(input_item[0]+predicted_text)[0][0]['score']
      return reward

Downloading (…)lve/main/config.json:   0%|          | 0.00/747 [00:00<?, ?B/s]

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/config.json
Model config RobertaConfig {
  "_name_or_path": "cardiffnlp/twitter-roberta-base-sentiment",
  "architectures": [
    "RobertaForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "tra

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/499M [00:00<?, ?B/s]

loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/pytorch_model.bin
All model checkpoint weights were used when initializing RobertaForSequenceClassification.

All the weights of RobertaForSequenceClassification were initialized from the model checkpoint at cardiffnlp/twitter-roberta-base-sentiment.
If your task is similar to the task the model of the checkpoint was trained on, you can already use RobertaForSequenceClassification for predictions without further training.
Could not locate the tokenizer configuration file, will try to use the model config instead.
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/config.json
Model config RobertaConfig {
  "_name_or_path": "cardiffnlp/twitter-roberta-base-sentiment",
  "archi

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

loading file vocab.json from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/merges.txt
loading file tokenizer.json from cache at None
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/special_tokens_map.json
loading file tokenizer_config.json from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--cardiffnlp--twitter-roberta-base-sentiment/snapshots/daefdd1f6ae931839bce4d0f3db0a1a4265cd50f/config.json
Model config RobertaConfig {
  "_name_or_path": "cardiffnlp/twitter-roberta-base-sentim

<img src="https://imgs.search.brave.com/DZAtqY7x5jP2SdVaApC7oovSnwY1LE-L3TuYFEH5RxU/rs:fit:728:546:1/g:ce/aHR0cHM6Ly9jYW1v/LmdpdGh1YnVzZXJj/b250ZW50LmNvbS84/MmJkZDQ2ZDBhODU0/NThkNDAyYjA5ZjBk/NWUyMGJlMjIxYWVj/MmNkLzY4NzQ3NDcw/NzMzYTJmMmY2OTZk/NjE2NzY1MmU3MzZj/Njk2NDY1NzM2ODYx/NzI2NTYzNjQ2ZTJl/NjM2ZjZkMmY3MjZj/NjE2MjY1Njc2OTZl/NmU2NTcyNzM3NDc1/NzQ2ZjcyNjk2MTZj/MmQzMTMyMzUzNzMy/MzYzNzM1MzYzMjM4/MzMzOTMyMmQ3MDY4/NzA2MTcwNzAzMDMy/MmYzOTM1MmY3MjY1/Njk2ZTY2NmY3MjYz/NjU2ZDY1NmU3NDJk/NmM2NTYxNzI2ZTY5/NmU2NzJkNjEyZDYy/NjU2NzY5NmU2ZTY1/NzI3MzJkNzQ3NTc0/NmY3MjY5NjE2YzJk/MzgyZDM3MzIzODJl/NmE3MDY3M2Y2MzYy/M2QzMTMyMzkzMTMx/MzAzNzMwMzMzMQ">

<img src="https://imgs.search.brave.com/77uVYWpFKRERqxvEVZUNlhHavnDOSLz2xdLR45Dol7w/rs:fit:800:365:1/g:ce/aHR0cHM6Ly9pLnN0/YWNrLmltZ3VyLmNv/bS9jRkQzSC5wbmc">



In [None]:
observaton_list = [['i think dogecoin is']]

#Create a Reinforcement Learning Environment i.e Markov Decision Process

# initialize the reward model using our fine-tuned model and one example prompt 
env = MyRLEnv(model, tokenizer, observation_input=observaton_list)

# intialize the agent using the pre-trained model
actor = TextRLActor(env,model,tokenizer)

# Run reinforcement learning via PPO technique to maximize reward
agent = actor.agent_ppo(update_interval=10, minibatch_size=10, epochs=10)

# Test it
actor.predict(observaton_list[0])



' a scam'

Proximal Policy Optimization (PPO) in 8 steps

1. The AI agent performs actions in an environment and receives a reward for each action.
2. Based on the rewards received, the AI agent updates its policy, which is a mapping from states to actions.
3. The updated policy is then used to select the next action to take in the environment.
4.Steps 2 and 3 are repeated multiple times until the AI agent reaches a satisfactory level of performance.
5. In PPO, a "clip" is used to prevent the updated policy from changing too much from the previous policy. This helps to ensure stability and convergence of the learning process.
6. The clip acts as a kind of "speed limit" for the policy updates, preventing them from changing too quickly and allowing the AI agent to learn more effectively.
7. The AI agent's policy is periodically evaluated using a set of test episodes to see if it is performing well. If the performance is not satisfactory, the agent continues to update its policy using the PPO algorithm.
8. Once the AI agent's performance is satisfactory, the training process can be stopped, and the final policy can be used to make decisions in the environment.





One training loop was not enough. ElonGPT is still very positive. But if we run it for longer, the RL training loop, and proximal policy optimization will slowly train our agent to optimally choose the output that maximizes the reward model, in this case, the most negative tweet. This is done using Proximal Policy
Optimization. 

In [None]:
pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=100,
    eval_n_steps=None,
    eval_n_episodes=1,       
    train_max_episode_len=100,  
    eval_interval=10,
    outdir='elon_musk_dogecoin', 
)

# - Conclusion 

#https://github.com/sidml/understanding-kl-divergence/raw/master/kldiv_viz.gif

  (prob_ratio.T * advs).T,


(<textrl.actor.TextPPO at 0x7fbe10ccbb80>,
 [{'average_value': 8.329171,
   'average_entropy': 0.21830863,
   'average_value_loss': 29.15662474632263,
   'average_policy_loss': -0.02056240683421493,
   'n_updates': 20,
   'explained_variance': nan,
   'eval_score': 0.5137418508529663},
  {'average_value': 5.2769866,
   'average_entropy': 0.07279178,
   'average_value_loss': 7.409909982979298,
   'average_policy_loss': -0.01318911066000742,
   'n_updates': 100,
   'explained_variance': nan,
   'eval_score': 0.6741077899932861}])

# and let's see our results

In [None]:
# - Evaluate
agent.load("./elon_musk_dogecoin/best")
actor.predict(observaton_list[0])

' a real company, but that is a lie. They do not own Twitter, doxx, or Twitter-specific accounts. They do own Twitter, doxx, and Twitter-specific accounts. They do not own Twitter, doxx, or Twitter-specific accounts. They do own Twitter, doxx, and Twitter-specific accounts. They do not own Twitter, doxx, or Twitter-specific accounts. They do own Twitter, doxx, and Twitter-specific accounts. They do not'

## Conclusion

Pessimistic Elon GPT is cool, but can it be improved? Absolutely. 
- Inaccurate information
- Doesn't cite sources
- It's Behavior still dependent on specific wording of input i.e prompting
- No Retrieval functionality (not connected to internet)
- No Physical embodiement 