# **Image-to-image translation process via Hierarchical Style Disentanglement (HiSD)**

The image-to-Image translation is a field in the computer vision domain that deals with generating a modified image from the original input image based on certain conditions. The conditions can be multi-labels or multi-styles, or both. In recent successful methods, translation of the input image is performed based on the multi-labels and the generation of output image out of the translated feature map is performed based on the multi-styles. The labels and styles are fed to the models via texts or reference images. The translation sometimes takes unnecessary manipulations and alterations in identity attributes that are difficult to control in a semi-supervised setting.

To know about it more, please refer [here](https://analyticsindiamag.com/hisd-python-implementation-of-image-to-image-translation/).

## **Python implementation of HiSD**

HiSD needs a Python environment and PyTorch framework to set up and run. Usage of a GPU runtime is optional. Pre-trained HiSD can be loaded and inference may be performed with a CPU runtime itself. Install dependencies using the following command.

# HiSD

### Hierarchical Style Disentanglement

References:

https://github.com/imlixinyang/HiSD

https://arxiv.org/abs/2103.01456

According to the authors of this research paper, the code is meant only for academic and research purpose. 

## Load pre-trained model checkpoint from offcial Google Drive page

checkpoint_256_celeba-hq.pt

Download the checkpoint parquet file from the official page using the following command.

In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn nltk gensim tensorflow keras torch torchvision \
    tqdm scikit-image --user -q --no-warn-script-location


import IPython
IPython.Application.instance().kernel.do_shutdown(True)


In [None]:
# id = '1KDrNWLejpo02fcalUOrAJOl1hGoccBKl'
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1KDrNWLejpo02fcalUOrAJOl1hGoccBKl' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1KDrNWLejpo02fcalUOrAJOl1hGoccBKl" -O checkpoint_256_celeba-hq.pt && rm -rf /tmp/cookies.txt

Move the checkpoint parquet file to the /HiSD directory using the following commands.

In [None]:
checkpoint_name = 'checkpoint_256_celeba-hq.pt'

In [None]:
!ls

## Clone source code and install dependencies

In [None]:
!pip install tensorboardx

The following command downloads the source codes from the official repository to the local machine. 

In [None]:
!git clone https://github.com/imlixinyang/HiSD.git

In [None]:
!mv checkpoint_256_celeba-hq.pt HiSD/

In [None]:
%cd HiSD/

In [None]:
!ls

## Prepare the environment

In [None]:
%cd HiSD/
from core.utils import get_config
from core.trainer import HiSD_Trainer
import argparse
import torchvision.utils as vutils
import sys
import torch
import os
from torchvision import transforms
from PIL import Image
import numpy as np
import time
import matplotlib.pyplot as plt

In [None]:
## Load model and its checkpoint

Load the checkpoint and prepare the model for inference using the following codes.

In [None]:
device = 'cpu'
# load checkpoint
config = get_config('configs/celeba-hq_256.yaml')
noise_dim = config['noise_dim']
image_size = config['new_size']
checkpoint = 'checkpoint_256_celeba-hq.pt'
trainer = HiSD_Trainer(config)
# map_location=torch.device('cpu')
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
trainer.models.gen.load_state_dict(state_dict['gen_test'])
trainer.models.gen.to(device)

E = trainer.models.gen.encode
T = trainer.models.gen.translate
G = trainer.models.gen.decode
M = trainer.models.gen.map
F = trainer.models.gen.extract

transform = transforms.Compose([transforms.Resize(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

## Define a helper function to perform inference

Define a function to perform the image-to-image translation.

In [None]:
def translate(input, steps):
    x = transform(Image.open(input).convert('RGB')).unsqueeze(0).to(device)
    c = E(x)
    c_trg = c
    for j in range(len(steps)):
        step = steps[j]
        if step['type'] == 'latent-guided':
            if step['seed'] is not None:
                torch.manual_seed(step['seed'])
                torch.cuda.manual_seed(step['seed']) 

            z = torch.randn(1, noise_dim).to(device)
            s_trg = M(z, step['tag'], step['attribute'])

        elif step['type'] == 'reference-guided':
            reference = transform(Image.open(step['reference']).convert('RGB')).unsqueeze(0).to(device)
            s_trg = F(reference, step['tag'])
        
        c_trg = T(c_trg, s_trg, step['tag'])
            
    x_trg = G(c_trg)
    output = x_trg.squeeze(0).cpu().permute(1, 2, 0).add(1).mul(1/2).clamp(0,1).detach().numpy()
    return output

## Inference

The following commands set the desired tags, the attributes and the styles to perform translation. They use in-built example images for translation. Users can opt for their own data images.

In [None]:
input = 'examples/input_0.jpg'

# e.g.1 change tag 'Bangs' to attribute 'with' using 3x latent-guided styles (generated by random noise). 
steps = [
    {'type': 'latent-guided', 'tag': 0, 'attribute': 0, 'seed': None}
]
plt.figure(figsize=(12,4))
for i in range(3):
    plt.subplot(1, 3, i+1)
    output = translate(input, steps)
    plt.imshow(output, aspect='auto')
plt.show()

Second example inference:

In [None]:
input = 'examples/input_1.jpg'
plt.figure(figsize=(12,4))
# e.g.2 change tag 'Glasses' to attribute 'with' using reference-guided styles (extracted from another image). 
steps = [
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_0.jpg'}
]

output = translate(input, steps)
plt.subplot(131)
plt.imshow(output, aspect='auto')

steps = [
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_1.jpg'}
]

output = translate(input, steps)
plt.subplot(132)
plt.imshow(output, aspect='auto')

steps = [
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_2.jpg'}
]

output = translate(input, steps)
plt.subplot(133)
plt.imshow(output, aspect='auto')
plt.show()

Third example inference:

In [None]:
input = 'examples/input_2.jpg'

# e.g.3 change tag 'Glasses' and 'Bangs 'to attribute 'with', 'Hair color' to 'black' during one translation. 
steps = [
    {'type': 'reference-guided', 'tag': 0, 'reference': 'examples/reference_bangs_0.jpg'},
    {'type': 'reference-guided', 'tag': 1, 'reference': 'examples/reference_glasses_0.jpg'},
    {'type': 'latent-guided', 'tag': 2, 'attribute': 0, 'seed': None}
]

output = translate(input, steps)
plt.figure(figsize=(5,5))
plt.imshow(output, aspect='auto')
plt.show()

Thank you for your time!

#**Related Articles:**

> * [Image to Image Translation](https://analyticsindiamag.com/hisd-python-implementation-of-image-to-image-translation/)

> * [Guide to Kornia](https://analyticsindiamag.com/guide-to-kornia-an-opencv-inspired-pytorch-framework/)

> * [Extract Foreground Images with GrabCut Algorithm](https://analyticsindiamag.com/how-to-extract-foreground-from-images-interactively-using-grabcut/)

> * [GAN in simple 8 Steps](https://analyticsindiamag.com/how-to-build-a-generative-adversarial-network-in-8-simple-steps/)

> * [PyTorch vs Keras vs Caffe](https://analyticsindiamag.com/keras-vs-pytorch-vs-caffe-comparing-the-implementation-of-cnn/)

> * [Face Emotion Recognizer](https://analyticsindiamag.com/face-emotion-recognizer-in-6-lines-of-code/)

> * [Sign Language Classification using CNN](https://analyticsindiamag.com/hands-on-guide-to-sign-language-classification-using-cnn/)


