In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
import cv2
import numpy as np
import tensorflow as tf 
import network
import guided_filter
from tqdm import tqdm

import io
import uvicorn
import nest_asyncio
from enum import Enum
from fastapi import FastAPI, UploadFile, File, HTTPException, Request, responses
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.templating import Jinja2Templates


In [None]:
def resize_crop(image):
    h, w, c = np.shape(image)
    if min(h, w) > 720:
        if h > w:
            h, w = int(720*h/w), 720
        else:
            h, w = 720, int(720*w/h)
    image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
    h, w = (h//8)*8, (w//8)*8
    image = image[:h, :w, :]
    return image
    

In [None]:
model_path = 'saved_models'

In [None]:
dir_name = "images_uploaded"
if not os.path.exists(dir_name):
    os.mkdir(dir_name)

In [None]:
def new_cartonizer(model_path):
    
    input_photo = tf.placeholder(tf.float32, [1, None, None, 3])
    network_out = network.unet_generator(input_photo)
    final_out = guided_filter.guided_filter(input_photo, network_out, r=1, eps=5e-3)

    all_vars = tf.trainable_variables()
    gene_vars = [var for var in all_vars if 'generator' in var.name]
    saver = tf.train.Saver(var_list=gene_vars)
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    sess.run(tf.global_variables_initializer())
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    return sess,final_out,input_photo

In [None]:
def new_cartoon(sess,image,final_out,input_photo):
    
    try:
        image = resize_crop(image)
        batch_image = image.astype(np.float32)/127.5 - 1
        batch_image = np.expand_dims(batch_image, axis=0)
        output = sess.run(final_out, feed_dict={input_photo: batch_image})
        output = (np.squeeze(output)+1)*127.5
        output = np.clip(output, 0, 255).astype(np.uint8)
    
        return output
    except:
        print('cartoonize {} failed')


In [None]:
sess,final_out,input_photo=new_cartonizer(model_path)

In [None]:
# Assign an instance of the FastAPI class to the variable "app".
# You will interact with your api using this instance.
app = FastAPI(title='Deploying a ML Model with FastAPI')
# By using @app.get("/") you are allowing the GET method to work for the / endpoint.

@app.get("/")
def home():
    return "Welcome to my Cartoon World"

@app.post("/cartoonize")
def cartoonize(file: UploadFile = File(...)):
   
    # 1. VALIDATE INPUT FILE
    filename = file.filename
    fileExtension = filename.split(".")[-1] in ("jpg", "jpeg", "png")
    if not fileExtension:
        raise HTTPException(status_code=415, detail="Unsupported file provided.")
    
    # 2. TRANSFORM RAW IMAGE INTO CV2 imag
    # Read image as a stream of bytes
    image_stream = io.BytesIO(file.file.read())
    # Start the stream from the beginning (position zero)
    image_stream.seek(0)
    # Write the stream of bytes into a numpy array
    file_bytes = np.asarray(bytearray(image_stream.read()), dtype=np.uint8)
    # Decode the numpy array as an image
    image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
    
    # 3. RUN Cartoonization algorithm
    # Create cartoonized image
    output_image = new_cartoon(sess,image,final_out,input_photo)
    # Save it in a folder within the server
    cv2.imwrite(f'images_uploaded/{filename}', image)
    cv2.imwrite(f'images_uploaded/cartoon_{filename}', output_image)
    
    # 4. STREAM THE RESPONSE BACK TO THE CLIENT 
    # Open the saved image for reading in binary mode
    file_image2 = open(f'images_uploaded/cartoon_{filename}', mode="rb")
    file_image = open(f'images_uploaded/{filename}', mode="rb")
    
    # Return the image as a stream specifying media type
    return StreamingResponse(file_image2, media_type="image/jpeg")


In [None]:
# Allows the server to be run in this interactive environment
nest_asyncio.apply()

# Host depends on the setup you selected (docker or virtual env)
host = "0.0.0.0" if os.getenv("DOCKER-SETUP") else "127.0.0.1"

# Spin up the server!    
uvicorn.run(app, host=host, port=8000)


# Bonus video processing


In [None]:
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt

from glob import glob

import IPython.display as ipd
from tqdm import tqdm

import subprocess

plt.style.use('ggplot')

## Display video

In [None]:
ipd.Video('videoName.mp4', width=500)

In [None]:
# Load in video capture
cap = cv2.VideoCapture('videoName.mp4')

In [None]:
# Total number of frames in video
cap.get(cv2.CAP_PROP_FRAME_COUNT)

In [None]:
# Video height and width
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
print(f'Height {height}, Width {width}')

In [None]:
# Get frames per second
fps = cap.get(cv2.CAP_PROP_FPS)
print(f'FPS : {fps:0.2f}')

In [None]:
cap.release()

In [None]:
## Helper function for plotting opencv images in notebook
def display_cv2_img(img, figsize=(10, 10)):
    img_ = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(img_)
    ax.axis("off")
    return img_

In [None]:
cap.release()

### Display multiple frams from video

In [None]:

cap = cv2.VideoCapture("videoName.mp4")
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

fig, axs = plt.subplots(3, 3, figsize=(30, 20))
axs = axs.flatten()

img_idx = 0
for frame in range(n_frames):
    ret, img = cap.read()
    if ret == False:
        break
    if frame % 100 == 0:
        axs[img_idx].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        axs[img_idx].set_title(f'Frame: {frame}')
        axs[img_idx].axis('off')
        img_idx += 1

plt.tight_layout()
plt.show()
cap.release()

In [None]:
sess,final_out,input_photo=new_cartonize(save_folder, model_path)

## cartoonizing the whole video

In [None]:
dir_name = "cartoon_video"
if not os.path.exists(dir_name):
    os.mkdir(dir_name)

In [None]:
import glob
import numpy as np


cap = cv2.VideoCapture("videoName.mp4")
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

VIDEO_CODEC = "mp4v"
img=cv2.imread(f'{0}.jpg')
height,width,layers=img.shape
size=(width,height)
out = cv2.VideoWriter("videoName_cartoon.mp4", cv2.VideoWriter_fourcc(*VIDEO_CODEC), n_frames, size)
out.write(img)

sess,final_out,input_photo=new_cartonize(save_folder, model_path)
os.chdir(r"cartoon_video")

for frame in tqdm(range(n_frames//2), total=n_frames//2):
    ret, img = cap.read()
    if ret == False:
        break
    img = new_cartoon(sess,img, final_out, input_photo)
    out.write(img)
    
out.release()
cap.release()

In [None]:
ipd.Video('videoName_cartoon.mp4', width=600)