In [1]:
# copied from train_interpreter.py

import imageio
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152

import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
import scipy.misc
import json
from collections import OrderedDict
import numpy as np
import os
device_ids = [0]
from PIL import Image
import gc

import pickle
import copy
from numpy.random import choice
from torch.distributions import Categorical
import scipy.stats
import torch.optim as optim
import argparse
import glob
from torch.utils.data import Dataset, DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import cv2

In [2]:
!unzip model_dir_17imgs.zip

Archive:  model_dir_17imgs.zip
   creating: model_dir/
   creating: model_dir/cat_16/
  inflating: model_dir/cat_16/model_20parts_iter45000_number_4.pth  
  inflating: model_dir/cat_16/model_20parts_iter10000_number_9.pth  
  inflating: model_dir/cat_16/model_20parts_iter5000_number_5.pth  
  inflating: model_dir/cat_16/model_20parts_iter30000_number_0.pth  
  inflating: model_dir/cat_16/model_9.pth  
  inflating: model_dir/cat_16/model_20parts_iter50000_number_0.pth  
  inflating: model_dir/cat_16/model_20parts_iter25000_number_0.pth  
  inflating: model_dir/cat_16/model_20parts_iter15000_number_9.pth  
  inflating: model_dir/cat_16/model_20parts_iter15000_number_5.pth  
  inflating: model_dir/cat_16/model_20parts_iter15000_number_0.pth  
  inflating: model_dir/cat_16/model_20parts_iter25000_number_5.pth  
  inflating: model_dir/cat_16/model_20parts_iter50000_number_9.pth  
  inflating: model_dir/cat_16/model_20parts_iter60000_number_6.pth  
  inflating: model_dir/cat_16/model_20parts

## Utils

In [2]:
class Interpolate(nn.Module):
    def __init__(self, size, mode):
        super(Interpolate, self).__init__()
        self.interp = nn.functional.interpolate
        self.size = size
        self.mode = mode

    def forward(self, x):
        x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
        return x



def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim=1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim=1)

    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)

    acc = acc * 100

    return acc


def oht_to_scalar(y_pred):
    y_pred_softmax = torch.log_softmax(y_pred, dim=1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim=1)

    return y_pred_tags

def latent_to_image(g_all, upsamplers, latents, return_upsampled_layers=False, use_style_latents=False,
                    style_latents=None, process_out=True, return_stylegan_latent=False, dim=512, return_only_im=False):
    '''Given a input latent code, generate corresponding image and concatenated feature maps'''

    # assert (len(latents) == 1)  # for GPU memory constraints
    if not use_style_latents:
        # generate style_latents from latents
        style_latents = g_all.module.truncation(g_all.module.g_mapping(latents))
        style_latents = style_latents.clone()  # make different layers non-alias

    else:
        style_latents = latents

        # style_latents = latents
    if return_stylegan_latent:

        return  style_latents
    img_list, affine_layers = g_all.module.g_synthesis(style_latents)

    if return_only_im:
        if process_out:
            if img_list.shape[-2] > 512:
                img_list = upsamplers[-1](img_list)

            img_list = img_list.cpu().detach().numpy()
            img_list = process_image(img_list)
            img_list = np.transpose(img_list, (0, 2, 3, 1)).astype(np.uint8)
        return img_list, style_latents

    number_feautre = 0

    for item in affine_layers:
        number_feautre += item.shape[1]


    affine_layers_upsamples = torch.FloatTensor(1, number_feautre, dim, dim)
    if torch.cuda.is_available():
        affine_layers_upsamples = affine_layers_upsamples.cuda()
    if return_upsampled_layers:

        start_channel_index = 0
        for i in range(len(affine_layers)):
            len_channel = affine_layers[i].shape[1]
            affine_layers_upsamples[:, start_channel_index:start_channel_index + len_channel] = upsamplers[i](
                affine_layers[i])
            start_channel_index += len_channel

    if img_list.shape[-2] != 512:
        img_list = upsamplers[-1](img_list)

    if process_out:
        img_list = img_list.cpu().detach().numpy()
        img_list = process_image(img_list)
        img_list = np.transpose(img_list, (0, 2, 3, 1)).astype(np.uint8)
        # print('start_channel_index',start_channel_index)


    return img_list, affine_layers_upsamples


def process_image(images):
    drange = [-1, 1]
    scale = 255 / (drange[1] - drange[0])
    images = images * scale + (0.5 - drange[0] * scale)

    images = images.astype(int)
    images[images > 255] = 255
    images[images < 0] = 0

    return images.astype(int)

def colorize_mask(mask, palette):
    # mask: numpy array of the mask

    new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
    new_mask.putpalette(palette)
    return np.array(new_mask.convert('RGB'))


def get_label_stas(data_loader):
    count_dict = {}
    for i in range(data_loader.__len__()):
        x, y = data_loader.__getitem__(i)
        if int(y.item()) not in count_dict:
            count_dict[int(y.item())] = 1
        else:
            count_dict[int(y.item())] += 1

    return count_dict


## StyleGAN nets

In [3]:
class MyLinear(nn.Module):
    """Linear layer with equalized learning rate and custom learning rate multiplier."""

    def __init__(self, input_size, output_size, gain=2 ** (0.5), use_wscale=False, lrmul=1, bias=True):
        super().__init__()
        he_std = gain * input_size ** (-0.5)  # He init
        # Equalized learning rate and custom learning rate multiplier.
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul
        self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(output_size))
            self.b_mul = lrmul
        else:
            self.bias = None

    def forward(self, x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul
        return F.linear(x, self.weight * self.w_mul, bias)


class MyConv2d(nn.Module):
    """Conv layer with equalized learning rate and custom learning rate multiplier."""

    def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** (0.5), use_wscale=False,
                 lrmul=1, bias=True,
                 intermediate=None, upscale=False, downscale=False):
        super().__init__()
        if upscale:
            self.upscale = Upscale2d()
        else:
            self.upscale = None
        if downscale:
            self.downscale = Downscale2d()
        else:
            self.downscale = None
        he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5)  # He init
        self.kernel_size = kernel_size
        if use_wscale:
            init_std = 1.0 / lrmul
            self.w_mul = he_std * lrmul
        else:
            init_std = he_std / lrmul
            self.w_mul = lrmul
        self.weight = torch.nn.Parameter(
            torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(output_channels))
            self.b_mul = lrmul
        else:
            self.bias = None
        self.intermediate = intermediate

    def forward(self, x):
        bias = self.bias
        if bias is not None:
            bias = bias * self.b_mul

        have_convolution = False
        if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
            # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
            # this really needs to be cleaned up and go into the conv...
            w = self.weight * self.w_mul
            w = w.permute(1, 0, 2, 3)
            # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
            w = F.pad(w, (1, 1, 1, 1))
            w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
            x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
            have_convolution = True
        elif self.upscale is not None:
            x = self.upscale(x)

        downscale = self.downscale
        intermediate = self.intermediate
        if downscale is not None and min(x.shape[2:]) >= 128:
            w = self.weight * self.w_mul
            w = F.pad(w, (1, 1, 1, 1))
            # in contrast to upscale, this is a mean...
            w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25  # avg_pool?
            x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
            have_convolution = True
            downscale = None
        elif downscale is not None:
            assert intermediate is None
            intermediate = downscale

        if not have_convolution and intermediate is None:
            return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2)
        elif not have_convolution:
            x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2)

        if intermediate is not None:
            x = intermediate(x)

        if bias is not None:
            x = x + bias.view(1, -1, 1, 1)
        return x


