# Verified Actual Channel Pruning Method

In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

In [None]:
print(model.conv1.state_dict()['weight'])
print(model.conv1.weight.shape)

In [None]:
prune_ratio = 0.3 

def weights_kernel_pruning_l1_norm(model, prune_ratio):
    layer_shape = model.state_dict()['weight'].size()
    weight_copy = model.weight.data.abs().clone()
    
    l1_norm = torch.sum(weight_copy, dim=(1, 2, 3))
    num_channels_to_prune = int(prune_ratio * layer_shape[0])
    response_val, prune_indices = torch.topk(l1_norm, num_channels_to_prune, largest=False)
    overall_indices = set([i for i in range(layer_shape[0])])
    prune_indices = set(prune_indices.tolist())
    remaining_indices = overall_indices - prune_indices

    in_weights_float = torch.zeros((len(remaining_indices), int(layer_shape[1]), int(layer_shape[2]), int(layer_shape[3])), dtype=torch.float)
    in_weights_float = weight_copy[list(remaining_indices),:,:,:]
    model.weight = torch.nn.Parameter(in_weights_float)
    print(f"under prune_ratio={prune_ratio}, num_channels_to_prune={num_channels_to_prune}, response_val={response_val}, remaining_indices={remaining_indices}, prune_indices={prune_indices}")
    return model, prune_indices

model.conv1, prune_indices = weights_kernel_pruning_l1_norm(model.conv1, prune_ratio)

In [None]:
module = model.conv1
print(list(module.named_parameters()))
print(module.weight.shape)

In [None]:
def iAct_channel_pruning_l1_norm(model, prune_indices):
    layer_shape = model.state_dict()['weight'].size()
    weight_copy = model.weight.data.abs().clone()
    
    prune_indices = set(prune_indices)
    overall_indices = set([i for i in range(layer_shape[1])])
    remaining_indices = overall_indices - prune_indices

    in_weights_float = torch.zeros((int(layer_shape[0]), int(len(remaining_indices)), int(layer_shape[2]), int(layer_shape[3])), dtype=torch.float)
    in_weights_float = weight_copy[:, list(remaining_indices), :, :]
    model.weight = torch.nn.Parameter(in_weights_float)
    print(f"prune input channel indice={prune_indices}, num_channels_to_prune={len(prune_indices)}, remaining_indices={remaining_indices}, prune_indices={prune_indices}")
    return model

In [None]:
print(model.conv2.weight.shape)
model.conv2 = iAct_channel_pruning_l1_norm(model.conv2, prune_indices)

In [None]:
module = model.conv2
print(list(module.named_parameters()))
print(module.weight.shape)

## 2 Overall Pruning -- Inference only

In [1]:
from torch.utils.tensorboard import SummaryWriter
import argparse
import json
import os
import time
import glob

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataset import Dataset
from models import DeepAppearanceVAE, WarpFieldVAE, ConvTranspose2dWN
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from utils import Renderer, gammaCorrect

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def weight_kernel_pruning_l1_norm(model, in_bias, prune_ratio):
    layer_shape = model.state_dict()['weight'].size()
    weight_copy = model.weight.data.abs().clone()
    
    l1_norm = torch.sum(weight_copy, dim=(0, 2, 3))
    num_channels_to_prune = int(prune_ratio * layer_shape[1])
    response_val, prune_indices = torch.topk(l1_norm, num_channels_to_prune, largest=False)
    overall_indices = set([i for i in range(layer_shape[1])])
    prune_indices = set(prune_indices.tolist())
    remaining_indices = overall_indices - prune_indices

    new_model = ConvTranspose2dWN(int(layer_shape[0]), int(len(remaining_indices)), kernel_size=(4,4), stride=(2,2), padding=(1,1), bias=False)
    print(f"remaining_indices={remaining_indices}")
    print(f"weight kernel pruning original bias dimension = {in_bias.shape}")
    out_bias = torch.nn.Parameter(in_bias[:,list(remaining_indices),:,:])

    in_weights_float = torch.zeros((int(layer_shape[0]), len(remaining_indices), int(layer_shape[2]), int(layer_shape[3])), dtype=torch.float)
    in_weights_float = weight_copy[:, list(remaining_indices), :, :]
    new_model.weight = torch.nn.Parameter(in_weights_float)
    print(f"under prune_ratio={prune_ratio}, num_channels_to_prune={num_channels_to_prune}, response_val={response_val}, remaining_indices={remaining_indices}, prune_indices={prune_indices}")
    return new_model, out_bias, prune_indices

