<a href="https://colab.research.google.com/github/gpt2ent/gpt2colab-js/blob/multisample/transformers_js_playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#This is proof of concept that transformers-based language models can be run from colab with Javascript interface
**Q: How to do?**

A: 
1. Runtime -> Change runtime type -> Hardware accelerator -> GPU (or TPU if required)
2. Runtime -> Reset all runtimes
3. Runtime -> Run all
4. Scroll down and wait until you see the little window
5. Type text
6. The button "Continue with Transformer" will invoke Transformer and it will continue your text.

**Q: how do I choose a different model?**

A: look for `model_name` in the first lines of code. You can use any model name from https://huggingface.co/models?pipeline_tag=text-generation here. However the code can't double-check you for potential out-of-memory situations, so good luck with that.

In [None]:
!pip install transformers

spinner_speed = "1000ms"

# choose model
#model_name = 'EleutherAI/gpt-neo-1.3B'
#model_name = 'EleutherAI/gpt-neo-2.7B'
#model_name = 'sberbank-ai/rugpt3large_based_on_gpt2'
#model_name = 'gpt2-xl'
model_name = 'gpt2-large'

do_sample = True

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
max_total_tokens = 1024  # maximum model sequence length

def trim_tensor_max_tokens(tensor, max_length):
    """
    Trim tokenizer output tensor from the beginning, so it can allow to
    generate a required amount of tokens.
    This is a more precise approach to trimming the beginning of a text.
    """
    tensor = tensor[0] # tokenizer returns tensor with shape (1, len)
    tensor = tensor[-(max_total_tokens-max_length):]
    tensor = tensor.unsqueeze(0) #back to (1, len) shape as in numpy.expand_dims
    return tensor


import google.colab.output

import json

class JsonRepr:
    """
    For some reasons I can only use the result of __repr__
    from inside Javascript. So this wrapper uses json.dumps() as __repr__
    for python function output.
    """
    def __init__(self, obj):
      self.obj = obj

    def __repr__(self):
      return json.dumps(self.obj)

def overlap(a, b):
    return max(i for i in range(len(b)+1) if a.endswith(b[:i]))


def ai_generate(prefix, temp, top_k, length):

    temp = float(temp)
    top_k = int(top_k)
    length = int(length)

    #convert prefix to tokens
    tokens = tokenizer(prefix, return_tensors='pt')['input_ids']
    tokens = trim_tensor_max_tokens(tokens, length)

    output = model.generate(tokens, min_length=tokens.shape[1]+length,
                                    max_length=tokens.shape[1]+length,
                                    do_sample=do_sample,
                                    temperature=temp,
                                    top_k=top_k,
                                    num_return_sequences=10)
    

    
    result = [tokenizer.decode(output[i,:]) for i in range(10)]
    
    for i, elem in enumerate(result):
        j = overlap(prefix, elem)
        result[i] = result[i][j:]
    
    
    return JsonRepr(result)

#register callback for Javascript
google.colab.output.register_callback('ai_generate', ai_generate)

print('Done')

# ai_generate('My name is Julien and I like to', '1.0', '100', '100')

In [None]:
import numpy as np
def get_sorted_tokens(model, tokenizer, prefix):
    tokens = tokenizer(prefix, return_tensors='pt')['input_ids']
    tokens = trim_tensor_max_tokens(tokens, 1)
    result = model(tokens)['logits'].detach().numpy()
    lastword = result[0,-1]
    indices = np.argsort(lastword)[::-1]
    
    result = [(index, tokenizer.decode([index]), lastword[index]) for index in indices]
    return result

def js_tokens_callback(prefix):
    tokens = get_sorted_tokens(model, tokenizer, prefix)
    result = """<table class="pure-table"><tr><th>Token</th><th>Index</th><th>Strength</th></tr>"""
    for index, token, strength in tokens[:50]:
        result += f"<tr><td>{token}</td><td>{index}</td><td>{strength:.4f}</td></tr>"

    result += "</table>"
    return result

google.colab.output.register_callback('toptokens', js_tokens_callback)

In [None]:
from IPython.display import HTML

#spinner from https://codepen.io/vovchisko/pen/vROoYQ
spinner_css = """
<style>
@keyframes c-inline-spinner-kf {
  0% {
    transform: rotate(0deg);
  }
  100% {
    transform: rotate(360deg);
  }
}

.c-inline-spinner,
.c-inline-spinner:before {
  display: inline-block;
  width: 11px;
  height: 11px;
  transform-origin: 50%;
  border: 2px solid transparent;
  border-color: #74a8d0 #74a8d0 transparent transparent;
  border-radius: 50%;
  content: "";
  animation: linear c-inline-spinner-kf """+spinner_speed+""" infinite;
  position: relative;
  vertical-align: inherit;
  line-height: inherit;
}
.c-inline-spinner {
  top: 3px;
  margin: 0 3px;
}
.c-inline-spinner:before {
  border-color: #74a8d0 #74a8d0 transparent transparent;
  position: absolute;
  left: -2px;
  top: -2px;
  border-style: solid;
}
</style>
"""