class NoiseLayer(nn.Module):
    """adds noise. noise is per pixel (constant over channels) with per-channel weight"""

    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(channels))
        self.noise = None

    def forward(self, x, noise=None):
        if noise is None and self.noise is None:
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        elif noise is None:
            # here is a little trick: if you get all the noiselayers and set each
            # modules .noise attribute, you can have pre-defined noise.
            # Very useful for analysis
            noise = self.noise
        x = x + self.weight.view(1, -1, 1, 1) * noise
        return x


class StyleMod(nn.Module):
    def __init__(self, latent_size, channels, use_wscale):
        super(StyleMod, self).__init__()
        self.lin = MyLinear(latent_size,
                            channels * 2,
                            gain=1.0, use_wscale=use_wscale)
        self.x_param_backup = None

    def forward(self, x, latent, latent_after_trans=None):
        if x is not None:
            if latent_after_trans is None:
                style = self.lin(latent)  # style => [batch_size, n_channels*2]
                shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
                style = style.view(shape)  # [batch_size, 2, n_channels, ...]
            else:
                style = latent_after_trans

            self.x_param_backup = [x.size(1), x.dim()]
            x = x * (style[:, 0] + 1.) + style[:, 1]
            return x

        else:
            if self.x_param_backup is None:
                print('error: have intialize shape yet')
            # print('Generating latent_after_trans:')
            style = self.lin(latent)  # style => [batch_size, n_channels*2]
            shape = [-1, 2, self.x_param_backup[0]] + (self.x_param_backup[1] - 2) * [1]
            style = style.view(shape)  # [batch_size, 2, n_channels, ...]
            return style


class PixelNormLayer(nn.Module):
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, x):
        return x * torch.rsqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)


# Upscale and blur layers


class BlurLayer(nn.Module):
    def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
        super(BlurLayer, self).__init__()
        kernel = torch.tensor(kernel, dtype=torch.float32)
        kernel = kernel[:, None] * kernel[None, :]
        kernel = kernel[None, None]
        if normalize:
            kernel = kernel / kernel.sum()
        if flip:
            kernel = kernel[:, :, ::-1, ::-1]
        self.register_buffer('kernel', kernel)
        self.stride = stride

    def forward(self, x):
        # expand kernel channels
        kernel = self.kernel.expand(x.size(1), -1, -1, -1)
        x = F.conv2d(
            x,
            kernel,
            stride=self.stride,
            padding=int((self.kernel.size(2) - 1) / 2),
            groups=x.size(1)
        )
        return x


def upscale2d(x, factor=2, gain=1):
    assert x.dim() == 4
    if gain != 1:
        x = x * gain
    if factor != 1:
        shape = x.shape
        x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
        x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
    return x


class Upscale2d(nn.Module):
    def __init__(self, factor=2, gain=1):
        super().__init__()
        assert isinstance(factor, int) and factor >= 1
        self.gain = gain
        self.factor = factor

    def forward(self, x):
        return upscale2d(x, factor=self.factor, gain=self.gain)


class G_mapping(nn.Sequential):
    def __init__(self, nonlinearity='lrelu', use_wscale=True):
        act, gain = {'relu': (torch.relu, np.sqrt(2)),
                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
        layers = [
            ('pixel_norm', PixelNormLayer()),
            ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense0_act', act),
            ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense1_act', act),
            ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense2_act', act),
            ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense3_act', act),
            ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense4_act', act),
            ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense5_act', act),
            ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense6_act', act),
            ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
            ('dense7_act', act)
        ]
        super().__init__(OrderedDict(layers))



    def make_mean_latent(self, n_latent):
        latent_in = torch.randn(
            n_latent, 512
        ).cuda()
        mean_latent = super().forward(latent_in).mean(0, keepdim=True)
        mean_latent = mean_latent.unsqueeze(1).expand(-1, 18, -1)
        return mean_latent

    def forward(self, x):
        x = super().forward(x)
        # Broadcast
        x = x.unsqueeze(1).expand(-1, 18, -1)
        return x


class Truncation(nn.Module):
    def __init__(self, avg_latent, device, max_layer=8, threshold=0.7):
        super().__init__()
        self.max_layer = max_layer
        self.threshold = threshold
        self.avg_latent = avg_latent
        self.device = device
        # self.register_buffer('avg_latent', avg_latent)

    def forward(self, x):
        assert x.dim() == 3
        interp = torch.lerp(self.avg_latent, x, self.threshold)
        do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1).to(self.device)
        return torch.where(do_trunc, interp, x)


class LayerEpilogue(nn.Module):
    """Things to do at the end of each layer."""

    def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles,
                 activation_layer):
        super().__init__()
        layers = []
        if use_noise:
            layers.append(('noise', NoiseLayer(channels)))
        layers.append(('activation', activation_layer))
        if use_pixel_norm:
            layers.append(('pixel_norm', PixelNorm()))
        if use_instance_norm:
            layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
        self.top_epi = nn.Sequential(OrderedDict(layers))
        if use_styles:
            self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)
        else:
            self.style_mod = None

    def forward(self, x, dlatents_in_slice=None, latent_after_trans=None):
        x = self.top_epi(x)
        if self.style_mod is not None:
            if latent_after_trans is None:
                x = self.style_mod(x, dlatents_in_slice)
            else:
                x = self.style_mod(x, dlatents_in_slice, latent_after_trans)
        else:
            assert dlatents_in_slice is None
        return x


class InputBlock(nn.Module):
    def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm,
                 use_instance_norm, use_styles, activation_layer):
        super().__init__()
        self.const_input_layer = const_input_layer
        self.nf = nf
        if self.const_input_layer:
            # called 'const' in tf
            self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
            self.bias = nn.Parameter(torch.ones(nf))
        else:
            self.dense = MyLinear(dlatent_size, nf * 16, gain=gain / 4,
                                  use_wscale=use_wscale)  # tweak gain to match the official implementation of Progressing GAN
        self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
                                  use_styles, activation_layer)
        self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
        self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
                                  use_styles, activation_layer)

    def forward(self, dlatents_in_range, latent_after_trans=None):
        batch_size = dlatents_in_range.size(0)
        if self.const_input_layer:
            x = self.const.expand(batch_size, -1, -1, -1)
            x = x + self.bias.view(1, -1, 1, 1)
        else:
            x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)

        if latent_after_trans is None:
            x = self.epi1(x, dlatents_in_range[:, 0])
        else:
            x = self.epi1(x, dlatents_in_range[:, 0], latent_after_trans[0])  # latent_after_trans is a list

        x = self.conv(x)

        if latent_after_trans is None:
            x1 = self.epi2(x, dlatents_in_range[:, 1])
        else:
            x1 = self.epi2(x, dlatents_in_range[:, 1], latent_after_trans[1])

        return x1, x


class GSynthesisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise,
                 use_pixel_norm, use_instance_norm, use_styles, activation_layer):
        # 2**res x 2**res # res = 3..resolution_log2
        super().__init__()
        if blur_filter:
            blur = BlurLayer(blur_filter)
        else:
            blur = None
        self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
                                 intermediate=blur, upscale=True)
        self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
                                  use_styles, activation_layer)
        self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)
        self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
                                  use_styles, activation_layer)

    def forward(self, x, dlatents_in_range, latent_after_trans=None):
        x = self.conv0_up(x)

        if latent_after_trans is None:
            x = self.epi1(x, dlatents_in_range[:, 0])
        else:
            x = self.epi1(x, dlatents_in_range[:, 0], latent_after_trans[0])  # latent_after_trans is a list
        x = self.conv1(x)

        if latent_after_trans is None:
            x1 = self.epi2(x, dlatents_in_range[:, 1])
        else:
            x1 = self.epi2(x, dlatents_in_range[:, 1], latent_after_trans[1])
        return x1, x


class SegSynthesisBlock(nn.Module):
    def __init__(self, prev_channel, current_channel, single_in=False):
        super().__init__()
        self.single_in = single_in
        # self.in_conv = nn.Sequential(
        #     nn.ReLU(),
        #     nn.Conv2d(current_channel, current_channel, 3, 1, 1),
        #     nn.BatchNorm2d(current_channel),
        #     nn.ReLU(),
        #     nn.Conv2d(current_channel, current_channel, 1),
        #     nn.BatchNorm2d(current_channel)
        # )

        if not single_in:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear")

            self.out_conv1 = nn.Sequential(
                nn.ReLU(),
                nn.Conv2d(current_channel + prev_channel, current_channel, 1, 1, 0),
                nn.BatchNorm2d(current_channel)
            )

        self.out_conv2 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(current_channel + current_channel, current_channel, 1, 1, 0),
            nn.BatchNorm2d(current_channel)
        )

    def forward(self, x_curr, x_curr2, x_prev=None):

        # x_curr = self.in_conv(x_curr)

        if self.single_in:
            x_middle = x_curr
        else:
            x_prev = self.up(x_prev)
            x_concat = torch.cat([x_curr, x_prev], 1)

            x_middle = self.out_conv1(x_concat)

            x_middle = x_middle + x_curr

        x_concat2 = torch.cat([x_curr2, x_middle], 1)
        x_out = self.out_conv2(x_concat2)
        x_out = x_out + x_curr2
        return x_out


class G_synthesis(nn.Module):
    def __init__(self,
                 dlatent_size=512,  # Disentangled latent (W) dimensionality.
                 num_channels=3,  # Number of output color channels.
                 resolution=512,  # Output resolution.
                 fmap_base=8192,  # Overall multiplier for the number of feature maps.
                 fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
                 fmap_max=512,  # Maximum number of feature maps in any layer.
                 use_styles=True,  # Enable style inputs?
                 const_input_layer=True,  # First layer is a learned constant?
                 use_noise=True,  # Enable noise inputs?
                 randomize_noise=True,
                 # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
                 nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu'
                 use_wscale=True,  # Enable equalized learning rate?
                 use_pixel_norm=False,  # Enable pixelwise feature vector normalization?
                 use_instance_norm=True,  # Enable instance normalization?
                 dtype=torch.float32,  # Data type to use for activations and outputs.
                 fused_scale='auto',
                 # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically.
                 blur_filter=[1, 2, 1],  # Low-pass filter to apply when resampling activations. None = no filtering.
                 structure='auto',
                 # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
                 is_template_graph=False,
                 # True = template graph constructed by the Network class, False = actual evaluation.
                 force_clean_graph=False,
                 # True = construct a clean graph that looks nice in TensorBoard, False = default behavior.
                 seg_branch=False
                 ):

        super().__init__()

        def nf(stage):
            return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

        self.dlatent_size = dlatent_size
        self.seg_branch = seg_branch
        resolution_log2 = int(np.log2(resolution))
        assert resolution == 2 ** resolution_log2 and resolution >= 4
        if is_template_graph: force_clean_graph = True
        if force_clean_graph: randomize_noise = False
        if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive'

        act, gain = {'relu': (torch.relu, np.sqrt(2)),
                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
        num_layers = resolution_log2 * 2 - 2
        num_styles = num_layers if use_styles else 1
        torgbs = []
        blocks = []
        if self.seg_branch:
            seg_block = []
        for res in range(2, resolution_log2 + 1):

            channels = nf(res - 1)
            name = '{s}x{s}'.format(s=2 ** res)
            if res == 2:
                blocks.append((name,
                               InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
                                          use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))



            else:
                blocks.append((name,
                               GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale,
                                               use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))

                if self.seg_branch:

                    name = '{s}x{s}_seg'.format(s=2 ** res)

                    if len(seg_block) == 0:
                        seg_block.append((name,
                                          SegSynthesisBlock(last_channels, channels, single_in=True)))
                    else:
                        seg_block.append((name,
                                          SegSynthesisBlock(last_channels, channels)))

            last_channels = channels
        self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)

        self.blocks = nn.ModuleDict(OrderedDict(blocks))
        if self.seg_branch:
            seg_block.append(("seg_out", nn.Conv2d(channels, 34, 1)))
            self.seg_block = nn.ModuleDict(OrderedDict(seg_block))

    def forward(self, dlatents_in, latent_after_trans=None):
        # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
        # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
        batch_size = dlatents_in.size(0)
        result_list = []

        if self.seg_branch:
            seg_branch_feature = None
        for i, m in enumerate(self.blocks.values()):
            if i == 0:
                if latent_after_trans is None:
                    x, x2 = m(dlatents_in[:, 2 * i:2 * i + 2])
                else:
                    x, x2 = m(dlatents_in[:, 2 * i:2 * i + 2], latent_after_trans[2 * i:2 * i + 2])
            else:

                if latent_after_trans is None:
                    x, x2 = m(x, dlatents_in[:, 2 * i:2 * i + 2])
                else:
                    x, x2 = m(x, dlatents_in[:, 2 * i:2 * i + 2],
                              latent_after_trans[2 * i:2 * i + 2])  # latent_after_trans is a tensor list

                if self.seg_branch:

                    name = '{s}x{s}_seg'.format(s=2 ** (i + 2))

                    curr_seg_block = self.seg_block[name]
                    if seg_branch_feature is None:
                        seg_branch_feature = curr_seg_block(x2, x)
                    else:
                        seg_branch_feature = curr_seg_block(x2, x, x_prev=seg_branch_feature)

            result_list.append(x)
            result_list.append(x2)
        rgb = self.torgb(x)
        if self.seg_branch:
            seg = self.seg_block["seg_out"](seg_branch_feature)
            return rgb, seg, result_list
        return rgb, result_list


#### define discriminator

