In [1]:
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

from torch.utils.tensorboard import SummaryWriter
import argparse
import json
import os
import time
import glob
from dataset import Dataset
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from models import DeepAppearanceVAE, WarpFieldVAE
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from utils import Renderer, gammaCorrect
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization import (
  get_default_qconfig_mapping,
  get_default_qat_qconfig_mapping,
  QConfigMapping,
)

In [11]:
model = DeepAppearanceVAE(1024, 21918, n_latent=256, n_cams=38)
pretrained_dict = torch.load("/workspace/uwing2/multiface/pretrained_model/6795937_best_base_model.pth")
filtered_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items() if 'module.' in k}
model.load_state_dict(filtered_dict)

<All keys matched successfully>

In [16]:
import json
test_segment_config = "/workspace/uwing2/Privatar/multiface_partition_bdct4x4/test_segment.json"

In [4]:
f = open("/workspace/uwing2/Privatar/multiface_partition_bdct4x4_nohp/camera_configs/camera-split-config_6795937.json", "r")
camera_config = json.load(f)['full']

In [17]:
if test_segment_config is not None:
    f = open(test_segment_config, "r")
    test_segment_config = json.load(f)
    f.close()
    test_segment = test_segment_config["segment"]
else:
    test_segment = ["EXP_ROM", "EXP_free_face"]

In [19]:
from dataset import Dataset
dataset_test = Dataset(
        "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS",
        "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS/KRT",
        "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS/frame_list.txt",
        1024,
        camset=None if camera_config is None else camera_config["test"],
        valid_prefix=test_segment,
    )

checking 0
checking 1000
checking 2000
checking 3000
checking 4000
checking 5000
checking 6000
checking 7000
checking 8000
checking 9000
checking 10000
checking 11000
checking 12000
checking 13000


In [None]:
from torch.utils.data import DataLoader, RandomSampler
test_sampler = RandomSampler(dataset_test)
test_loader = DataLoader(
        dataset_test,
        10,
        sampler=test_sampler,
        num_workers=0,
    )

In [None]:
for i, data in enumerate(dataset_test):
    print(data)

In [12]:
# example_avg_tex=torch.zeros([10, 3, 1024, 1024])
# example_verts=torch.zeros([10, 7306, 3])
# example_view=torch.zeros([10, 3])
# example_cams=torch.zeros([10])
# example_inputs = (example_avg_tex, example_verts, example_view, example_cams)

In [None]:
model.dec.texture_decoder = torch.quantization.QuantWrapper(model.dec.texture_decoder)
model.train()

In [14]:
qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
optimizer = optim.Adam(model.get_model_params(), 3e-4, (0.9, 0.999))
optimizer_cc = optim.Adam(model.get_cc_params(), 3e-4, (0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

qat_model = torch.ao.quantization.prepare_qat(model, inplace=True)



In [6]:
dataset_train = Dataset(
        "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS",
        "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS/KRT",
        "/workspace/uwing2/multiface/dataset/m--20180227--0000--6795937--GHS/frame_list.txt",
        1024,
        camset=None if camera_config is None else camera_config["train"],
        exclude_prefix=None,
    )

checking 0
checking 1000
checking 2000
checking 3000
checking 4000
checking 5000
checking 6000
checking 7000
checking 8000
checking 9000
checking 10000
checking 11000
checking 12000
checking 13000


In [7]:
train_sampler = RandomSampler(dataset_train)

In [8]:
train_loader = DataLoader(
        dataset_train,
        1,
        sampler=train_sampler,
        num_workers=0,
    )

In [17]:
mse = nn.MSELoss()
for i, data in enumerate(train_loader):
    optimizer.zero_grad()
    optimizer_cc.zero_grad()
    
    M = data["M"]#.cuda()
    gt_tex = data["tex"]#.cuda()
    vert_ids = data["vert_ids"]#.cuda()
    uvs = data["uvs"]#.cuda()
    uv_ids = data["uv_ids"]#.cuda()
    avg_tex = data["avg_tex"]#.cuda()
    view = data["view"]#.cuda()
    transf = data["transf"]#.cuda()
    verts = data["aligned_verts"]#.cuda()
    photo = data["photo"]#.cuda()
    mask = data["mask"]#.cuda()
    cams = data["cam"]#.cuda()
    batch, channel, height, width = avg_tex.shape
    output = {}
    height_render, width_render = [2048, 1334]
    width_render = width_render - (width_render % 8)
    photo_short = torch.Tensor(photo)[:, :, :width_render, :]

    pred_tex, pred_verts, kl = qat_model(avg_tex, verts, view, cams=cams)
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
    optimizer.step()
    optimizer_cc.step()
    loss = mse(pred_tex, avg_tex)
    loss.backward()
    if (i > 5):
        break

In [None]:
model.dec.texture_decoder.module.upsample[0].conv1.deconv.state_dict()

In [20]:
quantized_model = torch.ao.quantization.convert(qat_model.eval(), inplace=False)
quantized_model.eval()

DeepAppearanceVAE(
  (enc): DeepApperanceEncoder(
    (texture_encoder): TextureEncoder(
      (downsample): Sequential(
        (0): ConvDownsample(
          (conv1): Conv2dWN(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (conv2): Conv2dWN(16, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (relu): LeakyReLU(negative_slope=0.2, inplace=True)
        )
        (1): ConvDownsample(
          (conv1): Conv2dWN(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (conv2): Conv2dWN(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (relu): LeakyReLU(negative_slope=0.2, inplace=True)
        )
        (2): ConvDownsample(
          (conv1): Conv2dWN(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (conv2): Conv2dWN(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (relu): LeakyReLU(negative_slope=0.2, inplace=True)
        )
        (3): ConvDownsample(
          (conv1): Conv2dW

In [21]:
quantized_model.dec.texture_decoder.module.upsample[0].conv1.deconv.state_dict()

OrderedDict([('weight',
              tensor([[[[ 4.2532e-02, -6.4358e-02, -4.4278e-02,  4.0479e-02],
                        [-5.8819e-02, -1.1095e-01, -3.7640e-03,  8.5450e-02],
                        [-6.7983e-02,  1.0470e-01,  4.6212e-02, -9.7798e-02],
                        [ 1.7644e-01, -4.1109e-02, -2.5845e-03,  1.1428e-01]],
              
                       [[ 1.6727e-02,  1.1839e-01,  1.4635e-01, -2.4463e-02],
                        [ 3.8758e-02, -1.0763e-01,  1.0563e-01,  4.8089e-03],
                        [-3.4209e-02,  4.4760e-02,  2.7049e-01,  1.2554e-01],
                        [ 8.9261e-02,  9.2531e-02,  4.2513e-02,  1.2315e-01]],
              
                       [[ 9.5561e-02,  7.9667e-02, -1.4056e-01, -2.0480e-01],
                        [-2.4828e-02, -1.5581e-02, -2.0584e-02, -1.4230e-01],
                        [ 6.5232e-02, -1.1447e-02, -1.6669e-01, -1.3931e-01],
                        [-1.6884e-01, -5.0007e-02, -1.0443e-02,  1.0767e-01]],
       