input_form = """
<link rel="stylesheet" href="https://unpkg.com/purecss@1.0.1/build/pure-min.css" integrity="sha384-oAOxQR6DkCoMliIh8yFnu25d7Eq/PHS21PClpwjOTeU2jRSq11vu66rf90/cZr47" crossorigin="anonymous">

<div style="background-color:white; border:solid #ccc; width:1200px; padding:20px; color: black;">
<p>You have currently loaded %s model</p>
<div class="pure-g">
<div class="pure-u-2-3">
<textarea id="main_textarea" cols="70" rows="20" style="font-family: 'Liberation Serif', 'DejaVu Serif', Georgia, 'Times New Roman', Times, serif; font-size: 13pt; padding:10px;"></textarea><br>
<div class="pure-form pure-form-aligned">
    <div class="pure-control-group">
      <label for="temp">Temperature:</label>
      <input type="number" min="0.00" max="999.99" step="0.01" id="temp" value="0.70" style="background-color: white;">
    </div>
    <div class="pure-control-group">
        <label for="top_k">top_k:</label>
        <input type="number" min="0" max="9999" id="top_k" value="40" style="background-color: white;">
    </div>
    <div class="pure-control-group">
        <label for="length">Generate how much:</label>
        <input type="number" id="length" min="1" max="1023" value="10" style="background-color: white;">
    </div>
    <div style="width: 600px; display: block; margin-left: auto !important; margin-right: auto !important;">
        <p>
          <button class="pure-button" style="font-size: 125%%;" onclick="prev()">Prev</button>
          <button class="pure-button pure-button-primary" style="font-size: 125%%;" onclick="generate()">Continue with Transformer</button>
          <button class="pure-button" style="font-size: 125%%;" onclick="next()">Next</button>
          <span id="gen-index"></span>
          <span class="c-inline-spinner" style="visibility: hidden;" id="spinner"></span>
        </p>
    </div>

</div>
</div>
<div class="pure-u-1-3">
<p>Top 50 tokens</p>
<div id="top_tokens" style="overflow-y: scroll; height: 550px;"></div>
<button class="pure-button" style="font-size: 125%%;" onclick="toptokens()">Get top tokens</button>
</div>
</div>
</div>
""" % model_name

javascript = """
<script type="text/Javascript">

    var memory;
    var current_index = -1;

    function switch_to(new_index, force=false) {
        if (current_index == -1 && !force) {
            return;
        }

        new_index = Math.max( Math.min(new_index, 9) , 0);

        if (new_index == current_index) {
            return;
        }

        var current_text = document.getElementById('main_textarea').value;
        
        if (current_index != -1) {
            memory[current_index] = current_text;
        }
        current_index = new_index;
        document.getElementById('main_textarea').value = memory[new_index];
        document.getElementById('gen-index').innerHTML = new_index+1 + "/10";
    };

    function clear_memory() {
      memory = undefined;
      current_index = -1;
      document.getElementById('gen-index').innerHTML = "";
    }

    function store_gen_results(jsonstring) {
        var deftext = document.getElementById('main_textarea').value;
        memory = JSON.parse(jsonstring).map(x => deftext+x);
        switch_to(0, true);
    };

    //TODO
    function block_arrows() {};

    function prev() {switch_to(current_index-1)};
    function next() {switch_to(current_index+1)};

    

    function generate(){
        var prefix = document.getElementById('main_textarea').value;
        var temp = document.getElementById('temp').value;
        var top_k = document.getElementById('top_k').value;
        var length = document.getElementById('length').value;
        
        var kernel = google.colab.kernel;
        var resultPromise = kernel.invokeFunction("ai_generate", [prefix,temp,top_k,length]); // developer, look here
        resultPromise.then(
            function(value) {
              store_gen_results(value.data["text/plain"]);
              document.getElementById('spinner').style = "visibility: hidden;";
        });
        document.getElementById('spinner').style = "visibility: visible;";
        
        clear_memory();
        block_arrows();

    };

    function toptokens(){
      var prefix = document.getElementById('main_textarea').value;
      var kernel = google.colab.kernel;
      var resultPromise = kernel.invokeFunction("toptokens", [prefix]);
      document.getElementById('top_tokens').innerHTML = "<p>Working...</p>";
      resultPromise.then(
        function(value) {
            document.getElementById('top_tokens').innerHTML = value.data["text/plain"];
        }
      );
    };
</script>
"""

HTML(spinner_css + input_form + javascript)