class StddevLayer(nn.Module):
    def __init__(self, group_size=4, num_new_features=1):
        super().__init__()
        self.group_size = 4
        self.num_new_features = 1

    def forward(self, x):
        b, c, h, w = x.shape
        group_size = min(self.group_size, b)
        y = x.reshape([group_size, -1, self.num_new_features,
                       c // self.num_new_features, h, w])
        y = y - y.mean(0, keepdim=True)
        y = (y ** 2).mean(0, keepdim=True)
        y = (y + 1e-8) ** 0.5
        y = y.mean([3, 4, 5], keepdim=True).squeeze(3)  # don't keep the meaned-out channels
        y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w)
        z = torch.cat([x, y], dim=1)
        return z


class Downscale2d(nn.Module):
    def __init__(self, factor=2, gain=1):
        super().__init__()
        assert isinstance(factor, int) and factor >= 1
        self.factor = factor
        self.gain = gain
        if factor == 2:
            f = [np.sqrt(gain) / factor] * factor
            self.blur = BlurLayer(kernel=f, normalize=False, stride=factor)
        else:
            self.blur = None

    def forward(self, x):
        assert x.dim() == 4
        # 2x2, float32 => downscale using _blur2d().
        if self.blur is not None and x.dtype == torch.float32:
            return self.blur(x)

        # Apply gain.
        if self.gain != 1:
            x = x * self.gain

        # No-op => early exit.
        if factor == 1:
            return x

        # Large factor => downscale using tf.nn.avg_pool().
        # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
        return F.avg_pool2d(x, self.factor)


class DiscriminatorBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer):
        super().__init__(OrderedDict([
            ('conv0', MyConv2d(in_channels, in_channels, 3, gain=gain, use_wscale=use_wscale)),
            # out channels nf(res-1)
            ('act0', activation_layer),
            ('blur', BlurLayer()),
            ('conv1_down', MyConv2d(in_channels, out_channels, 3, gain=gain, use_wscale=use_wscale, downscale=True)),
            ('act1', activation_layer)]))


class View(nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(x.size(0), *self.shape)


class DiscriminatorTop(nn.Sequential):
    def __init__(self, mbstd_group_size, mbstd_num_features, in_channels, intermediate_channels, gain, use_wscale,
                 activation_layer, resolution=4, in_channels2=None, output_features=1, last_gain=1):
        layers = []
        if mbstd_group_size > 1:
            layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features)))
        if in_channels2 is None:
            in_channels2 = in_channels
        layers.append(
            ('conv', MyConv2d(in_channels + mbstd_num_features, in_channels2, 3, gain=gain, use_wscale=use_wscale)))
        layers.append(('act0', activation_layer))
        layers.append(('view', View(-1)))
        layers.append(('dense0', MyLinear(in_channels2 * resolution * resolution, intermediate_channels, gain=gain,
                                          use_wscale=use_wscale)))
        layers.append(('act1', activation_layer))
        layers.append(
            ('dense1', MyLinear(intermediate_channels, output_features, gain=last_gain, use_wscale=use_wscale)))
        super().__init__(OrderedDict(layers))


class D_basic(nn.Sequential):

    def __init__(self,
                 # images_in,                          # First input: Images [minibatch, channel, height, width].
                 # labels_in,                          # Second input: Labels [minibatch, label_size].
                 num_channels=3,  # Number of input color channels. Overridden based on dataloader.
                 resolution=512,  # Input resolution. Overridden based on dataloader.
                 fmap_base=8192,  # Overall multiplier for the number of feature maps.
                 fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
                 fmap_max=512,  # Maximum number of feature maps in any layer.
                 nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu',
                 use_wscale=True,  # Enable equalized learning rate?
                 mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
                 mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
                 # blur_filter         = [1,2,1],      # Low-pass filter to apply when resampling activations. None = no filtering.
                 ):
        self.mbstd_group_size = 4
        self.mbstd_num_features = 1
        resolution_log2 = int(np.log2(resolution))
        assert resolution == 2 ** resolution_log2 and resolution >= 4

        def nf(stage):
            return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

        act, gain = {'relu': (torch.relu, np.sqrt(2)),
                     'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
        self.gain = gain
        self.use_wscale = use_wscale
        super().__init__(OrderedDict([
                                         ('fromrgb', MyConv2d(num_channels, nf(resolution_log2 - 1), 1, gain=gain,
                                                              use_wscale=use_wscale)),
                                         ('act', act)]
                                     + [('{s}x{s}'.format(s=2 ** res),
                                         DiscriminatorBlock(nf(res - 1), nf(res - 2), gain=gain, use_wscale=use_wscale,
                                                            activation_layer=act)) for res in
                                        range(resolution_log2, 2, -1)]
                                     + [('4x4',
                                         DiscriminatorTop(mbstd_group_size, mbstd_num_features, nf(2), nf(2), gain=gain,
                                                          use_wscale=use_wscale, activation_layer=act))]))

## Dataset

In [4]:
class trainData(Dataset):
    def __init__(self, X_data, y_data):
        print('initializing train data')
        self.X_data = X_data
        print('x_data set')
        self.y_data = y_data
        print('y_data set')
        print('trainData initialized')

    def __getitem__(self, index):
        return self.X_data[index], self.y_data[index]

    def __len__(self):
        return len(self.X_data)


class pixel_classifier(nn.Module):
    def __init__(self, numpy_class, dim):
        super(pixel_classifier, self).__init__()
        if numpy_class < 32:
            self.layers = nn.Sequential(
                nn.Linear(dim, 128),
                nn.ReLU(),
                nn.BatchNorm1d(num_features=128),
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.BatchNorm1d(num_features=32),
                nn.Linear(32, numpy_class),
                # nn.Sigmoid()
            )
        else:
            self.layers = nn.Sequential(
                nn.Linear(dim, 256),
                nn.ReLU(),
                nn.BatchNorm1d(num_features=256),
                nn.Linear(256, 128),
                nn.ReLU(),
                nn.BatchNorm1d(num_features=128),
                nn.Linear(128, numpy_class),
                # nn.Sigmoid()
            )

    def init_weights(self, init_type='normal', gain=0.02):
        '''
        initialize network's weights
        init_type: normal | xavier | kaiming | orthogonal
        https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/
        9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
        '''

        def init_func(m):
            classname = m.__class__.__name__
            if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)

                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

            elif classname.find('BatchNorm2d') != -1:
                nn.init.normal_(m.weight.data, 1.0, gain)
                nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)

    def forward(self, x):
        return self.layers(x)

## prepare_stylegan

In [5]:
def prepare_stylegan(args):
    print('preparing stylegan...')
    if args['stylegan_ver'] == "1":
        if args['category'] == "car":
            resolution = 512
            max_layer = 8
        elif  args['category'] == "face":
            resolution = 1024
            max_layer = 8
        elif args['category'] == "bedroom":
            resolution = 256
            max_layer = 7
        elif args['category'] == "cat":
            resolution = 256
            max_layer = 7
        else:
            assert "Not implementated!"

        avg_latent = np.load(args['average_latent'])
        avg_latent = torch.from_numpy(avg_latent).type(torch.FloatTensor).to(device)

        g_all = nn.Sequential(OrderedDict([
            ('g_mapping', G_mapping()),
            ('truncation', Truncation(avg_latent,max_layer=max_layer, device=device, threshold=0.7)),
            ('g_synthesis', G_synthesis( resolution=resolution))
        ]))

        g_all.load_state_dict(torch.load(args['stylegan_checkpoint'], map_location=device))
        g_all.eval()
        g_all = nn.DataParallel(g_all, device_ids=device_ids)
        if torch.cuda.is_available():
            g_all = g_all.cuda()

    else:
        assert "Not implementated error"

    res  = args['dim'][1]
    mode = args['upsample_mode']
    upsamplers = [nn.Upsample(scale_factor=res / 4, mode=mode),
                  nn.Upsample(scale_factor=res / 4, mode=mode),
                  nn.Upsample(scale_factor=res / 8, mode=mode),
                  nn.Upsample(scale_factor=res / 8, mode=mode),
                  nn.Upsample(scale_factor=res / 16, mode=mode),
                  nn.Upsample(scale_factor=res / 16, mode=mode),
                  nn.Upsample(scale_factor=res / 32, mode=mode),
                  nn.Upsample(scale_factor=res / 32, mode=mode),
                  nn.Upsample(scale_factor=res / 64, mode=mode),
                  nn.Upsample(scale_factor=res / 64, mode=mode),
                  nn.Upsample(scale_factor=res / 128, mode=mode),
                  nn.Upsample(scale_factor=res / 128, mode=mode),
                  nn.Upsample(scale_factor=res / 256, mode=mode),
                  nn.Upsample(scale_factor=res / 256, mode=mode)
                  ]

    if resolution > 256:
        upsamplers.append(nn.Upsample(scale_factor=res / 512, mode=mode))
        upsamplers.append(nn.Upsample(scale_factor=res / 512, mode=mode))

    if resolution > 512:

        upsamplers.append(Interpolate(res, 'bilinear'))
        upsamplers.append(Interpolate(res, 'bilinear'))

    return g_all, avg_latent, upsamplers

