# Transfer Learning Baseline - Tokenizer

Build up baseline image set by generating reconstruct images

In [1]:
%cd /home/bochao/maskgit
import cv2
import albumentations
import glob
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")

import jax
import numpy as np
import jax.numpy as jnp
import flax
from flax import serialization
# from flax import optim
import optax
import itertools
from tqdm import tqdm

import maskgit
from maskgit.utils import visualize_images, read_image_from_url, restore_from_path, draw_image_with_bbox, Bbox
from maskgit.inference import ImageNet_class_conditional_generator
from maskgit.nets import vqgan_tokenizer, maskgit_transformer
from maskgit.configs import maskgit_class_cond_config
from maskgit.libml import losses, mask_schedule, parallel_decode

# categories = [
#  'jasmine', 'phlox', 'leucan', 'cherry',
#  'viola', 'lily', 'appleTree', 'snowdrop',
#  'perennial', 'blackberry', 'strawberry', 'nankingcherry',
#  'bellflower'
#]
category = 'jasmine'
TRAIN_DATA_DIR = '/home/bochao/flowers/' + category
TOKENIZER_CKPT = '/home/bochao/checkpoints/vqvae_ckpt_v1/vqvae_epoch19'
OUTPUT_DIR = '/home/bochao/results/tokenizer_baseline'
IMG_SIZE = 256

# define our crop schema for train imgs
rescaler = albumentations.SmallestMaxSize(max_size=256)
cropper = albumentations.CenterCrop(height=256, width=256)
preprocessor = albumentations.Compose([rescaler, cropper])

def read_image(img_path):
    img = cv2.imread(img_path).astype(np.uint8)
    img = preprocessor(image=img)["image"]
    img = tf.image.convert_image_dtype(img, tf.float32, saturate=False)
    return img

# data loader with crop
class Dataloader(tf.keras.utils.Sequence): #
    def __init__(self, data_dir, batch_size):
        self.train_imgs_path = glob.glob(data_dir + '/*.png')
        self.train_imgs = []
        self.counter = 0
        self.num_imgs = len(self.train_imgs_path)
        self.batch_size = batch_size
        self.im_size = 256
    
    def __len__(self):
        # Denotes the number of batches per epoch
        return np.ceil( self.num_imgs / self.batch_size).astype(int)

    def __getitem__(self, index):
        # Generate one batch of data
        while(len(self.train_imgs) < self.batch_size):
            img = read_image(self.train_imgs_path[self.counter])
            self.train_imgs.append(img)
            self.counter = (self.counter+1) % self.num_imgs
        batch = self.train_imgs[0:self.batch_size]
        self.train_imgs = self.train_imgs[self.batch_size:]
        return batch


/home/bochao/maskgit


Define tokenizer and generate reconstruct images based on train images

In [2]:
# configurations
maskgit_cf = maskgit_class_cond_config.get_config()
maskgit_cf.image_size = IMG_SIZE
maskgit_cf.eval_batch_size = 8
batch_size = 2

# dataloader
train_dataset = Dataloader(TRAIN_DATA_DIR, batch_size)
# tokenizer
tokenizer = vqgan_tokenizer.VQVAE(config=maskgit_cf, dtype=jnp.float32, train=False)
# load checkpoint
tokenizer_variables = restore_from_path(TOKENIZER_CKPT)
counter = 0

for batch in train_dataset:
    input_dict = {
        'image': batch
    }
    quantized, result_dict = tokenizer.apply(tokenizer_variables, input_dict, method=tokenizer.encode, mutable=False)
    # decode
    reconstructed_imgs = tokenizer.apply(tokenizer_variables, quantized, method=tokenizer.decode, mutable=False)
    for i in range(batch_size):
        img = tf.clip_by_value(reconstructed_imgs[i], 0.0, 1.0)
        result_img = tf.image.convert_image_dtype(img, tf.uint8).numpy()
        tf.keras.utils.save_img(OUTPUT_DIR + f'/{counter}.png', x=result_img, data_format='channels_last')
        counter += 1

# RQ-VAE tokenizer

Load RQ-VAE tokenizer

Generate reconstructed images using RQ-VAE tokenizer