In [1]:
# cyclegan, pix2pix

# baseline
# 50% sparse
# 75% sparse
# 50% sparse 8-bit
# 50% sparse 4-bit
# 75% sparse 8-bit
# 75% sparse 4-bit

In [5]:
!bash prepare_data.sh cyclegan

In [6]:
import matplotlib.pyplot as plt

In [7]:
import sys
import os
import os.path as osp
import sys
import argparse
import time
from datetime import datetime
import numpy as np
import torch
import gc
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm.auto import tqdm
import jstyleson as json
sys.path.append("./style_transfer")
sys.path.append("./qsparse-private")
sys.path.append(".")
import qsparse

has_gpu = torch.cuda.is_available()

print(sys.executable, 'gpu yes' if has_gpu else 'gpu no')

/usr/bin/python3.8 gpu yes


In [9]:
from data.unaligned_dataset import UnalignedDataset
from data.aligned_dataset import AlignedDataset
from argparse import Namespace

import numpy as np
from PIL import Image
import os


def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)


def test_loader(model_type):
    opt = Namespace()
    opt.direction = 'BtoA'
    opt.max_dataset_size = float('inf')
    opt.output_nc = 3
    opt.input_nc = 3
    opt.preprocess = 'resize_and_crop'

    opt.no_flip = False
    opt.serial_batches = False
    if model_type == 'cyclegan':
        opt.dataroot = "/home/jovyan/cyclegan/"
        opt.load_size = 143
        opt.crop_size = 128
        opt.phase = "test"
        D = UnalignedDataset(opt)
        loader = torch.utils.data.DataLoader(
            D,
            batch_size=1, shuffle=False,
            num_workers=4)
    else:
        opt.dataroot = "/home/jovyan/Cityscapes_pix2pix/cityscapes"
        opt.load_size = 286
        opt.crop_size = 256
        opt.phase = "val"
        D = AlignedDataset(opt)
        loader = torch.utils.data.DataLoader(
            D,
            batch_size=1, shuffle=False,
            num_workers=4)
    return loader


cyclegan_loader = [b for i, b in enumerate(test_loader('cyclegan')) if i <= 30]
pix2pix_loader = [b for i, b in enumerate(test_loader('pix2pix')) if i <= 30] 


def create_test_loader(model_type):
    if model_type == 'cyclegan':
        return cyclegan_loader
    else:
        return pix2pix_loader


In [10]:
def ensure_dir(path):
    if not osp.exists(path):
        os.makedirs(path)


def convert(net, json_path): 
    if json_path is None:
        return net
    conversions = json.load(open(json_path))["qsparse_parameters"]["conversions"]

    from copy import deepcopy
    import qsparse
    from qsparse import prune, convert, MagnitudePruningCallback, UniformPruningCallback, BanditPruningCallback, quantize
    from qsparse.quantize import AdaptiveLineQuantizer
    from qsparse.util import logging
    from qsparse.sparse import PruneLayer
    from qsparse.quantize import QuantizeLayer
    
    net = deepcopy(net)
    p_net = deepcopy(net)

    def to_step(x): return 0
    def layer_names_to_modules(names): return [getattr(nn, layer_name) for layer_name in names]
    def layer_names_to_modules_2rd(indexes): return [(getattr(nn, layer_name), indexes) for layer_name, indexes in indexes]

    for param in conversions:
        callback_kwargs = {}
        if "outlier_ratio" in param:
            callback_kwargs["outlier_ratio"] = param["outlier_ratio"]
        if "line" in param["callback"].lower():
            logging.danger("always use running average")
            callback_kwargs["always_running_average"] = True

        if "spa" in param:
            callback_kwargs["spa"] = True

        if "group_num" in param:
            callback_kwargs["group_num"] = param["group_num"]


        p_net = qsparse.convert(p_net, quantize(
                    bits=param["bits"], channelwise=param["channelwise"], timeout=10,
                    callback=getattr(qsparse, param["callback"])(**callback_kwargs)
                ) if param["op"] == "quantize" else prune(
                    sparsity=param["sparsity"],
                    start=0,
                    rampup=param.get("rampup", False),
                    repetition=param.get("repetition", 1),
                    callback=qsparse.MagnitudePruningCallback(mask_refresh_interval=1, use_gradient=param.get("use_gradient", False), running_average=param.get("running_average", True))
                                if param["callback"] == 'MagnitudePruningCallback' else (
                                    qsparse.UniformPruningCallback()
                                    if param["callback"] == "UniformPruningCallback" else
                                    qsparse.BanditPruningCallback(mask_refresh_interval=-1))
                    ),
                    weight_layers=layer_names_to_modules(param.get("weight_layers", [])),
                    activation_layers=layer_names_to_modules(param.get("activation_layers", [])),
                    excluded_activation_layer_indexes=layer_names_to_modules_2rd(param.get("excluded_activation_layer_indexes", [])),
                    excluded_weight_layer_indexes=layer_names_to_modules_2rd(param.get("excluded_weight_layer_indexes", [])), filter=param.get("filter", []), input=param.get("input", False), order=param.get('order', 'post'))
    return p_net