In [6]:
car_20_palette =[ 255,  255,  255, # 0 background
  238,  229,  102,# 1 back_bumper
  0, 0, 0,# 2 bumper
  124,  99 , 34, # 3 car
  193 , 127,  15,# 4 car_lights
  248  ,213 , 42, # 5 door
  220  ,147 , 77, # 6 fender
  99 , 83  , 3, # 7 grilles
  116 , 116 , 138,  # 8 handles
  200  ,226 , 37, # 9 hoods
  225 , 184 , 161, # 10 licensePlate
  142 , 172  ,248, # 11 mirror
  153 , 112 , 146, # 12 roof
  38  ,112 , 254, # 13 running_boards
  229 , 30  ,141, # 14 tailLight
  52 , 83  ,84, # 15 tire
  194 , 87 , 125, # 16 trunk_lids
  225,  96  ,18,  # 17 wheelhub
  31 , 102 , 211, # 18 window
  104 , 131 , 101# 19 windshield
         ]



face_palette = [  1.0000,  1.0000 , 1.0000,
              0.4420,  0.5100 , 0.4234,
              0.8562,  0.9537 , 0.3188,
              0.2405,  0.4699 , 0.9918,
              0.8434,  0.9329  ,0.7544,
              0.3748,  0.7917 , 0.3256,
              0.0190,  0.4943 , 0.3782,
              0.7461 , 0.0137 , 0.5684,
              0.1644,  0.2402 , 0.7324,
              0.0200 , 0.4379 , 0.4100,
              0.5853 , 0.8880 , 0.6137,
              0.7991 , 0.9132 , 0.9720,
              0.6816 , 0.6237  ,0.8562,
              0.9981 , 0.4692 , 0.3849,
              0.5351 , 0.8242 , 0.2731,
              0.1747 , 0.3626 , 0.8345,
              0.5323 , 0.6668 , 0.4922,
              0.2122 , 0.3483 , 0.4707,
              0.6844,  0.1238 , 0.1452,
              0.3882 , 0.4664 , 0.1003,
              0.2296,  0.0401 , 0.3030,
              0.5751 , 0.5467 , 0.9835,
              0.1308 , 0.9628,  0.0777,
              0.2849  ,0.1846 , 0.2625,
              0.9764 , 0.9420 , 0.6628,
              0.3893 , 0.4456 , 0.6433,
              0.8705 , 0.3957 , 0.0963,
              0.6117 , 0.9702 , 0.0247,
              0.3668 , 0.6694 , 0.3117,
              0.6451 , 0.7302,  0.9542,
              0.6171 , 0.1097,  0.9053,
              0.3377 , 0.4950,  0.7284,
              0.1655,  0.9254,  0.6557,
              0.9450  ,0.6721,  0.6162]

face_palette = [int(item * 255) for item in face_palette]





car_12_palette =[ 255,  255,  255, # 0 background
         124,  99 , 34, # 3 car
         193 , 127,  15,# 4 car_lights
         229 , 30  ,141, # 14 tailLight
        225 , 184 , 161, # 10 licensePlate
        104 , 131 , 101,# 19 windshield
        52 , 83  ,84, # 15 tire
        248  ,213 , 42, # 5 door
         116 , 116 , 138,  # 8 handles
           225,  96  ,18,  # 17 wheelhub
         31 , 102 , 211, # 18 window
         142 , 172  ,248, # 11 mirror
         ]



car_32_palette =[ 255,  255,  255,
  238,  229,  102,
  0, 0, 0,
  124,  99 , 34,
  193 , 127,  15,
  106,  177,  21,
  248  ,213 , 42,
  252 , 155,  83,
  220  ,147 , 77,
  99 , 83  , 3,
  116 , 116 , 138,
  63  ,182 , 24,
  200  ,226 , 37,
  225 , 184 , 161,
  233 ,  5  ,219,
  142 , 172  ,248,
  153 , 112 , 146,
  38  ,112 , 254,
  229 , 30  ,141,
  115  ,208 , 131,
  52 , 83  ,84,
  229 , 63 , 110,
  194 , 87 , 125,
  225,  96  ,18,
  73  ,139,  226,
  172 , 143 , 16,
  169 , 101 , 111,
  31 , 102 , 211,
  104 , 131 , 101,
  70  ,168  ,156,
  183 , 242 , 209,
  72  ,184 , 226]

bedroom_palette =[ 255,  255,  255,
  238,  229,  102,
  255, 72, 69,
  124,  99 , 34,
  193 , 127,  15,
  106,  177,  21,
  248  ,213 , 42,
  252 , 155,  83,
  220  ,147 , 77,
  99 , 83  , 3,
  116 , 116 , 138,
  63  ,182 , 24,
  200  ,226 , 37,
  225 , 184 , 161,
  233 ,  5  ,219,
  142 , 172  ,248,
  153 , 112 , 146,
  38  ,112 , 254,
  229 , 30  ,141,
   238, 229, 12,
   255, 72, 6,
   124, 9, 34,
   193, 17, 15,
   106, 17, 21,
   28, 213, 2,
   252, 155, 3,
   20, 147, 77,
   9, 83, 3,
   11, 16, 138,
   6, 12, 24,
   20, 22, 37,
   225, 14, 16,
   23, 5, 29,
   14, 12, 28,
   15, 11, 16,
   3, 12, 24,
   22, 3, 11
   ]

cat_palette = [255,  255,  255,
            220, 220, 0,
           190, 153, 153,
            250, 170, 30,
           220, 220, 0,
           107, 142, 35,
           102, 102, 156,
           152, 251, 152,
           119, 11, 32,
           244, 35, 232,
           220, 20, 60,
           52 , 83  ,84,
          194 , 87 , 125,
          225,  96  ,18,
          31 , 102 , 211,
          104 , 131 , 101
          ]

## generate_data

