# Generating images with MX-Font model from a reference style
In this example we'll generate images with trained MX-Font model from a reference style.
If you want to generate multiple styles, please check using `eval.py` instead of using this example file (because it is much simpler to load the referece styles).

### 1. Loading packages
* First, load the packages used in this code.
* All of the packages are avilable in `pip`.

In [1]:
import json
from pathlib import Path
from PIL import Image

import torch
from sconf import Config
from torchvision import transforms

import os
import cv2
import math
import matplotlib.pyplot as plt

* These modules are defined in this repository.

In [2]:
import models
from datasets import read_font, render
from utils import save_tensor_to_image

### 2. Define Inference Handler Class

In [4]:
class InferenceHandler:
    
    def __init__(self, weight_path):
        self.weight_path = weight_path

    def build_model(self):
        '''
        Build and load the trained model.
        '''
        cfg = Config("cfgs/eval.yaml", default="cfgs/defaults.yaml")
        self.transform = transforms.Compose(
            [transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        )
        self.decomposition = json.load(open("data/chn_decomposition.json"))

        g_kwargs = cfg.get('g_args', {})
        self.gen = models.Generator(1, cfg.C, 1, **g_kwargs).cuda().eval()
        weight = torch.load(self.weight_path)
        if "generator_ema" in weight:
            weight = weight["generator_ema"]
        self.gen.load_state_dict(weight)

    def load_reference_images(self, data_path, exp_path):
        '''
        Load reference images
        '''
        self.data_path = data_path
        self.exp_path =  exp_path

        # Get a list of image files in the 'data/' folder
        image_files = os.listdir(f'{self.data_path}')

        print(image_files)

        # Image Preprocessing: Loop through each image file
        for image_file in image_files:
            if image_file.endswith(('.jpg', '.png', '.jpeg')):  # Only process image files
                # Load the image
                image_path = os.path.join(self.data_path, image_file)
                image = cv2.imread(image_path)

                # Convert to grayscale
                grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

                # Denoise the Image
                denoised_image = cv2.fastNlMeansDenoising(grayscale_image, None, h=10, templateWindowSize=7, searchWindowSize=21)

                # Create save directory
                self.raw_imgs_path = f'{self.exp_path}/raw_imgs'
                self.ref_imgs_path = f'{self.exp_path}/ref_imgs'
                if not os.path.exists(self.raw_imgs_path): os.makedirs(self.raw_imgs_path)
                if not os.path.exists(self.ref_imgs_path): os.makedirs(self.ref_imgs_path)

                # Save the grayscale image in the 'result' folder
                raw_img = os.path.join(self.raw_imgs_path, image_file)
                ref_img = os.path.join(self.ref_imgs_path, image_file)
                cv2.imwrite(raw_img, image)
                cv2.imwrite(ref_img, denoised_image)

                print(f'Converted {image_file} and saved as {ref_img}')

        print(f'Done!!')

    def extract_style_factor(self):
        '''
        * `ref_path`: 
            * The path of reference font or images.
            * If you are using a ttf file, set this to the location of the ttf file.
            * If you want to use rendered images, set this to the path to the directory which contains the reference images.
        * `extension`:
            * If you are using image files, set this to their extension(png, jpg, etc..). 
            * This will be ignored if `use_ttf` is True.
        * `batch_size`:
            * The number of images inferred at once.
        '''
        ref_path = self.ref_imgs_path  # Path to the reference images
        extension = "png"  # Extension of the reference images
        batch_size = 3  # The batch size

        ref_paths = Path(ref_path).glob(f"*.{extension}")
        ref_imgs = torch.stack([self.transform(Image.open(str(p))) for p in ref_paths]).cuda()
        ref_batches = torch.split(ref_imgs, batch_size)

        self.style_facts = {}

        for batch in ref_batches:
            style_fact = self.gen.factorize(self.gen.encode(batch), 0)
            for k in style_fact:
                self.style_facts.setdefault(k, []).append(style_fact[k])
                
        self.style_facts = {k: torch.cat(v).mean(0, keepdim=True) for k, v in self.style_facts.items()}

    def generate_infer_imgs(self, gen_chars):
        save_dir = Path(f"{self.exp_path}/infer_imgs")  # Directory where you want to save generated images
        source_path = "data/ttfs/source/chn_source.ttf"  # Path to the font file to render the source images

        save_dir.mkdir(parents=True, exist_ok=True)
        source_font = read_font(source_path)

        self.gen_img_list = []
        for char in gen_chars:
            source_img = self.transform(render(source_font, char)).unsqueeze(0).cuda()
            char_facts = self.gen.factorize(self.gen.encode(source_img), 1)
            
            gen_feats = self.gen.defactorize([self.style_facts, char_facts])
            out = self.gen.decode(gen_feats).detach().cpu()[0]

            path = save_dir / f"{char}.png"
            self.gen_img_list.append(path)
            save_tensor_to_image(out, path)
        
        return self.gen_img_list

    def show_inference_imgs(self):
        images = []

        for filename in self.gen_img_list:
            img = cv2.imread(str(filename)) 
            if img is not None:
                images.append(img)

        plt.figure(figsize=(20,10))
        columns = 5
        for i, image in enumerate(images):
            plt.subplot(math.ceil(len(images) / columns + 1), columns, i + 1)
            plt.imshow(image)
        pass

### 3. Demonstration of the inference results 

In [5]:
# Initialization
weight_path = "generator.pth"  # path to weight to infer
ih = InferenceHandler(weight_path)
ih.build_model()

In [None]:
# Experiment1
data_path = "data/images/lanting/4shot_01"   # Path to the raw images
exp_path  = "exp/test1"                     # Path to the experiment folder
gen_chars = '床前明月光疑是地上霜举头望明月低头思故乡'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment2
data_path = "data/images/lanting/8shot_01"   # Path to the raw images
exp_path  = "exp/test2"                     # Path to the experiment folder
gen_chars = '床前明月光疑是地上霜举头望明月低头思故乡'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment3
data_path = "data/images/lanting/8shot_02_bw"   # Path to the raw images
exp_path  = "exp/test3"                     # Path to the experiment folder
gen_chars = '床前明月光疑是地上霜举头望明月低头思故乡'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment4
data_path = "data/images/lanting/8shot_03"   # Path to the raw images
exp_path  = "exp/test4"                     # Path to the experiment folder
gen_chars = '床前明月光疑是地上霜举头望明月低头思故乡'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment5
data_path = "data/images/lanting/8shot_yf_01"   # Path to the raw images
exp_path  = "exp/test5"                     # Path to the experiment folder
gen_chars = '床前明月光疑是地上霜举头望明月低头思故乡'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment6
data_path = "data/images/lanting/8shot_yf_01_rename"   # Path to the raw images
exp_path  = "exp/test6"                     # Path to the experiment folder
gen_chars = '床前明月光疑是地上霜举头望明月低头思故乡'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment7
data_path = "data/images/lanting/8shot_yf_01_rename"   # Path to the raw images
exp_path  = "exp/test7"                     # Path to the experiment folder
gen_chars = '吳永鋒'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment8
data_path = "data/images/lanting/8shot_yf_02"   # Path to the raw images
exp_path  = "exp/test8"                     # Path to the experiment folder
gen_chars = '吳永鋒'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()

In [None]:
# Experiment9
data_path = "data/images/lanting/203shot"   # Path to the raw images
exp_path  = "exp/test9"                     # Path to the experiment folder
gen_chars = '吳永鋒'
ih.load_reference_images(data_path, exp_path)
ih.extract_style_factor()
ih.generate_infer_imgs(gen_chars)
ih.show_inference_imgs()