In [17]:
def fix_shapes(model, state_dict):
    new_state_dict = model.state_dict()
    for k,v in list(state_dict.items()):
        if k in new_state_dict:
            if  v.shape != new_state_dict[k].shape:
                print(f"{k} shape mismatch: {v.shape} vs {new_state_dict[k].shape}")
                state_dict[k] = new_state_dict[k]
                continue
        else:
            state_dict.pop(k)
            
    for k in new_state_dict:
        if k.endswith(".magnitude"):
            if k not in state_dict:
                state_dict[k] = new_state_dict[k]

In [11]:
import models.networks as networks

def create_net(t):
    if t == 'cyclegan':
        return networks.define_G(3, 3, 64, "resnet_6blocks", norm='instance', use_dropout=False)
    elif t == 'pix2pix':
        return networks.define_G(3, 3, 64, "unet_256", norm='batch', use_dropout=True)
    else:
        raise None
        
def get_shape(t):
    if t == 'cyclegan':
        return (1, 3, 128, 128)
    else:
        return (1, 3, 256, 256)

In [12]:
configurations = dict(
    cyclegan=dict(
        baseline=(None, 
                  "./checkpoints/citysgan/res/baseline/cityscape_cyclegan/latest_net_G_A.pth"),
        
        p50=("./configs/citysgan/res/a50_mag_top_loc_st.json",
             "./checkpoints/citysgan/res/a50_mag_top_loc_st/cityscape_cyclegan/latest_net_G_A.pth"),
        
        p75=("./configs/citysgan/res/a75_mag_top_loc_st.json",
             "./checkpoints/citysgan/res/a75_mag_top_loc_st/cityscape_cyclegan/latest_net_G_A.pth"),
        
        p50q8=("./configs/citysgan/res/a50_wa8_t_st.json",
               "./checkpoints/citysgan/res/joint_pq/a50_wa8_t_st.G_A.pth"),
        
        p50q4=("./configs/citysgan/res/a50_wa4_t_st.json",
               "./checkpoints/citysgan/res/joint_pq/a50_wa4_t_st.G_A.pth"),
        
        p75q8=("./configs/citysgan/res/a75_wa8_t_st.json", 
               "./checkpoints/citysgan/res/joint_pq/a75_wa8_t_st.G_A.pth"),
        
        p75q4=("./configs/citysgan/res/a75_wa4_t_st.json", 
               "./checkpoints/citysgan/res/joint_pq/a75_wa4_t_st.G_A.pth")
    ),
    pix2pix=dict(
        baseline=(None, 
                  "./checkpoints/citysgan/unet/baseline/cityscapes_pix2pix/latest_net_G.pth"),
        
        p50=("./configs/citysgan/unet/a50_mag_top_loc_st.json", 
             "./checkpoints/citysgan/unet/a50_mag_top_loc_st/cityscapes_pix2pix/latest_net_G.pth"),
        
        p75=("./configs/citysgan/unet/a75_mag_top_loc_st.json", 
             "./checkpoints/citysgan/unet/a75_mag_top_loc_st/cityscapes_pix2pix/latest_net_G.pth"),
        
        p50q8=("./configs/citysgan/unet/a50_wa8_t_st.json",
               "./checkpoints/citysgan/unet/joint_pq/a50_wa8_t_st.pth"),
        
        p50q4=("./configs/citysgan/unet/a50_wa4_t_st.json",
               "./checkpoints/citysgan/unet/joint_pq/a50_wa4_t_st.pth")
        
       #  p75q8=("./configs/citysgan/unet/a75_wa8_t_st.json", 
       #         "need to retrain, previous one uses wrong pretrain model"),
        
       #  p75q4=("./configs/citysgan/unet/a75_wa4_t_st.json", 
       #         "")
    )
)

In [18]:
root = "./images"

with tqdm(total=14) as bar:
    for net_name, v in configurations.items():
        # if net_name == 'cyclegan':
        #     continue

        for exp_name, (json_path, params_path) in v.items():
            save_path = osp.join(root, net_name, exp_name)
            ensure_dir(save_path)
            print(save_path)

            if not params_path:
                continue
            # if json_path:
            #     continue

            # print(json_path)
            net = create_net(net_name)
            net = convert(net, json_path)
            net(torch.rand(*get_shape(net_name)))
            states = torch.load(params_path, map_location=torch.device('cpu'))
            # print(states.keys())
            fix_shapes(net, states)

            net.load_state_dict(states)
            net.eval()
            net.cuda()
            loader = create_test_loader(net_name)

            with torch.no_grad():
                for i, batch in enumerate(loader):
                    B = batch["B"]

                    img = tensor2im(B)
                    im = Image.fromarray(img)
                    im.save(osp.join(save_path, f"{i}-in.jpg"))

                    if has_gpu:
                        B = B.cuda()
                    out = net(B)
                    img = tensor2im(out)
                    im = Image.fromarray(img)
                    im.save(osp.join(save_path, f"{i}-out.jpg"))
            bar.update()

/workspace/code/experiments/MDPI/analysis/generate_images/cyclegan/baseline
initialize network with normal
/workspace/code/experiments/MDPI/analysis/generate_images/cyclegan/p50
initialize network with normal
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.1 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.4 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.7 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.10.conv_block.1 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.10.conv_block.5 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.11.conv_block.1 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.11.conv_block.5 weight
[Quantize] bits=0 channelwise=0 timeout=10
Apply `quantizebits=0` on the .model.12.conv_block.1 weight
[Quantize