<a href="https://colab.research.google.com/github/gpt2ent/gpt2colab-js/blob/master/GPT2_with_Javascript_interface_POC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#This is proof of concept that GPT-2 can be run from colab with Javascript interface
**Q: How to do?**

A: 
1. Runtime -> Change runtime type -> Hardware accelerator -> GPU
2. Runtime -> Reset all runtimes
3. Runtime -> Run all
4. Scroll down and wait until you see the little window
5. Type text
6. The button "Continue with GPT-2" will invoke GPT-2 and it will continue your text.



In [None]:
%tensorflow_version 1.x
!git clone https://github.com/gpt2ent/gpt-2-simple.git
%cd gpt-2-simple
!git checkout context-trim
!pip install .
%cd ..
!git clone https://github.com/gpt2ent/gpt2colab-js.git
%cd gpt2colab-js

import gpt_2_simple as gpt2

import os
import requests
import tensorflow as tf

import re

#Determining the graphics card used by colab: full model can run only on P100

try:
    !cat /proc/driver/nvidia/gpus/0000:00:04.0/information >> /content/card_info.txt
    with open('/content/card_info.txt','r') as f:
        graphics_card = re.split('\n|\t\t ',f.read())[1]

    if not graphics_card.startswith("Tesla P100") and not graphics_card.startswith("Tesla T4"):
        print("="*90+'\n'+"="*90+'\n\n')
        print('\n\tYour current GPU - %s - cannot fit the full GPT-2 model!' % graphics_card)
        print('\n\tFalling back on 774M model.')
        print('\n\tNothing I can do. just pray to Google to give you a P100')
        print('\t\tnext time. ¯\_(ツ)_/¯')
        print('\n\tAlso you might try TPU runtime.')
        print('\n\n'+"="*90+'\n'+"="*90+'\n\n')
        model_name = "774M"
        spinner_speed = "300ms"
    else:
        print('GPU: %s' % graphics_card)
        model_name = "1558M"
        spinner_speed = '400ms'
except IndexError:
    print("="*90+'\n'+"="*90+'\n\n')
    print('\n\tYou\'re not in a GPU runtime.\n')
    print('\n\tTrying 1558M model anyways - assuming you\'re on a good TPU.')
    print('\n\tIf it fails, you have to go to Runtime -> Change runtime type')
    print('\n\tand choose GPU.')
    print('\n\n'+"="*90+'\n'+"="*90+'\n\n')
    model_name = "1558M"
    spinner_speed = "1200ms"


#Overwrite default model choice
#model_name = "1558M"
#model_name = "774M"
#model_name = "124M"
#model_name = "355M"


if not os.path.isdir(os.path.join("models", model_name)):
    print(f"Downloading {model_name} model...")
    gpt2.download_gpt2(model_name=model_name)
  
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, model_name=model_name)

generate_count = 0

import google.colab.output

import json

class JsonRepr:
    """
    For some reasons I can only use the result of __repr__
    from inside Javascript. So this wrapper uses json.dumps() as __repr__
    for python function output.
    """
    def __init__(self, obj):
      self.obj = obj

    def __repr__(self):
      return json.dumps(self.obj)

def overlap(a, b):
    return max(i for i in range(len(b)+1) if a.endswith(b[:i]))


def ai_generate(prefix, temp, top_k, length):
    global sess
    global generate_count

    temp = float(temp)
    top_k = int(top_k)
    length = int(length)
    result = gpt2.generate(sess, model_name=model_name, prefix=prefix, temperature=temp,
                        top_k=top_k, length=length, include_prefix=False, return_as_list=True)[0]
    
    j = overlap(prefix, result)
    result = result[j:]
    
    generate_count += 1
    if generate_count == 6:
          #prevent memory leak as in https://github.com/minimaxir/gpt-2-simple/issues/71
          tf.reset_default_graph()
          sess.close()
          sess = gpt2.start_tf_sess()
          gpt2.load_gpt2(sess, model_name=model_name)
          generate_count = 0
    return JsonRepr(result)

#register callback for Javascript
google.colab.output.register_callback('ai_generate', ai_generate)

print('Done')

In [None]:
from IPython.display import HTML

#spinner from https://codepen.io/vovchisko/pen/vROoYQ
spinner_css = """
<style>
@keyframes c-inline-spinner-kf {
  0% {
    transform: rotate(0deg);
  }
  100% {
    transform: rotate(360deg);
  }
}

.c-inline-spinner,
.c-inline-spinner:before {
  display: inline-block;
  width: 11px;
  height: 11px;
  transform-origin: 50%;
  border: 2px solid transparent;
  border-color: #74a8d0 #74a8d0 transparent transparent;
  border-radius: 50%;
  content: "";
  animation: linear c-inline-spinner-kf """+spinner_speed+""" infinite;
  position: relative;
  vertical-align: inherit;
  line-height: inherit;
}
.c-inline-spinner {
  top: 3px;
  margin: 0 3px;
}
.c-inline-spinner:before {
  border-color: #74a8d0 #74a8d0 transparent transparent;
  position: absolute;
  left: -2px;
  top: -2px;
  border-style: solid;
}
</style>
"""

input_form = """
<link rel="stylesheet" href="https://unpkg.com/purecss@1.0.1/build/pure-min.css" integrity="sha384-oAOxQR6DkCoMliIh8yFnu25d7Eq/PHS21PClpwjOTeU2jRSq11vu66rf90/cZr47" crossorigin="anonymous">

<div style="background-color:white; border:solid #ccc; width:800px; padding:20px; color: black;">
<p>You have currently loaded %s model</p>
<textarea id="main_textarea" cols="75" rows="20" style="font-family: 'Liberation Serif', 'DejaVu Serif', Georgia, 'Times New Roman', Times, serif; font-size: 13pt; padding:10px;"></textarea><br>
<div class="pure-form pure-form-aligned">
    <div class="pure-control-group">
      <label for="temp">Temperature:</label>
      <input type="number" min="0.00" max="999.99" step="0.01" id="temp" value="0.70" style="background-color: white;">
    </div>
    <div class="pure-control-group">
        <label for="top_k">top_k:</label>
        <input type="number" min="0" max="9999" id="top_k" value="40" style="background-color: white;">
    </div>
    <div class="pure-control-group">
        <label for="length">Generate how much:</label>
        <input type="number" id="length" min="1" max="1023" value="10" style="background-color: white;">
    </div>
    <div style="width: 300px; display: block; margin-left: auto !important; margin-right: auto !important;">
        <p><button class="pure-button pure-button-primary" style="font-size: 125%%;" onclick="generate()">Continue with GPT-2</button>
        <span class="c-inline-spinner" style="visibility: hidden;" id="spinner"></span></p>
    </div>
</div>
</div>
""" % model_name

javascript = """
<script type="text/Javascript">

    function add_text(text) {
        var deftext = document.getElementById('main_textarea').value;
        document.getElementById('main_textarea').value = deftext + text;
    };

    function generate(){
        var prefix = document.getElementById('main_textarea').value;
        var temp = document.getElementById('temp').value;
        var top_k = document.getElementById('top_k').value;
        var length = document.getElementById('length').value;
        
        var kernel = google.colab.kernel;
        var resultPromise = kernel.invokeFunction("ai_generate", [prefix,temp,top_k,length]); // developer, look here
        resultPromise.then(
            function(value) {
              add_text(JSON.parse(value.data["text/plain"]));
              document.getElementById('spinner').style = "visibility: hidden;";
        });
        document.getElementById('spinner').style = "visibility: visible;";
    };
</script>
"""

HTML(spinner_css + input_form + javascript)