<a href="https://colab.research.google.com/github/jonbaer/googlecolab/blob/master/Pokemon_GPT_2_Generate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pokemon GPT-2 Generate

This notebook will let you generate sprite images with either a pre-trained model, or a model you trained with the [Pokemon GPT-2 Train](https://colab.research.google.com/drive/1c1kmO9tixviyBB7IGh-jVpLvOh2RpLYk) notebook.

To use this notebook just follow along with the instructions and run each cell as directed.

---

GPT-2 portions of this notebook are based on the notebook for [GPT-2-Simple](https://github.com/minimaxir/gpt-2-simple)

## 1. Setup

First we import everything we need:

In [None]:
%tensorflow_version 1.x
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files, drive

import os
import sys
import time
import gpt_2_simple as gpt2
import shutil
import math
from PIL import Image

## 2. Google Drive

Google Drive is used to load your model and store the generated sprite images. Here we mount your Google Drive:

In [None]:
drive.mount('/content/drive')

## 3. Download or Import Model

If you want to download a pre-trained model to generate Pokemon sprites, run cell 3a.

If you want to import your own model that you trained with the [Pokemon GPT-2 Train](https://colab.research.google.com/drive/1c1kmO9tixviyBB7IGh-jVpLvOh2RpLYk) notebook, use section 3b.

## 3a. Download a Pre-Trained Model

If you want to use a pre-trained model to generate Pokemon sprites, run this cell:

In [None]:
!cd /content
!wget -O "pokemon-gpt-2-multigen-250000.zip" "https://ipfs.io/ipfs/QmRjkH2szrkez3QaHUKPM1jr3aHnJyN11JpcoRM2EwFHdQ"
!mkdir -p checkpoint
!unzip "pokemon-gpt-2-multigen-250000.zip" -d checkpoint

run_name = 'pokemon-gpt-2-multigen-250000'

## 3b. Import Your Own Model

If you used the [Pokemon GPT-2 Train](https://colab.research.google.com/drive/1c1kmO9tixviyBB7IGh-jVpLvOh2RpLYk) notebook and want to import that model for use, make sure the checkpoint_folder and run_name values match the values you used for training, and run this cell:

In [None]:
checkpoint_folder = 'pokemon-gpt-2-checkpoints'  #@param {type:"string"}
run_name = 'pokemon-gpt-2-run' #@param {type:"string"}

if not os.path.exists('checkpoint/%s' % run_name):
    shutil.copytree('/content/drive/My Drive/%s/%s' % (checkpoint_folder, run_name), 'checkpoint/%s' % run_name)

## 4. Generation

Here we'll generate sprites with the model we imported and save them to Google Drive.

There are a few settings you can change for this cell:

- output_folder - This is the folder that sprites will be saved to in your Google Drive
- generate_count - This is how many sprites will be generated when you run this cell
- temperature - This controls how "crazy" the output is. A high number will lead to something that barely looks like the input, but might be interesting. Too low will be boring. Play with this number.
- width - This is the width that output images will be. Set this based on your input images.
- height - This is the height that output images will be. Set this based on your input images.
- save_texts - When checked this will save the text output along with the generated image. You probably don't need this. But hey, you do you.

Default width and height will work fine for the pre-trained model. Adjust settings as needed and run the cell:



In [None]:
output_folder = 'pokemon-gpt-2-output' #@param {type:"string"}
generate_count = 10 #@param {type:"integer"}
temperature = 0.9 #@param {type:"slider", min:0.7, max:1.3, step:0.01}
width = 64 #@param {type:"integer"}
height = 64 #@param {type:"integer"}
save_texts = False #@param {type:"boolean"}

def createImage(text, path):
    lines = text.split('\n')

    imageWidth = 0
    imageHeight = 0

    for line in lines:
        split = line.split(' ')

        marker = split[0]
        if len(marker) == 3:
            index = int(marker[0:2])

            width = len(split) - 1
            height = index + 1

            if width > imageWidth:
                imageWidth = width;

            if height > imageHeight:
                imageHeight = height

    pixels = []
    for y in range(imageHeight):
        for x in range(imageWidth):
            pixels.append((0, 0, 0, 0))

    for line in lines:
        split = line.split(' ')

        marker = split[0]
        if len(marker) == 3:
            index = int(marker[0:2])

            for x in range(len(split) - 1):
                s = split[x + 1]

                if s != '~':
                    r = 0
                    g = 0
                    b = 0
                    
                    if s == 'a':
                        r = g = b = 107
                    elif s == 'b':
                        r = g = b = 187
                    else:
                        c = ord(s[0]) - 33

                        b = (c & 3) * 64
                        if b == 192:
                            b += 63

                        c = c >> 2
                        g = (c & 3) * 64
                        if g == 192:
                            g += 63

                        c = c >> 2
                        r = (c & 3) * 64
                        if r == 192:
                            r += 63

                    i = (index * imageWidth) + x

                    pixels[i] = (r, g, b, 255)

    image = Image.new('RGBA', (imageWidth, imageHeight))
    image.putdata(pixels)
    image.save(path)

def blankLines():
    lines = []

    for i in range(0, height):
        lines.append('')

    return lines

sess = None

for ii in range(0, generate_count):
    print(ii)
    
    if not sess:
        sess = gpt2.start_tf_sess()
    else:
        sess = gpt2.reset_session(sess)
    
    gpt2.load_gpt2(sess, run_name=run_name)

    lines = blankLines()
    prefix = ''
    hasColor = False

    while True:
        text = gpt2.generate(sess, run_name=run_name, prefix=prefix, temperature=temperature, return_as_list=True)[0]

        print('\n\noutput:');
        print(text)

        newLines = text.split('\n')

        direction = None
        lastIndex = None
        for line in newLines:
                split = line.split(' ')[:width + 2]

                if len(split) < 55:
                    break;

                marker = split[0]
                if len(marker) == 3:
                    try:
                        index = int(marker[0:2])
                    except:
                        break

                    if direction == None:
                        direction = marker[2]

                    if marker[2] != direction:
                        print('direction changed')
                        break

                    if lastIndex != None:
                        if marker[2] == 'd' and index <= lastIndex:
                            print('bad line order')
                            break
                        elif marker[2] == 'u' and index >= lastIndex:
                            print('bad line order')
                            break
                    lastIndex = index

                    split[0] = marker.replace('u', 'd')

                    if not hasColor:
                        for character in split[2:]:
                            if character != '~' and character != '`':
                                hasColor = True
                                break

                    while len(split) < width:
                        split.append('~')

                    try:
                        lines[index] = ' '.join(split)
                    except IndexError:
                        break

        if not hasColor:
            print('no color')
            lines = blankLines()
            continue

        topIndex = None
        for i in range(0, height):
            if lines[i]:
                topIndex = i
                break

        bottomIndex = None
        for i in range(topIndex, height):
            if lines[i]:
                bottomIndex = i
            else:
                break

        print('\n\ntop %i bottom %i' % (topIndex, bottomIndex))

        sectionSize = 5
        if topIndex > 0:
            section = lines[topIndex:min(topIndex+sectionSize+1, bottomIndex+1)]
            section.reverse()
            for i in range(0, len(section)):
                section[i] = section[i].replace('d', 'u')

        elif bottomIndex < height - 1:
            section = lines[max(bottomIndex-sectionSize, topIndex):bottomIndex+1]

        else:
            print('\n'.join(lines))
            filename = '%i' % int(time.time())
            text = '\n'.join(lines)

            if not os.path.exists('/content/drive/My Drive/%s' % output_folder):
                os.makedirs('/content/drive/My Drive/%s' % output_folder)

            if save_texts:
                text_file = open('/content/drive/My Drive/%s/%s.txt' % (output_folder, filename), 'w')
                text_file.write(text)
                text_file.close()
            
            createImage(text, '/content/drive/My Drive/%s/%s.png' % (output_folder, filename))

            print('saved !')
            break

        prefix = '\n'.join(section)
        print('\n\nprefix:\n%s' % prefix)

## 5. Done !

You're done ! Check out the generated images in your chosen Google Drive folder. Run step 5 again to generate more. Cool !