In [16]:
def generate_data(args, checkpoint_path, num_sample, start_step=0, vis=True):
    if args['category'] == 'car':
        palette = car_20_palette
    elif args['category'] == 'face':
        palette = face_palette
    elif args['category'] == 'bedroom':
        palette = bedroom_palette
    elif args['category'] == 'cat':
        palette = cat_palette
    else:
        assert False
    if not vis:
        result_path = os.path.join(checkpoint_path, 'samples' )
    else:
        result_path = os.path.join(checkpoint_path, 'vis_%d'%num_sample)
    if os.path.exists(result_path):
        pass
    else:
        os.system('mkdir -p %s' % (result_path))
        print('Experiment folder created at: %s' % (result_path))


    g_all, avg_latent, upsamplers = prepare_stylegan(args)

    classifier_list = []
    for MODEL_NUMBER in range(args['model_num']):
        print('MODEL_NUMBER', MODEL_NUMBER)

        classifier = pixel_classifier(numpy_class=args['number_class']
                                      , dim=args['dim'][-1])
        classifier =  nn.DataParallel(classifier, device_ids=device_ids)
        if torch.cuda.is_available():
            classifier = classifier.cuda()

        checkpoint = torch.load(os.path.join(checkpoint_path, 'model_' + str(MODEL_NUMBER) + '.pth'))

        classifier.load_state_dict(checkpoint['model_state_dict'])


        classifier.eval()
        classifier_list.append(classifier)

    softmax_f = nn.Softmax(dim=1)
    with torch.no_grad():
        latent_cache = []
        image_cache = []
        seg_cache = []
        entropy_calculate = []
        results = []
        np.random.seed(start_step)
        count_step = start_step



        print( "num_sample: ", num_sample)

        for i in range(num_sample):
            if i % 20 == 0:
                print("Genearte", i, "Out of:", num_sample)

            curr_result = {}

            latent = np.random.randn(1, 512)

            curr_result['latent'] = latent


            latent = torch.from_numpy(latent).type(torch.FloatTensor).to(device)
            latent_cache.append(latent)

            img, affine_layers = latent_to_image(g_all, upsamplers, latent, dim=args['dim'][1],
                                                     return_upsampled_layers=True)

            if args['dim'][0] != args['dim'][1]:
                img = img[:, 64:448][0]
            else:
                img = img[0]

            image_cache.append(img)
            if args['dim'][0] != args['dim'][1]:
                affine_layers = affine_layers[:, :, 64:448]
            affine_layers = affine_layers[0]

            affine_layers = affine_layers.reshape(args['dim'][-1], -1).transpose(1, 0)

            all_seg = []
            all_entropy = []
            mean_seg = None

            seg_mode_ensemble = []
            for MODEL_NUMBER in range(args['model_num']):
                classifier = classifier_list[MODEL_NUMBER]

                img_seg = classifier(affine_layers)

                img_seg = img_seg.squeeze()


                entropy = Categorical(logits=img_seg).entropy()
                all_entropy.append(entropy)

                all_seg.append(img_seg)
                if mean_seg is None:
                    mean_seg = softmax_f(img_seg)
                else:
                    mean_seg += softmax_f(img_seg)

                img_seg_final = oht_to_scalar(img_seg)
                img_seg_final = img_seg_final.reshape(args['dim'][0], args['dim'][1], 1)
                img_seg_final = img_seg_final.cpu().detach().numpy()

                seg_mode_ensemble.append(img_seg_final)

            mean_seg = mean_seg / len(all_seg)

            full_entropy = Categorical(mean_seg).entropy()

            js = full_entropy - torch.mean(torch.stack(all_entropy), 0)

            top_k = js.sort()[0][- int(js.shape[0] / 10):].mean()
            entropy_calculate.append(top_k)


            img_seg_final = np.concatenate(seg_mode_ensemble, axis=-1)
            img_seg_final = scipy.stats.mode(img_seg_final, 2)[0].reshape(args['dim'][0], args['dim'][1])
            del (affine_layers)
            if vis:

                color_mask = 0.7 * colorize_mask(img_seg_final, palette) + 0.3 * img

                imageio.imwrite(os.path.join(result_path, "vis_" + str(i) + '.jpg'),
                                  color_mask.astype(np.uint8))
                imageio.imwrite(os.path.join(result_path, "vis_" + str(i) + '_image.jpg'),
                                  img.astype(np.uint8))
            else:
                seg_cache.append(img_seg_final)
                curr_result['uncertrainty_score'] = top_k.item()
                image_label_name = os.path.join(result_path, 'label_' + str(count_step) + '.png')
                image_name = os.path.join(result_path,  str(count_step) + '.png')

                js_name = os.path.join(result_path, str(count_step) + '.npy')
                img = Image.fromarray(img)
                img_seg = Image.fromarray(img_seg_final.astype('uint8'))
                js = js.cpu().numpy().reshape(args['dim'][0], args['dim'][1])
                img.save(image_name)
                img_seg.save(image_label_name)
                np.save(js_name, js)
                curr_result['image_name'] = image_name
                curr_result['image_label_name'] = image_label_name
                curr_result['js_name'] = js_name
                count_step += 1


                results.append(curr_result)
                if i % 1000 == 0 and i != 0:
                    with open(os.path.join(result_path, str(i) + "_" + str(start_step) + '.pickle'), 'wb') as f:
                        pickle.dump(results, f)

        with open(os.path.join(result_path, str(num_sample) + "_" + str(start_step) + '.pickle'), 'wb') as f:
            pickle.dump(results, f)

## prepare_data

In [8]:
def prepare_data(args, palette):
    print('preparing data...')
    g_all, avg_latent, upsamplers = prepare_stylegan(args)
    print('prepared stylegan')
    latent_all = np.load(args['annotation_image_latent_path'])
    latent_all = torch.from_numpy(latent_all)
    if torch.cuda.is_available():
        latent_all = latent_all.cuda()

    # load annotated mask
    print('loading annotated mask')
    mask_list = []
    im_list = []
    latent_all = latent_all[:args['max_training']]
    num_data = len(latent_all)

    for i in range(len(latent_all)):
        print(i)

        if i >= args['max_training']:
            break
        name = 'image_mask%0d.npy' % i

        im_frame = np.load(os.path.join( args['annotation_mask_path'] , name))
        mask = np.array(im_frame)
        mask =  cv2.resize(np.squeeze(mask), dsize=(args['dim'][1], args['dim'][0]), interpolation=cv2.INTER_NEAREST)

        mask_list.append(mask)

        im_name = os.path.join( args['annotation_mask_path'], 'image_%d.jpg' % i)
        img = Image.open(im_name)
        img = img.resize((args['dim'][1], args['dim'][0]))

        im_list.append(np.array(img))

    # delete small annotation error
    for i in range(len(mask_list)):  # clean up artifacts in the annotation, must do
        for target in range(1, 50):
            if (mask_list[i] == target).sum() < 30:
                mask_list[i][mask_list[i] == target] = 0


    all_mask = np.stack(mask_list)


    # 3. Generate ALL training data for training pixel classifier
    print('generating training data')
    all_feature_maps_train = np.zeros((args['dim'][0] * args['dim'][1] * len(latent_all), args['dim'][2]), dtype=np.float16)
    all_mask_train = np.zeros((args['dim'][0] * args['dim'][1] * len(latent_all),), dtype=np.float16)


    vis = []
    for i in range(len(latent_all) ):
        print(i)

        gc.collect()

        latent_input = latent_all[i].float()

        img, feature_maps = latent_to_image(g_all, upsamplers, latent_input.unsqueeze(0), dim=args['dim'][1],
                                            return_upsampled_layers=True, use_style_latents=args['annotation_data_from_w'])
        
        if args['dim'][0]  != args['dim'][1]:
            # only for car
            img = img[:, 64:448]
            feature_maps = feature_maps[:, :, 64:448]
        mask = all_mask[i:i + 1]
        feature_maps = feature_maps.permute(0, 2, 3, 1)

        feature_maps = feature_maps.reshape(-1, args['dim'][2])
        new_mask =  np.squeeze(mask)

        mask = mask.reshape(-1)

        all_feature_maps_train[args['dim'][0] * args['dim'][1] * i: args['dim'][0] * args['dim'][1] * i + args['dim'][0] * args['dim'][1]] = feature_maps.cpu().detach().numpy().astype(np.float16)
        all_mask_train[args['dim'][0] * args['dim'][1] * i:args['dim'][0] * args['dim'][1] * i + args['dim'][0] * args['dim'][1]] = mask.astype(np.float16)

        img_show =  cv2.resize(np.squeeze(img[0]), dsize=(args['dim'][1], args['dim'][1]), interpolation=cv2.INTER_NEAREST)

        curr_vis = np.concatenate( [im_list[i], img_show, colorize_mask(new_mask, palette)], 0 )

        vis.append( curr_vis )


    vis = np.concatenate(vis, 1)

    print('vis = np.concatenate(vis, 1)')
    
    imageio.imwrite(os.path.join(args['exp_dir'], "train_data.jpg"), vis)

    print('imageio.imwrite(os.path.join(args[\'exp_dir\'], \"train_data.jpg\"), vis)')

    return all_feature_maps_train, all_mask_train, num_data

