In [None]:
"""
This script is used to refine codebook learned by VQGAN
By reducing size of codebook, reduce ambiguity in comparing and predicting operations happened in token embedding space
Modify and validate in a dynamic way, k-means to determine representative centers, then iteratively construct code groupings
Codebook quality is reflected by reconstruction performance
"""
import torch
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F
import os
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from taming.vqgan import VQModel
from matplotlib import pyplot as plt
import os
import h5py
from utils.bbox_utils import compute_iou_mask
from utils.io import dump_json_object
%matplotlib inline

In [None]:
codebook_path = '/mnt/data/jiangyong/vito/vqgan_embed.pth'
code2mask_dir = '/mnt/data/jiangyong/vito/code2mask'
code_embed = torch.load(codebook_path, map_location='cpu').numpy()
code_smtic = torch.zeros((0, 256, 256), dtype=float)
print('start parsing semantic masks of codebook')
for code_name in tqdm(os.listdir(code2mask_dir)):
    code = Image.open(os.path.join(code2mask_dir, code_name))
    code = F.to_tensor(code)
    code_smtic = torch.cat([code_smtic, code], dim=0)
code_smtic = code_smtic.numpy()
print(f'code semantic map is shaped at {code_smtic.shape}')

In [None]:
# initialize groupings with k-means
init_type = 'semantics'
assert init_type in ['geometry', 'semantics']
init_num = 100
print(f'running K-means to initialize centers regarding {init_type} type')

if init_type == 'geometry':
    kmeans = KMeans(n_clusters=init_num).fit(code_embed)
    centers = kmeans.cluster_centers_
    print(f'got initial centers at shape {centers.shape}')
    centers_idx = []
    # replace centers with representative codes
    for i in range(init_num):
        c = centers[i]
        dist = np.linalg.norm(code_embed-c, axis=1)
        centers_idx.append(np.argmin(dist))
    centers = code_embed[centers_idx]
else:
    code_smtic_flatten = code_smtic.reshape(1024, -1)
    kmeans = KMeans(n_clusters=init_num).fit(code_smtic_flatten)
    centers = kmeans.cluster_centers_
    print(f'got initial centers at shape {centers.shape}')
    centers_idx = []
    # replace centers with representative codes
    for i in range(init_num):
        c = centers[i]
        dist = np.linalg.norm(code_smtic_flatten-c, axis=1)
        centers_idx.append(np.argmin(dist))
    centers = code_embed[centers_idx]

print(f'updated centers at shape {centers.shape}')

In [None]:
# run code grouping
gamma = 0.05 * 256
centers_smtic = code_smtic[centers_idx]
grp = []   # denotes grouping label of each code
novel_grp = init_num
print('running code grouping')
for i in tqdm(range(1024)):
    c = code_embed[i]
    s = code_smtic[i]
    dist_smtic = np.linalg.norm(s-centers_smtic, axis=(1,2))
    m = np.min(dist_smtic)
    if m < gamma:
        grp.append(np.argmin(dist_smtic))
    else:
        grp.append(novel_grp)
        novel_grp += 1
        centers = np.concatenate([centers, c[None,:]], axis=0)
        centers_idx.append(i)
        centers_smtic = np.concatenate([centers_smtic, s[None,:]], axis=0)
print(f'got {novel_grp} semantic clusters')

In [None]:
# visualize codebook with color denoting groupings
codebook_reduced = TSNE(n_components=2, learning_rate='auto').fit_transform(code_embed)
color = np.zeros((novel_grp, 3))