def iAct_channel_pruning_l1_norm(model, prune_indices):
    layer_shape = model.state_dict()['weight'].size()
    weight_copy = model.weight.data.abs().clone()
    
    prune_indices = set(prune_indices)
    overall_indices = set([i for i in range(layer_shape[0])])
    remaining_indices = overall_indices - prune_indices

    new_model = ConvTranspose2dWN(len(remaining_indices), int(layer_shape[1]), kernel_size=(4,4), stride=(2,2), padding=(1,1), bias=False)
    
    in_weights_float = torch.zeros((int(len(remaining_indices)), int(layer_shape[1]), int(layer_shape[2]), int(layer_shape[3])), dtype=torch.float)
    in_weights_float = weight_copy[list(remaining_indices), :, :, :]
    new_model.weight = torch.nn.Parameter(in_weights_float)
    print(f"prune input channel indice={prune_indices}, num_channels_to_prune={len(prune_indices)}, remaining_indices={remaining_indices}, prune_indices={prune_indices}")
    return new_model

In [31]:
model = DeepAppearanceVAE(1024, 21918, n_latent=256, n_cams=38)

In [33]:
print(model.dec.texture_decoder.upsample[0].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[0].conv2.deconv.weight.shape)

print(model.dec.texture_decoder.upsample[1].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[1].conv2.deconv.weight.shape)

print(model.dec.texture_decoder.upsample[2].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[2].conv2.deconv.weight.shape)

print(model.dec.texture_decoder.upsample[3].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[3].conv2.deconv.weight.shape)

torch.Size([256, 256, 4, 4])
torch.Size([256, 64, 4, 4])
torch.Size([64, 64, 4, 4])
torch.Size([64, 32, 4, 4])
torch.Size([32, 32, 4, 4])
torch.Size([32, 16, 4, 4])
torch.Size([16, 16, 4, 4])
torch.Size([16, 3, 4, 4])


In [36]:
def model_decoder_pruning(model, unified_pruning_ratio):
    model.dec.texture_decoder.upsample[0].conv1.deconv, model.dec.texture_decoder.upsample[0].conv1.bias, prune_indices_1 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[0].conv1.deconv, model.dec.texture_decoder.upsample[0].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[0].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[0].conv2.deconv, prune_indices_1)

    model.dec.texture_decoder.upsample[0].conv2.deconv, model.dec.texture_decoder.upsample[0].conv2.bias, prune_indices_2 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[0].conv2.deconv, model.dec.texture_decoder.upsample[0].conv2.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[1].conv1.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv1.deconv, prune_indices_2)

    model.dec.texture_decoder.upsample[1].conv1.deconv,  model.dec.texture_decoder.upsample[1].conv1.bias, prune_indices_3 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv1.deconv, model.dec.texture_decoder.upsample[1].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[1].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv2.deconv, prune_indices_3)

    model.dec.texture_decoder.upsample[1].conv2.deconv,  model.dec.texture_decoder.upsample[1].conv2.bias, prune_indices_4 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv2.deconv, model.dec.texture_decoder.upsample[1].conv2.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[2].conv1.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv1.deconv, prune_indices_4)

    model.dec.texture_decoder.upsample[2].conv1.deconv,  model.dec.texture_decoder.upsample[2].conv1.bias, prune_indices_5 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv1.deconv, model.dec.texture_decoder.upsample[2].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[2].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv2.deconv, prune_indices_5)

    model.dec.texture_decoder.upsample[2].conv2.deconv,  model.dec.texture_decoder.upsample[2].conv2.bias, prune_indices_6 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv2.deconv, model.dec.texture_decoder.upsample[2].conv2.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[3].conv1.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv1.deconv, prune_indices_6)

    model.dec.texture_decoder.upsample[3].conv1.deconv, model.dec.texture_decoder.upsample[3].conv1.bias, prune_indices_7 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv1.deconv, model.dec.texture_decoder.upsample[3].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[3].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv2.deconv, prune_indices_7)

    model.dec.texture_decoder.upsample[3].conv2.deconv,  model.dec.texture_decoder.upsample[3].conv2.bias, prune_indices_8 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv2.deconv, model.dec.texture_decoder.upsample[3].conv2.bias, unified_pruning_ratio)