## main training

In [9]:
def train_interpreter_main(args):

    if args['category'] == 'car':
        palette = car_20_palette
    elif args['category'] == 'face':
        palette = face_palette
    elif args['category'] == 'bedroom':
        palette = bedroom_palette
    elif args['category'] == 'cat':
        palette = cat_palette
    else:
        assert False


    all_feature_maps_train_all, all_mask_train_all, num_data = prepare_data(args, palette)
    print('data prepared')

    torch.FloatTensor(all_feature_maps_train_all)
    print('all_feature_maps_train_all float tensor success')
    torch.FloatTensor(all_mask_train_all)
    print('all_mask_train_all float tensor success')

    train_data = trainData(torch.FloatTensor(all_feature_maps_train_all),
                           torch.FloatTensor(all_mask_train_all))
    
    print('data floated')


    count_dict = get_label_stas(train_data)

    print('label stas got')

    max_label = args['number_class'] - 1 #max([*count_dict])
    
    print(" *********************** max_label " + str(max_label) + " ***********************")


    print(" *********************** Current number data " + str(num_data) + " ***********************")


    batch_size = args['batch_size']

    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

    print(" *********************** Current dataloader length " +  str(len(train_loader)) + " ***********************")

    for MODEL_NUMBER in range(args['model_num']):

        gc.collect()

        classifier = pixel_classifier(numpy_class=(max_label + 1), dim=args['dim'][-1])

        classifier.init_weights()

        classifier = nn.DataParallel(classifier, device_ids=device_ids)
        if torch.cuda.is_available():
            classifier = classifier.cuda()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(classifier.parameters(), lr=0.001)
        classifier.train()


        iteration = 0
        break_count = 0
        best_loss = 10000000
        stop_sign = 0
        for epoch in range(100):
            for X_batch, y_batch in train_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                y_batch = y_batch.type(torch.long)
                y_batch = y_batch.type(torch.long)

                optimizer.zero_grad()
                y_pred = classifier(X_batch)
                loss = criterion(y_pred, y_batch)
                acc = multi_acc(y_pred, y_batch)

                loss.backward()
                optimizer.step()

                iteration += 1
                if iteration % 1000 == 0:
                    print('Epoch : ', str(epoch), 'iteration', iteration, 'loss', loss.item(), 'acc', acc)
                    gc.collect()


                # if iteration % 5000 == 0:
                #     model_path = os.path.join(args['exp_dir'],
                #                               'model_20parts_iter' +  str(iteration) + '_number_' + str(MODEL_NUMBER) + '.pth')
                #     print('Save checkpoint, Epoch : ', str(epoch), ' Path: ', model_path)

                #     torch.save({'model_state_dict': classifier.state_dict()},
                #                model_path)

                if epoch > 3:
                    if loss.item() < best_loss:
                        best_loss = loss.item()
                        break_count = 0
                    else:
                        break_count += 1

                    if break_count > 50:
                        stop_sign = 1
                        print("*************** Break, Total iters,", iteration, ", at epoch", str(epoch), "***************")
                        break

            if stop_sign == 1:
                break

        gc.collect()
        model_path = os.path.join(args['exp_dir'],
                                  'model_' + str(MODEL_NUMBER) + '.pth')
        MODEL_NUMBER += 1
        print('save to:',model_path)
        torch.save({'model_state_dict': classifier.state_dict()},
                   model_path)
        gc.collect()


        gc.collect()
        torch.cuda.empty_cache()    # clear cache memory on GPU

## Run Interpreter

In [14]:
!unzip dataset_release.zip

Archive:  dataset_release.zip
   creating: dataset_release/
  inflating: __MACOSX/._dataset_release  
  inflating: dataset_release/LICENSE-STYLEGAN.txt  
  inflating: __MACOSX/dataset_release/._LICENSE-STYLEGAN.txt  
   creating: dataset_release/training_latent/
  inflating: __MACOSX/dataset_release/._training_latent  
   creating: dataset_release/annotation/
  inflating: __MACOSX/dataset_release/._annotation  
  inflating: dataset_release/giistr-cla.md  
  inflating: __MACOSX/dataset_release/._giistr-cla.md  
  inflating: dataset_release/LICENSE.txt  
  inflating: __MACOSX/dataset_release/._LICENSE.txt  
   creating: dataset_release/training_latent/car_20/
  inflating: __MACOSX/dataset_release/training_latent/._car_20  
   creating: dataset_release/training_latent/face_34/
  inflating: __MACOSX/dataset_release/training_latent/._face_34  
   creating: dataset_release/training_latent/cat_16/
  inflating: __MACOSX/dataset_release/training_latent/._cat_16  
   creating: dataset_release/an

In [10]:
args = {
      'exp': 'cat_16.json',
      'exp_dir': '',
      'generate_data': False,
      'save_vis': False,
      'start_step': 0,
      'resume': '',
      'num_sample': 1000
  }


opts = json.load(open(args['exp'], 'r'))
print("Opt", opts)

if args['exp_dir'] != "":
    opts['exp_dir'] = args['exp_dir']


path =opts['exp_dir']
if os.path.exists(path):
  pass
else:
  os.system('mkdir -p %s' % (path))
  print('Experiment folder created at: %s' % (path))

os.system('cp %s %s' % (args['exp'], opts['exp_dir']))

train_interpreter_main(opts)


