In [4]:
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

from torch.utils.tensorboard import SummaryWriter
import argparse
import json
import os
import time
import glob
from dataset import Dataset
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from models import DeepAppearanceVAE, WarpFieldVAE
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from utils import Renderer, gammaCorrect
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)

# Define the actual NN model, optimizer, scheduler and dataset

In [93]:
model = DeepAppearanceVAE(1024, 21918, n_latent=256, n_cams=38)
pretrained_dict = torch.load("/workspace/uwing2/multiface/pretrained_model/6795937_best_base_model.pth")
filtered_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items() if 'module.' in k}
model.load_state_dict(filtered_dict)

<All keys matched successfully>

In [63]:
qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
optimizer = optim.Adam(model.get_model_params(), 3e-4, (0.9, 0.999))
optimizer_cc = optim.Adam(model.get_cc_params(), 3e-4, (0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

In [64]:
f = open("/workspace/uwing2/Privatar/multiface_partition_bdct4x4_nohp/camera_configs/camera-split-config_6795937.json", "r")
camera_config = json.load(f)['full']
dataset_train = Dataset(
    "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS",
    "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS/KRT",
    "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS/frame_list.txt",
    1024,
    camset=None if camera_config is None else camera_config["train"],
    exclude_prefix=None,
)

checking 0
checking 1000
checking 2000
checking 3000
checking 4000
checking 5000
checking 6000
checking 7000
checking 8000
checking 9000
checking 10000
checking 11000
checking 12000
checking 13000


In [8]:
train_sampler = RandomSampler(dataset_train)
train_loader = DataLoader(
    dataset_train,
    1,
    sampler=train_sampler,
    num_workers=0,
)

In [27]:
weights_tensor = model.dec.texture_decoder.upsample[0].conv1.deconv.state_dict()['weight']#.shape#['g'].shape
g_tensor = model.dec.texture_decoder.upsample[0].conv1.deconv.state_dict()['g']#['g'].shape
result_tensor = weights_tensor * g_tensor[None, :, None, None]
wnorm = torch.sqrt(torch.sum(weights_tensor**2))
event_out_tensor = result_tensor  / wnorm

In [None]:
event_out_tensor = result_tensor  / wnorm
print(event_out_tensor)

In [None]:
# Backward testing
scale_back_tensor = event_out_tensor * wnorm / g_tensor[None, :, None, None]
torch.equal(scale_back_tensor, weights_tensor)

# Manual Quantize the transposed convolution layer

In [89]:
def get_quantized_range(bitwidth):
    quantized_max = (1 << (bitwidth - 1)) - 1
    quantized_min = -(1 << (bitwidth - 1))
    return quantized_min, quantized_max

def get_quantization_scale_for_weight(weight, bitwidth):
    """
    get quantization scale for single tensor of weight
    :param weight: [torch.(cuda.)Tensor] floating weight to be quantized
    :param bitwidth: [integer] quantization bit width
    :return:
        [floating scalar] scale
    """
    # we just assume values in weight are symmetric
    # we also always make zero_point 0 for weight
    fp_max = max(weight.abs().max().item(), 5e-7)
    _, quantized_max = get_quantized_range(bitwidth)
    return fp_max / quantized_max

def linear_quantize(fp_tensor, bitwidth, scale, zero_point, dtype=torch.int8) -> torch.Tensor:
    """
    linear quantization for single fp_tensor
      from
        fp_tensor = (quantized_tensor - zero_point) * scale
      we have,
        quantized_tensor = int(round(fp_tensor / scale)) + zero_point
    :param tensor: [torch.(cuda.)FloatTensor] floating tensor to be quantized
    :param bitwidth: [int] quantization bit width
    :param scale: [torch.(cuda.)FloatTensor] scaling factor
    :param zero_point: [torch.(cuda.)IntTensor] the desired centroid of tensor values
    :return:
        [torch.(cuda.)FloatTensor] quantized tensor whose values are integers
    """
    assert(fp_tensor.dtype == torch.float)
    assert(isinstance(scale, float) or
           (scale.dtype == torch.float and scale.dim() == fp_tensor.dim()))
    assert(isinstance(zero_point, int) or
           (zero_point.dtype == dtype and zero_point.dim() == fp_tensor.dim()))

    ############### YOUR CODE STARTS HERE ###############
    # Step 1: scale the fp_tensor
    scaled_tensor = fp_tensor / scale
    # Step 2: round the floating value to integer value
    rounded_tensor = torch.round(scaled_tensor)
    ############### YOUR CODE ENDS HERE #################

    rounded_tensor = rounded_tensor.to(dtype)

    ############### YOUR CODE STARTS HERE ###############
    # Step 3: shift the rounded_tensor to make zero_point 0
    shifted_tensor = rounded_tensor + zero_point
    ############### YOUR CODE ENDS HERE #################

    # Step 4: clamp the shifted_tensor to lie in bitwidth-bit range
    quantized_min, quantized_max = get_quantized_range(bitwidth)
    print(quantized_min, quantized_max)
    print(torch.min(shifted_tensor), torch.max(shifted_tensor))
    quantized_tensor = shifted_tensor.clamp_(quantized_min, quantized_max)
    return quantized_tensor

def linear_quantize_weight_per_channel(tensor, bitwidth, datatype):
    """
    linear quantization for weight tensor
        using different scales and zero_points for different output channels
    :param tensor: [torch.(cuda.)Tensor] floating weight to be quantized
    :param bitwidth: [int] quantization bit width
    :return:
        [torch.(cuda.)Tensor] quantized tensor
        [torch.(cuda.)Tensor] scale tensor
        [int] zero point (which is always 0)
    """
    dim_output_channels = 0
    num_output_channels = tensor.shape[dim_output_channels]
    scale = torch.zeros(num_output_channels, device=tensor.device)
    for oc in range(num_output_channels):
        _subtensor = tensor.select(dim_output_channels, oc)
        _scale = get_quantization_scale_for_weight(_subtensor, bitwidth)
        scale[oc] = _scale
    scale_shape = [1] * tensor.dim()
    scale_shape[dim_output_channels] = -1
    scale = scale.view(scale_shape)
    quantized_tensor = linear_quantize(tensor, bitwidth, scale, zero_point=0, dtype=datatype)
    return quantized_tensor, scale, 0

def linear_quantize_and_replace_weight(conv_transpose_layer, bitwidth=16, datatype=torch.int16):
    new_state_dict = conv_transpose_layer.state_dict()
    weights_tensor = torch.clone(new_state_dict['weight'])
    wnorm = torch.sqrt(torch.sum(weights_tensor**2))
    g_tensor = new_state_dict['g']#['g'].shape
    result_tensor = weights_tensor * g_tensor[None, :, None, None]
    event_out_tensor = result_tensor  / wnorm
    post_quantized_tensor, scale, zp = linear_quantize_weight_per_channel(event_out_tensor, bitwidth, datatype)
    post_quantized_tensor_fp = post_quantized_tensor.float() #torch.convert(post_quantized_tensor, torch.float)
    post_quantized_tensor_fp = (post_quantized_tensor_fp - zp) * scale
    post_quantized_weights = post_quantized_tensor_fp * wnorm / g_tensor[None, :, None, None]
    new_state_dict['weight'] = post_quantized_weights
    conv_transpose_layer.load_state_dict(new_state_dict)
    return conv_transpose_layer

In [100]:
bitwidth = 14
datatype = torch.int16
for i in range(len(model.dec.texture_decoder.upsample)):
    linear_quantize_and_replace_weight(model.dec.texture_decoder.upsample[i].conv1.deconv, bitwidth, datatype)
    linear_quantize_and_replace_weight(model.dec.texture_decoder.upsample[i].conv2.deconv, bitwidth, datatype)

-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)
-8192 8191
tensor(-8191, dtype=torch.int32) tensor(8191, dtype=torch.int32)


In [None]:
mse = nn.MSELoss()
for i, data in enumerate(train_loader):
    optimizer.zero_grad()
    optimizer_cc.zero_grad()
    
    M = data["M"]#.cuda()
    gt_tex = data["tex"]#.cuda()
    vert_ids = data["vert_ids"]#.cuda()
    uvs = data["uvs"]#.cuda()
    uv_ids = data["uv_ids"]#.cuda()
    avg_tex = data["avg_tex"]#.cuda()
    view = data["view"]#.cuda()
    transf = data["transf"]#.cuda()
    verts = data["aligned_verts"]#.cuda()
    photo = data["photo"]#.cuda()
    mask = data["mask"]#.cuda()
    cams = data["cam"]#.cuda()
    batch, channel, height, width = avg_tex.shape
    output = {}
    height_render, width_render = [2048, 1334]
    width_render = width_render - (width_render % 8)
    photo_short = torch.Tensor(photo)[:, :, :width_render, :]

    pred_tex, pred_verts, kl = model(avg_tex, verts, view, cams=cams)
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
    optimizer.step()
    optimizer_cc.step()
    loss = mse(pred_tex, avg_tex)
    loss.backward()
    if (i > 5):
        break

# Post Quantization Data

In [None]:
model.dec.texture_decoder.module.upsample[0].conv1.deconv.state_dict()['g']

# Understanding Transposed Quantization

In [2]:
# With square kernels and equal stride
m = nn.ConvTranspose2d(16, 33, 3, stride=2)
# non-square kernels and unequal stride and with padding
m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
input = torch.randn(20, 16, 50, 100)
output = m(input)
# exact output size can be also specified as an argument
input = torch.randn(1, 16, 12, 12)
downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
h = downsample(input)
h.size()
output = upsample(h, output_size=input.size())
output.size()

torch.Size([1, 16, 12, 12])

In [3]:
m.state_dict()

OrderedDict([('weight',
              tensor([[[[-3.8928e-02, -4.4176e-02,  3.6679e-02, -1.2662e-02, -3.5243e-02],
                        [-3.6266e-02,  1.9824e-02,  1.2112e-02,  6.6716e-03, -6.2657e-03],
                        [ 3.5419e-02, -1.4695e-03, -3.2812e-02,  1.1015e-02,  7.7270e-03]],
              
                       [[-4.2148e-02, -8.9569e-03,  4.2684e-03,  2.7824e-02,  1.9044e-02],
                        [ 3.6119e-02, -5.7129e-03, -2.0307e-02, -4.3224e-02, -4.0367e-02],
                        [-1.2355e-02,  7.8870e-03, -2.0248e-03,  2.8258e-02, -4.1270e-02]],
              
                       [[ 2.2233e-02, -2.2651e-02, -2.5109e-02,  2.0934e-02, -2.0716e-02],
                        [-1.5664e-03,  6.5778e-03, -4.1291e-02, -4.0083e-03, -1.5251e-02],
                        [-2.1261e-02,  4.1453e-02,  1.3748e-02, -3.5007e-02,  1.7711e-02]],
              
                       ...,
              
                       [[ 1.7797e-02,  1.8618e-02, -1.0890e-02,  2