ch_list = [(novel_grp-1)//3] * 3
if (novel_grp-1) % 3 == 1:
    ch_list[0] += 1
elif (novel_grp-1) % 3 == 2:
    ch_list[0] += 1
    ch_list[1] += 1

color[1:ch_list[0]+1, 0] = np.linspace(0, 1, ch_list[0]+1)[1:]
color[ch_list[0]+1:ch_list[0]+ch_list[1]+1, 1] = np.linspace(0, 1, ch_list[1]+1)[1:]
color[-ch_list[2]:, 2] = np.linspace(0, 1, ch_list[2]+1)[1:]

plt.scatter(codebook_reduced[:, 0], codebook_reduced[:, 1], s=1, c=color[grp])
plt.xlabel('x')
plt.ylabel('y')
plt.title('t-SNE on grouping clusters')
plt.grid(False)
plt.show()
plt.clf()

In [None]:
# merge codebook and validate quality
ddconfig = {
    'double_z': False,
    'z_channels': 256,
    'resolution': 256,
    'in_channels': 3,
    'out_ch': 3,
    'ch': 128,
    'ch_mult': [1,1,2,2,4],
    'num_res_blocks': 2,
    'attn_resolutions': [16],
    'dropout': 0.0
}
vqgan_ckpt = '/mnt/data/jiangyong/vito/vqgan.ckpt'
vqgan = VQModel(ddconfig=ddconfig, n_embed=novel_grp, embed_dim=256, ckpt_path=vqgan_ckpt)
vqgan.to('cuda:1')
vqgan.eval()
vqgan.quantize.embedding.weight.data = torch.from_numpy(centers).to(vqgan.device)
# test mask reconstruction
mask_path = 'data/masks/refcoco/50.png'
mask = F.to_tensor(Image.open(mask_path))
mask = F.resize(mask, 256)

H, W = mask.shape[-2:]
mask_0 = F.crop(mask, 0, 0, 256, 256)
mask_1 = F.crop(mask, (H-256)//2, (W-256)//2, 256, 256)
mask_2 = F.crop(mask, H-256, W-256, 256, 256)
crop_flag = np.argmax([torch.sum(mask_0), torch.sum(mask_1), torch.sum(mask_2)])
mask = [mask_0, mask_1, mask_2][crop_flag]

mask_vqgan = 2*mask - 1
mask_vqgan = mask_vqgan.repeat(3, 1, 1).unsqueeze(0).to(torch.float).to(vqgan.device)
with torch.no_grad():
    encoding_indices = vqgan.encode(mask_vqgan)[-1][-1]
print(f'encode input mask to sequence at length {len(encoding_indices)}')

with torch.no_grad():
    pred_mask = vqgan.decode_code(torch.LongTensor(encoding_indices.cpu()).to(vqgan.device),
                                    shape=(1, 16, 16, -1))
# vqgan reconstruction shape at [1, 3, 256, 256], value in [-1, 1]
pred_mask = pred_mask.squeeze().detach().cpu().numpy()
pred_mask = (pred_mask+1) / 2
# ITU-R 601-2 luma transform: L = R * 0.299 + G * 0.587 + B * 0.114
pred_mask = np.sum([[[0.299]], [[0.587]], [[0.114]]]*pred_mask, axis=0)
print(f'decode sequence to mask at shape {pred_mask.shape}')

vis_mask = np.concatenate([mask.squeeze().cpu().numpy(), pred_mask], axis=1)
plt.imshow(vis_mask)

In [None]:
# running quantative evaluation to verify codebook quality
mask_root = 'data/masks'
datasets = ['refclef', 'refcoco', 'refcoco+', 'refcocog']
for dataset in datasets:
    pred_h5py_path = '/mnt/data/jiangyong/vito/pred_tmp.h5py'
    pred_h5py = h5py.File(pred_h5py_path, 'w')
    dataset_dir = os.path.join(mask_root, dataset)
    print(f'running inference on {dataset} dataset')
    for i, fname in enumerate(tqdm(os.listdir(dataset_dir))):
        mask_path = os.path.join(dataset_dir, fname)
        mask = F.to_tensor(Image.open(mask_path))
        mask = F.resize(mask, 256)

        H, W = mask.shape[-2:]
        mask_0 = F.crop(mask, 0, 0, 256, 256)
        mask_1 = F.crop(mask, (H-256)//2, (W-256)//2, 256, 256)
        mask_2 = F.crop(mask, H-256, W-256, 256, 256)
        crop_flag = np.argmax([torch.sum(mask_0), torch.sum(mask_1), torch.sum(mask_2)])
        mask = [mask_0, mask_1, mask_2][crop_flag]

        mask_vqgan = 2*mask - 1
        mask_vqgan = mask_vqgan.repeat(3, 1, 1).unsqueeze(0).to(torch.float).to(vqgan.device)
        with torch.no_grad():
            encoding_indices = vqgan.encode(mask_vqgan)[-1][-1]
            pred_mask = vqgan.decode_code(torch.LongTensor(encoding_indices.cpu()).to(vqgan.device),
                                            shape=(1, 16, 16, -1))
        # vqgan reconstruction shape at [1, 3, 256, 256], value in [-1, 1]
        pred_mask = pred_mask.squeeze().detach().cpu().numpy()
        pred_mask = (pred_mask+1) / 2
        # ITU-R 601-2 luma transform: L = R * 0.299 + G * 0.587 + B * 0.114
        pred_mask = np.sum([[[0.299]], [[0.587]], [[0.114]]]*pred_mask, axis=0)

        gt_mask = mask.squeeze().cpu().numpy()

        data_grp = pred_h5py.create_group(fname)
        data_grp.create_dataset('pred', dtype='f', data=pred_mask)
        data_grp.create_dataset('gt', dtype='f', data=gt_mask)

        if i >= 100:
            break
    
    pred_h5py.close()

    pred_h5py = h5py.File(pred_h5py_path, 'r')
    mask_iou = 0   # mIoU
    mask_mat = np.zeros((0, 3))   # @0.5, @0.7, @0.9
    print(f'evaluating metrics on {dataset} dataset')
    for fname in tqdm(pred_h5py):
        data_grp = pred_h5py[fname]
        pred_mask = np.array(data_grp['pred'])
        pred_mask = pred_mask >= 0.5
        gt_mask = np.array(data_grp['gt']).astype(bool)
        
        iou = compute_iou_mask(pred_mask, gt_mask)
        iou_thre = np.array([0.5, 0.7, 0.9])
        acc = iou >= iou_thre

        mask_mat = np.concatenate([mask_mat, acc.reshape(-1, 3)], axis=0)
        mask_iou += iou
    
    total = mask_mat.shape[0]
    mask_mIoU = mask_iou / total
    mask_AP = np.sum(mask_mat, axis=0) / total
    print(f'dataset: {dataset} | mask mIoU: {mask_mIoU} | mask AP: {mask_AP}\n')

    pred_h5py.close()
    os.remove(pred_h5py_path)

In [None]:
# save current results
update_codebook = {
    'size': novel_grp,
    'embed': centers,
    'center_id': centers_idx,
    'grp_label': grp
}
torch.save(update_codebook, '/mnt/data/jiangyong/vito/refined_codebook.pt')

In [None]:
# visualize renewed codebook
codebook_reduced = TSNE(n_components=2, learning_rate='auto').fit_transform(centers)
color = np.zeros((novel_grp, 3))

ch_list = [(novel_grp-1)//3] * 3
if (novel_grp-1) % 3 == 1:
    ch_list[0] += 1
elif (novel_grp-1) % 3 == 2:
    ch_list[0] += 1
    ch_list[1] += 1

color[1:ch_list[0]+1, 0] = np.linspace(0, 1, ch_list[0]+1)[1:]
color[ch_list[0]+1:ch_list[0]+ch_list[1]+1, 1] = np.linspace(0, 1, ch_list[1]+1)[1:]
color[-ch_list[2]:, 2] = np.linspace(0, 1, ch_list[2]+1)[1:]

plt.scatter(codebook_reduced[:, 0], codebook_reduced[:, 1], s=1, c=color)
plt.xlabel('x')
plt.ylabel('y')
plt.title('t-SNE on renewed codebook')
plt.grid(False)
plt.show()
plt.clf()

In [None]:
# calculate semantic distance matrix
smtic_dist_mat = np.zeros((novel_grp, novel_grp))
for i in tqdm(range(novel_grp)):
    smtic_dist_mat[i] = np.linalg.norm(centers_smtic[i]-centers_smtic, axis=(1,2))
smtic_dist_mat