# CycleGAN in Jax + Flax inference demo

This notebook demonstrates how to use pre-trained model weights to generate images with the CycleGAN model code published here: https://github.com/louwjac/CycleGAN-Flax
<br/>Please make sure that you've selected a GPU runtime instance before proceeding. This notebook will not run on Colab TPU's.



In [None]:
#clone the cyclegan code repo
!git clone https://github.com/louwjac/CycleGAN-Flax.git

In [None]:
#get the model weights from huggingface
%cd CycleGAN-Flax
!rm -r ./experiments
!git clone https://huggingface.co/louwjac/CycleGAN-Flax.git experiments/

In [None]:
#install dependencies
!pip install -r notebooks/requirements.txt

In [None]:
import os
import ntpath

import gin
import jax
import jax.numpy as jnp
import numpy as np
import gradio as gr
from PIL import Image
from flax.training import checkpoints

from cyclegan import models, utils

## A note on Gin

Gin is a framework that can be used to configure input parameters in Python. It takes the place of the traditional config-dict, param, ini or yaml files that you may be used to. Please see the documentation included with the official repo (https://github.com/google/gin-config) if you are not already familiar with it. 

The rest of this demo will highlight all instances where Gin configured parameters are used. I recommend that you open one of the [config files]( https://github.com/louwjac/CycleGAN-Flax/tree/main/config) separately to get a sense of what has been configured. 


In [None]:
# Load the gin config file that was used to train the model
# This makes it easy to play with a different model. Just change the config file here and re-run the code cells
cfg_path = 'config/horse2zebra_original.gin' 
cfg = gin.parse_config_file(cfg_path)

#get the name of the gin file
cfg_filename = ntpath.basename(cfg_path).split(".")[0]

In [None]:
work_dir = gin.query_parameter("%work_dir") #read the value from line 6 of the .gin file e.g. "work_dir = './experiments'"
work_dir = os.path.join(work_dir, cfg_filename)
checkpoints_dir = os.path.join( work_dir, "checkpoints")

# Load pre-trained weights

Each CycleGAN model makes use of two generator networks; one to translate images from domain A to domain B, and another to do the reverse. For this demo, one network will turn horses into fake zebras and another will turn zebras into fake horses.

So,
>Network_G: Horse -> Fake Zebra<br>
>Network_F:  Zebra -> Fake Horse

This codebase combines both generator networks into a single "CycleGenerator" class in order to make the training code cleaner and the data structures easier to handle. As a result, the model weights for both networks are joined into a single data structure for the CycleGenerator. We will therefore need to extract the weights for each sub-network from the larger structure before we'll be able to do inference with them separately. This is fortunately very easy to do with Flax models. 

The next few cells will execute the necessary steps to load pre-trained weights. This will consist of:
1) Create an instance of the model class.
2) Initialize model weights for the model instance. This will be the same as the weights you would normally start with in a training session.
3) Load the pre-trained model weights from disk. The initialized weights are passed in to the loader to serve as a template telling Flax how to load the stored weights. 

## Step 1: Create a model instance
This is a case where gin is used to configure the model: ``` generator = models.CycleGenerator()```