unified_pruning_ratio = 0.3
model_decoder_pruning(unified_pruning_ratio)

remaining_indices={0, 1, 3, 4, 6, 7, 8, 9, 10, 11, 13, 15, 17, 18, 20, 22, 23, 24, 26, 27, 29, 30, 31, 32, 33, 34, 35, 37, 38, 40, 42, 44, 45, 47, 48, 52, 53, 54, 55, 56, 57, 59, 62, 66, 68, 69, 70, 71, 72, 73, 74, 77, 78, 80, 83, 84, 86, 88, 89, 90, 91, 95, 96, 97, 98, 100, 101, 102, 104, 105, 108, 109, 110, 112, 113, 115, 116, 118, 119, 121, 122, 123, 124, 125, 126, 127, 128, 131, 133, 134, 135, 136, 137, 138, 140, 142, 143, 145, 147, 148, 149, 150, 151, 152, 154, 157, 158, 159, 162, 163, 164, 165, 166, 168, 169, 170, 174, 175, 176, 177, 178, 179, 180, 181, 182, 185, 186, 187, 188, 189, 191, 192, 193, 195, 197, 198, 199, 202, 203, 204, 205, 206, 207, 208, 209, 210, 212, 213, 214, 216, 217, 219, 220, 222, 223, 224, 226, 227, 228, 229, 231, 232, 234, 236, 237, 238, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 254, 255}
weight kernel pruning original bias dimension = torch.Size([1, 256, 8, 8])
under prune_ratio=0.3, num_channels_to_prune=76, response_val=tensor([145.6904,

In [37]:
print(model.dec.texture_decoder.upsample[0].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[0].conv2.deconv.weight.shape)

print(model.dec.texture_decoder.upsample[1].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[1].conv2.deconv.weight.shape)

print(model.dec.texture_decoder.upsample[2].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[2].conv2.deconv.weight.shape)

print(model.dec.texture_decoder.upsample[3].conv1.deconv.weight.shape)
print(model.dec.texture_decoder.upsample[3].conv2.deconv.weight.shape)

# Do not prune the last layer as the input images come with 3 channels.
# model.dec.texture_decoder.upsample[3].conv2.deconv, prune_indices_8 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv2.deconv, 0.3)

torch.Size([256, 180, 4, 4])
torch.Size([180, 45, 4, 4])
torch.Size([45, 45, 4, 4])
torch.Size([45, 23, 4, 4])
torch.Size([23, 23, 4, 4])
torch.Size([23, 12, 4, 4])
torch.Size([12, 12, 4, 4])
torch.Size([12, 3, 4, 4])


In [38]:
in_avg_tex = torch.randn([16, 3, 1024, 1024])
in_verts = torch.randn([16, 7306, 3])
in_view = torch.randn([16, 3])
in_cams = torch.tensor([22, 37, 14, 19,  7, 11, 31,  2, 20, 20, 14, 21,  9,  6, 10,  5])

print(in_avg_tex.shape)
print(in_verts.shape)
print(in_view.shape)
print(in_cams.shape)

torch.Size([16, 3, 1024, 1024])
torch.Size([16, 7306, 3])
torch.Size([16, 3])
torch.Size([16])


In [39]:
model(in_avg_tex, in_verts, in_view, in_cams)

out.shape=torch.Size([16, 180, 8, 8])
self.bias.shape=torch.Size([1, 180, 8, 8])
out.shape=torch.Size([16, 45, 16, 16])
self.bias.shape=torch.Size([1, 45, 16, 16])
out.shape=torch.Size([16, 45, 32, 32])
self.bias.shape=torch.Size([1, 45, 32, 32])
out.shape=torch.Size([16, 23, 64, 64])
self.bias.shape=torch.Size([1, 23, 64, 64])
out.shape=torch.Size([16, 23, 128, 128])
self.bias.shape=torch.Size([1, 23, 128, 128])
out.shape=torch.Size([16, 12, 256, 256])
self.bias.shape=torch.Size([1, 12, 256, 256])
out.shape=torch.Size([16, 12, 512, 512])
self.bias.shape=torch.Size([1, 12, 512, 512])
out.shape=torch.Size([16, 3, 1024, 1024])
self.bias.shape=torch.Size([1, 3, 1024, 1024])


(tensor([[[[ 0.0536,  0.0732,  0.0732,  ...,  0.0738,  0.0738,  0.0606],
           [ 0.0714,  0.1056,  0.1056,  ...,  0.1068,  0.1068,  0.0719],
           [ 0.0714,  0.1056,  0.1056,  ...,  0.1068,  0.1068,  0.0719],
           ...,
           [ 0.0717,  0.1058,  0.1058,  ...,  0.1055,  0.1055,  0.0712],
           [ 0.0717,  0.1058,  0.1058,  ...,  0.1055,  0.1055,  0.0712],
           [ 0.0588,  0.0697,  0.0697,  ...,  0.0693,  0.0693,  0.0516]],
 
          [[-0.1157, -0.1001, -0.1001,  ..., -0.1002, -0.1002, -0.1146],
           [-0.1005, -0.0673, -0.0673,  ..., -0.0675, -0.0675, -0.1003],
           [-0.1005, -0.0673, -0.0673,  ..., -0.0675, -0.0675, -0.1003],
           ...,
           [-0.1007, -0.0676, -0.0676,  ..., -0.0674, -0.0674, -0.1001],
           [-0.1007, -0.0676, -0.0676,  ..., -0.0674, -0.0674, -0.1001],
           [-0.1149, -0.1005, -0.1005,  ..., -0.1003, -0.1003, -0.1154]],
 
          [[-0.0856, -0.0765, -0.0765,  ..., -0.0765, -0.0765, -0.0831],
           [-

# 3. Overall Pruning -- Training 

In [None]:
from torch.utils.tensorboard import SummaryWriter
import argparse
import json
import os
import time
import glob

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataset import Dataset
from models import DeepAppearanceVAE, WarpFieldVAE, ConvTranspose2dWN
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from utils import Renderer, gammaCorrect

In [None]:
def weight_kernel_pruning_l1_norm(model, in_bias, prune_ratio):
    layer_shape = model.state_dict()['weight'].size()
    weight_copy = model.weight.data.abs().clone()
    
    l1_norm = torch.sum(weight_copy, dim=(0, 2, 3))
    num_channels_to_prune = int(prune_ratio * layer_shape[1])
    response_val, prune_indices = torch.topk(l1_norm, num_channels_to_prune, largest=False)
    overall_indices = set([i for i in range(layer_shape[1])])
    prune_indices = set(prune_indices.tolist())
    remaining_indices = overall_indices - prune_indices

    new_model = ConvTranspose2dWN(int(layer_shape[0]), int(len(remaining_indices)), kernel_size=(4,4), stride=(2,2), padding=(1,1), bias=False)
    print(f"remaining_indices={remaining_indices}")
    print(f"weight kernel pruning original bias dimension = {in_bias.shape}")
    out_bias = torch.nn.Parameter(in_bias[:,list(remaining_indices),:,:])

    in_weights_float = torch.zeros((int(layer_shape[0]), len(remaining_indices), int(layer_shape[2]), int(layer_shape[3])), dtype=torch.float)
    in_weights_float = weight_copy[:, list(remaining_indices), :, :]
    new_model.weight = torch.nn.Parameter(in_weights_float)
    print(f"under prune_ratio={prune_ratio}, num_channels_to_prune={num_channels_to_prune}, response_val={response_val}, remaining_indices={remaining_indices}, prune_indices={prune_indices}")
    return new_model, out_bias, prune_indices

def iAct_channel_pruning_l1_norm(model, prune_indices):
    layer_shape = model.state_dict()['weight'].size()
    weight_copy = model.weight.data.abs().clone()
    
    prune_indices = set(prune_indices)
    overall_indices = set([i for i in range(layer_shape[0])])
    remaining_indices = overall_indices - prune_indices

    new_model = ConvTranspose2dWN(len(remaining_indices), int(layer_shape[1]), kernel_size=(4,4), stride=(2,2), padding=(1,1), bias=False)
    
    in_weights_float = torch.zeros((int(len(remaining_indices)), int(layer_shape[1]), int(layer_shape[2]), int(layer_shape[3])), dtype=torch.float)
    in_weights_float = weight_copy[list(remaining_indices), :, :, :]
    new_model.weight = torch.nn.Parameter(in_weights_float)
    print(f"prune input channel indice={prune_indices}, num_channels_to_prune={len(prune_indices)}, remaining_indices={remaining_indices}, prune_indices={prune_indices}")
    return new_model

In [None]:
model = DeepAppearanceVAE(1024, 21918, n_latent=256, n_cams=38)

In [None]:
def model_decoder_pruning(model, unified_pruning_ratio):
    model.dec.texture_decoder.upsample[0].conv1.deconv, model.dec.texture_decoder.upsample[0].conv1.bias, prune_indices_1 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[0].conv1.deconv, model.dec.texture_decoder.upsample[0].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[0].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[0].conv2.deconv, prune_indices_1)

    model.dec.texture_decoder.upsample[0].conv2.deconv, model.dec.texture_decoder.upsample[0].conv2.bias, prune_indices_2 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[0].conv2.deconv, model.dec.texture_decoder.upsample[0].conv2.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[1].conv1.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv1.deconv, prune_indices_2)

    model.dec.texture_decoder.upsample[1].conv1.deconv,  model.dec.texture_decoder.upsample[1].conv1.bias, prune_indices_3 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv1.deconv, model.dec.texture_decoder.upsample[1].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[1].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv2.deconv, prune_indices_3)

    model.dec.texture_decoder.upsample[1].conv2.deconv,  model.dec.texture_decoder.upsample[1].conv2.bias, prune_indices_4 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[1].conv2.deconv, model.dec.texture_decoder.upsample[1].conv2.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[2].conv1.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv1.deconv, prune_indices_4)

    model.dec.texture_decoder.upsample[2].conv1.deconv,  model.dec.texture_decoder.upsample[2].conv1.bias, prune_indices_5 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv1.deconv, model.dec.texture_decoder.upsample[2].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[2].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv2.deconv, prune_indices_5)

    model.dec.texture_decoder.upsample[2].conv2.deconv,  model.dec.texture_decoder.upsample[2].conv2.bias, prune_indices_6 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[2].conv2.deconv, model.dec.texture_decoder.upsample[2].conv2.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[3].conv1.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv1.deconv, prune_indices_6)

    model.dec.texture_decoder.upsample[3].conv1.deconv, model.dec.texture_decoder.upsample[3].conv1.bias, prune_indices_7 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv1.deconv, model.dec.texture_decoder.upsample[3].conv1.bias, unified_pruning_ratio)
    model.dec.texture_decoder.upsample[3].conv2.deconv = iAct_channel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv2.deconv, prune_indices_7)

    model.dec.texture_decoder.upsample[3].conv2.deconv,  model.dec.texture_decoder.upsample[3].conv2.bias, prune_indices_8 = weight_kernel_pruning_l1_norm(model.dec.texture_decoder.upsample[3].conv2.deconv, model.dec.texture_decoder.upsample[3].conv2.bias, unified_pruning_ratio)

unified_pruning_ratio = 0.3
model_decoder_pruning(unified_pruning_ratio)

In [None]:
lr = 3e-4
optimizer = optim.Adam(model.get_model_params(), lr, (0.9, 0.999))
optimizer_cc = optim.Adam(model.get_cc_params(), lr, (0.9, 0.999))
mse = nn.MSELoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
pred_tex, pred_verts, unwarped_tex, warp_field, kl = model(in_avg_tex, in_verts, in_view, cams=in_cams)
vert_loss = mse(pred_verts, in_verts)

tex_loss = mse(pred_tex, gt_tex) 