Opt {'exp_dir': 'model_dir/cat_16', 'batch_size': 64, 'category': 'cat', 'debug': False, 'dim': [256, 256, 4992], 'deeplab_res': 256, 'number_class': 16, 'testing_data_number_class': 16, 'max_training': 17, 'stylegan_ver': '1', 'annotation_data_from_w': False, 'annotation_mask_path': './dataset_release/annotation/training_data/cat_processed', 'testing_path': './dataset_release/annotation/testing_data/cat_16_class', 'average_latent': './dataset_release/training_latent/cat_16/avg_latent_cat.npy', 'annotation_image_latent_path': './dataset_release/training_latent/cat_16/latent_stylegan1.npy', 'stylegan_checkpoint': 'karras2019stylegan-cats-256x256.for_g_all.pt', 'model_num': 5, 'upsample_mode': 'bilinear'}
preparing data...
preparing stylegan...
prepared stylegan
loading annotated mask
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
generating training data
0


  "See the documentation of nn.Upsample for details.".format(mode)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
vis = np.concatenate(vis, 1)
imageio.imwrite(os.path.join(args['exp_dir'], "train_data.jpg"), vis)
data prepared
all_feature_maps_train_all float tensor success
all_mask_train_all float tensor success
initializing train data
x_data set
y_data set
trainData initialized
data floated
label stas got
 *********************** max_label 15 ***********************
 *********************** Current number data 17 ***********************
 *********************** Current dataloader length 17408 ***********************
Epoch :  0 iteration 1000 loss 0.24430996179580688 acc tensor(89.0625)
Epoch :  0 iteration 2000 loss 0.3771819472312927 acc tensor(85.9375)
Epoch :  0 iteration 3000 loss 0.22531837224960327 acc tensor(89.0625)
Epoch :  0 iteration 4000 loss 0.35952937602996826 acc tensor(84.3750)
Epoch :  0 iteration 5000 loss 0.3457278609275818 acc tensor(89.0625)
Epoch :  0 iteration 6000 loss 0.28472739458084106 acc tensor(89.0625)
Epoch :  0 iteration 7000

In [12]:
!zip -r 'model_dir_final_5model.zip' 'model_dir'

  adding: model_dir/ (stored 0%)
  adding: model_dir/cat_16/ (stored 0%)
  adding: model_dir/cat_16/model_2.pth (deflated 7%)
  adding: model_dir/cat_16/model_0.pth (deflated 7%)
  adding: model_dir/cat_16/train_data.jpg (deflated 2%)
  adding: model_dir/cat_16/model_4.pth (deflated 7%)
  adding: model_dir/cat_16/cat_16.json (deflated 55%)
  adding: model_dir/cat_16/model_3.pth (deflated 7%)
  adding: model_dir/cat_16/model_1.pth (deflated 7%)


## Run Generator

In [17]:
print(f'generating from model number {model_num}')
args = {
      'exp': 'model_dir/cat_16/cat_16.json',
      'exp_dir': '',
      'generate_data': True,
      'save_vis': True,
      'start_step': 0,
      'resume': f'model_dir/cat_16',
      'num_sample': 1000
  }


opts = json.load(open(args['exp'], 'r'))
print("Opt", opts)

if args['exp_dir'] != "":
  opts['exp_dir'] = args['exp_dir']


path =opts['exp_dir']
if os.path.exists(path):
      pass
else:
      os.system('mkdir -p %s' % (path))
      print('Experiment folder created at: %s' % (path))

os.system('cp %s %s' % (args['exp'], opts['exp_dir']))

generate_data(opts, args['resume'], args['num_sample'], vis=args['save_vis'], start_step=args['start_step'])


generating from model number 0
Opt {'exp_dir': 'model_dir/cat_16', 'batch_size': 64, 'category': 'cat', 'debug': False, 'dim': [256, 256, 4992], 'deeplab_res': 256, 'number_class': 16, 'testing_data_number_class': 16, 'max_training': 17, 'stylegan_ver': '1', 'annotation_data_from_w': False, 'annotation_mask_path': './dataset_release/annotation/training_data/cat_processed', 'testing_path': './dataset_release/annotation/testing_data/cat_16_class', 'average_latent': './dataset_release/training_latent/cat_16/avg_latent_cat.npy', 'annotation_image_latent_path': './dataset_release/training_latent/cat_16/latent_stylegan1.npy', 'stylegan_checkpoint': 'karras2019stylegan-cats-256x256.for_g_all.pt', 'model_num': 5, 'upsample_mode': 'bilinear'}
preparing stylegan...
MODEL_NUMBER 0
MODEL_NUMBER 1
MODEL_NUMBER 2
MODEL_NUMBER 3
MODEL_NUMBER 4
num_sample:  1000
Genearte 0 Out of: 1000


  "See the documentation of nn.Upsample for details.".format(mode)


Genearte 20 Out of: 1000
Genearte 40 Out of: 1000
Genearte 60 Out of: 1000
Genearte 80 Out of: 1000
Genearte 100 Out of: 1000
Genearte 120 Out of: 1000
Genearte 140 Out of: 1000
Genearte 160 Out of: 1000
Genearte 180 Out of: 1000
Genearte 200 Out of: 1000
Genearte 220 Out of: 1000
Genearte 240 Out of: 1000
Genearte 260 Out of: 1000
Genearte 280 Out of: 1000
Genearte 300 Out of: 1000
Genearte 320 Out of: 1000
Genearte 340 Out of: 1000
Genearte 360 Out of: 1000
Genearte 380 Out of: 1000
Genearte 400 Out of: 1000
Genearte 420 Out of: 1000
Genearte 440 Out of: 1000
Genearte 460 Out of: 1000
Genearte 480 Out of: 1000
Genearte 500 Out of: 1000
Genearte 520 Out of: 1000
Genearte 540 Out of: 1000
Genearte 560 Out of: 1000
Genearte 580 Out of: 1000
Genearte 600 Out of: 1000
Genearte 620 Out of: 1000
Genearte 640 Out of: 1000
Genearte 660 Out of: 1000
Genearte 680 Out of: 1000
Genearte 700 Out of: 1000
Genearte 720 Out of: 1000
Genearte 740 Out of: 1000
Genearte 760 Out of: 1000
Genearte 780 Out

In [18]:
!zip -r 'model_dir_final_5model_withsamples.zip' 'model_dir'

  adding: model_dir/ (stored 0%)
  adding: model_dir/cat_16/ (stored 0%)
  adding: model_dir/cat_16/model_2.pth (deflated 7%)
  adding: model_dir/cat_16/model_0.pth (deflated 7%)
  adding: model_dir/cat_16/samples/ (stored 0%)
  adding: model_dir/cat_16/samples/38.npy (deflated 10%)
  adding: model_dir/cat_16/samples/283.npy (deflated 10%)
  adding: model_dir/cat_16/samples/285.npy (deflated 10%)
  adding: model_dir/cat_16/samples/label_46.png (stored 0%)
  adding: model_dir/cat_16/samples/293.png (deflated 0%)
  adding: model_dir/cat_16/samples/278.npy (deflated 10%)
  adding: model_dir/cat_16/samples/205.npy (deflated 9%)
  adding: model_dir/cat_16/samples/56.png (deflated 0%)
  adding: model_dir/cat_16/samples/29.npy (deflated 10%)
  adding: model_dir/cat_16/samples/30.npy (deflated 10%)
  adding: model_dir/cat_16/samples/label_284.png (stored 0%)
  adding: model_dir/cat_16/samples/189.npy (deflated 9%)
  adding: model_dir/cat_16/samples/label_222.png (stored 0%)
  adding: model_dir