Notice that CycleGenerator() is called without arguments. Now take a look at lines 69-73 of [horse2zebra_original.gin](https://github.com/louwjac/CycleGAN-Flax/blob/main/config/horse2zebra_original.gin) :

```
  69 # Parameters for CycleGenerator:
  70 # ==============================================================================
  71 models.ResnetGenerator.residuals = 9
  72 models.ResnetGenerator.features = 64
  73 models.CycleGenerator.base_model = @models.ResnetGenerator
```

CycleGenerator takes an argument named "base_model" that is used as the model class for each of the sub-networks G and F. In this case, that "base_model" parameter has been configured to use an instance of a ResnetGenerator class, which itself has also been configured. Both classes are defined in cyclegan/models.py, hence the prefix "models.CycleGenerator" and "models.ResnetGenerator". Gin will bind the arguments that were configured in horse2zebra_original.gin automatically when we call models.CycleGenerator() without those arguments. 

In [None]:
#initialize a cyclegan generator
generator = models.CycleGenerator()

## Step 2: Initialize model weights
This will create a data structure that will contain initialized model weights similar to what you will use when you start a training session. This is only needed here because it will inform Flax how the trained model weights should be structured when they are loaded. 

In [None]:
kg = utils.KeyGen()
rngs_gen = {'params': kg()}
sample_batch = (jnp.ones(shape=(1,256,256,3)),)*2 # this is a dummy batch of input images that will inform Flax of the input shapes
vars_gen = generator.init(rngs_gen, sample_batch)

# Flax model parameters can include non-trainable states (such as batch statistics)
# but CycleGan does not use any such components. So, we can discard them here. 
# "params_gen" will include only the trainable weights of both networks G and F
states_gen, params_gen = vars_gen.pop('params') 
del vars_gen, states_gen


In [None]:
# run the next line to inspect the shape of the weights structure in case you are curious
utils.print_shapes(params_gen)

## Step 3: Load the saved model weights
The saved models were stored using the checkpoints utility that is included in Flax. That utility is in the process of being [deprecated](https://github.com/google/flax/discussions/2720) in favor of a library named [Orbax](https://github.com/google/orbax). Keep that in mind when you get to a point where you need to checkpoint your own models.

In [None]:
#restore generator model weights from a pre-trained checkpoint
params_gen = checkpoints.restore_checkpoint(
    ckpt_dir=checkpoints_dir,
    target=params_gen,
    prefix='params_'
)

#extract the weights for the two sub-networks from the generator
params_gen, params_g = params_gen.pop('net_g')
params_gen, params_f = params_gen.pop('net_f')

#params_gen is now empty
utils.print_shapes(params_gen)

# Make the inference function
The following steps are required to do inference:

1. Get an input image
2. Resize it so that it is roughly in the same scale as the images that were used to train the model
3. Convert the image to an array 
4. Convert the RGB pixel values from the integer range [0,255] to the floating-point range [-1.0, 1.0]
5. Feed the image array into the generator model and get back a translated array with the same shape
6. Convert the output pixel values from the floating-point range of [-1.0, 1.0] back to the integer range [0, 255]
7. Change the array back into an image object

The next cell will produce a single function for each of the sub-networks G and F that can complete steps 2 through 7

In [None]:
# make the inference function

def resize_img(img: Image):  
  # this takes care of step 2
  height, width = img.size
  scale_in = min(1,256./max(1.,min(height,width)))
  height= int(scale_in*height)
  width= int(scale_in*width)
  img = img.resize(((height, width)))
  return img

def generate_fn(params):
  base_model = models.ResnetGenerator() # don't need to pass any parameters here because it is taken care of in the gin config file.

  @jax.jit
  def gen_fake(inputs):
    # this is step 5
    fake_img = base_model.apply({'params':params}, inputs)
    return fake_img

  def fn(img_in):
    if img_in is None:
      return img_in

    # step 2: resize
    img_out = resize_img(img_in) 

    # step 3: convert to array
    img_out = np.asarray(img_out, dtype=np.uint8)

    # step 4: convert values from RGB
    # this also changes the numpy array into a jax devicearray and loads it into GPU memory
    img_out = (jnp.asarray(img_out) +127.5 ) - 1.0
    img_out = jnp.expand_dims(img_out,0) #the model expects a batch dimention 

    # step 5: the exciting part!
    img_out = gen_fake(img_out)

    # step 6: convert values back to RGB
    img_out = jnp.uint8((img_out[0] +1.0)*127.5)

    # step 7: make an image object with the output
    img_out = Image.fromarray(np.asarray(img_out),mode='RGB')

    return img_out
    
  return fn


# notice that "generate_fn" returns a function
generate_g = generate_fn(params_g) 
generate_f = generate_fn(params_f)

## Test the inference functions
The CycleGAN repository contains a few sample images that have been selected because they work very well with the pre-trained models. Here we will open two of sample images and verify that the inference functions can generate fake horses and zebras. You should consider these results to be the best-case outputs for the models. Most other images will not produce great results from your point of view. Part of the reason for this is that the criteria you use to judge what a great output looks like are very different from the objectives the models were trained with. 

The training objective of horse2zebra model was not to translate horses to zebras. Instead, the objective was to generate images from the **horse image "domain"** to the **zebra image "domain"** and vice versa. The key difference is that an image domain is everything in the training sets, not just the horses and zebras. In other words, the models will make tradeoffs to produce good translations with respect to features you won't even notice instead of just focusing on making good fake horses and zebras. You, on the other hand, will likely judge an output to be bad if the fake horse or fake zebra doesn't look believable even if the model did a fantastic job of translating the grass, trees, sky and water. You will also not notice when those other items look bad but the horse or zebra looks good. 

It will take trial and error to find input images that produce outputs that look great through your eyes.


### Test horse2zebra

In [None]:
# Load a sample horse image
horse_img = Image.open('./images/aleksei-zaitcev-ZZ68lVMON7g-unsplash.jpg')  #credit Aleksei Zaitcev (https://unsplash.com/@laowai66)
# uncomment any lines below to try a different image
# horse_img = Image.open('./images/brendon-van-zyl-PsdLrhj18bg-unsplash.jpg') #credit Brendon van Zyl (https://unsplash.com/@brendonvzyl)
# horse_img = Image.open('./images/immo-wegmann-HT07wMriR1U-unsplash.jpg') #credit Immo Wegmann (https://unsplash.com/@macroman)

#translate it to a zebra
fake_zebra_img = generate_g(horse_img)

In [None]:
#view the input horse image in the same size as it was fed into the model
resize_img(horse_img).show()

In [None]:
#view the output
fake_zebra_img.show()

### Test zebra2horse

In [None]:
#Load a sample horse image
zebra_img = Image.open('./images/henning-borgersen-SarK3PsCKnk-unsplash.jpg')  #credit Henning Borgersen (https://unsplash.com/@hebo79)
#uncomment any lines below to try a different image
# zebra_img = Image.open('./images/matteo-di-iorio-v-9hnUGyuOU-unsplash.jpg') #credit Matteo Di Iorio (https://unsplash.com/@shot_by_teo)
# zebra_img = Image.open('./images/ray-rui-TwG9EZ28nms-unsplash.jpg') #credit Ray Rui (https://unsplash.com/@ray30)
# zebra_img = Image.open('./images/ron-dauphin-k-8-eX4Y3no-unsplash.jpg') #credit https://unsplash.com/@rondomondo
# zebra_img = Image.open('./images/sandra-gabriel-9yYrpdGu8g0-unsplash.jpg') #credit Sandra Gabriel (https://unsplash.com/@sandragabriel)
# zebra_img = Image.open('./images/wolfgang-hasselmann-3UMTQDO5TkE-unsplash.jpg') #credit Wolfgang Hasselmann (https://unsplash.com/@wolfgang_hasselmann)

#translate it to a zebra
fake_horse_img = generate_f(zebra_img)

In [None]:
#view the input zebra image in the same size as it was fed into the model
resize_img(zebra_img).show()

In [None]:
#view the output
fake_horse_img.show()

## Make a web app with Gradio
[Gradio](https://gradio.app/) is a very easy-to-use Python library that enables you to quickly make web applications for Python functions. It is heavily used to show demos of deep learning models on [Huggingface Spaces](https://huggingface.co/spaces). One very nice feature that we will exploit here is that the input image in a Gradio app allows you to submit images to a model using a drag and drop interface. 

After you run the cell below, you will be able to try new images without changing any code. Simply open a separate browser window and do a Google image search for horses and zebras. Then drag and drop some of the results into the input boxes of the app to see what the models produce. Keep in mind that this app will not support all image formats that some websites will use. So, you may get errors on certain images. 

In [None]:
# Feel free to ignore the next line if you are not familiar with web development. 
# Css is used to style web pages. The next line is included because it helps to ensure that the input and output images 
# are displayed at roughly the same size in the app . 
css = """
    * {
        margin:auto;
    }
    div.contain {
        display:flex;
        justify-content:center;
    }

    #in_img_a, 
    #in_img_b,
    #out_img_a,
    #out_img_b {
        max-width: 400px;
        max-height: 400px;
        min-width: 256px
    }
"""

# The next few lines use the Gradio api to make a web app. Please see Gradio's documentation for more information. 
with gr.Blocks(theme=gr.themes.Glass() ,css=css) as demo:
    gr.Markdown('# CycleGAN in Jax + Flax Demo')
    gr.Markdown('### Do a Google image search for horses OR zebras (not together) in a separate window,<br/> then try dragging and dropping the images into the input boxes below!')
    
    with gr.Box():
        atob = gr.Markdown("## Horse to Zebra") 
        with gr.Row():       
            inp = gr.Image(type="pil",label="Input", elem_id="in_img_a")
            out = gr.Image(type="pil", label="Output", interactive=False, elem_id="out_img_a")                    
            inp.change(lambda img:generate_g(img),  inputs=inp, outputs=out)
        
    with gr.Box():    
        btoa = gr.Markdown("## Zebra to Horse")
        with gr.Row():
            inp = gr.Image(type="pil",label="Input", elem_id="in_img_b")
            out = gr.Image(type="pil", label="Output", interactive=False, elem_id="out_img_b")
            inp.change(lambda img: generate_f(img),  inputs=inp, outputs=out)


#Launch the demo! You should see the app in the output below after you run this cell
demo.launch()