In [None]:
import torch
import torch.nn as nn
from typing import Union, List
from torch.nn.utils import prune
from models import DeepAppearanceVAE, WarpFieldVAE

In [None]:
@torch.no_grad()
def channel_prune_decoder(model: nn.Module,
                  prune_ratio: Union[List, float]) -> nn.Module:
    """Apply channel pruning to each of the conv layer in the backbone
    Note that for prune_ratio, we can either provide a floating-point number,
    indicating that we use a uniform pruning rate for all layers, or a list of
    numbers to indicate per-layer pruning rate.
    """
    model[0].conv1.deconv = prune.ln_structured(model[0].conv1.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[0].conv2.deconv = prune.ln_structured(model[0].conv2.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))
    
    model[0].conv2.deconv = prune.ln_structured(model[0].conv2.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[1].conv1.deconv = prune.ln_structured(model[1].conv1.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))

    model[1].conv1.deconv = prune.ln_structured(model[1].conv1.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[1].conv2.deconv = prune.ln_structured(model[1].conv2.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))

    model[1].conv2.deconv = prune.ln_structured(model[1].conv2.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[2].conv1.deconv = prune.ln_structured(model[2].conv1.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))

    model[2].conv1.deconv = prune.ln_structured(model[2].conv1.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[2].conv2.deconv = prune.ln_structured(model[2].conv2.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))

    model[2].conv2.deconv = prune.ln_structured(model[2].conv2.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[3].conv1.deconv = prune.ln_structured(model[3].conv1.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))

    model[3].conv1.deconv = prune.ln_structured(model[3].conv1.deconv, 'weight', amount=prune_ratio, dim=0, n=float('inf'))
    model[3].conv2.deconv = prune.ln_structured(model[3].conv2.deconv, 'weight', amount=prune_ratio, dim=1, n=float('inf'))
    
    return model