In [None]:
%matplotlib inline
import os
import cv2
import time
import pickle

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from IPython.display import clear_output
from datetime import datetime

from lib.utils import make_seed, make_circle_masks
from lib.utils import get_living_mask, get_sobel, softmax, to_rgb

In [None]:
class VAE_encoder(nn.Module):
    
    def __init__(self, dim_out):
        super(VAE_encoder, self).__init__()
        self.conv_mu    = nn.Sequential(nn.Conv2d(3,64,7,padding=3,stride=1), nn.ReLU(),
                                        nn.MaxPool2d(2,stride=2),
                                        nn.Conv2d(64,64,3,padding=1), nn.ReLU(),
                                        nn.Conv2d(64,64,3,padding=1), nn.ReLU(),
                                        nn.Conv2d(64,64,3,padding=1), nn.ReLU(),
                                        nn.MaxPool2d(2,stride=2),
                                        nn.Conv2d(64,128,3,padding=1), nn.ReLU(),
                                        nn.Conv2d(128,128,3,padding=1), nn.ReLU(),
                                        nn.Conv2d(128,128,3,padding=1), nn.ReLU(),
                                        nn.MaxPool2d(2,stride=2),
                                        nn.Conv2d(128,256,3,padding=1), nn.ReLU(),
                                        nn.Conv2d(256,256,3,stride=1,padding=1), nn.ReLU(),
                                        nn.Conv2d(256,512,5,stride=2,padding=2), nn.ReLU(),
                                        nn.Conv2d(512,1024,3))
        self.lin_mu     = nn.Linear(1024,dim_out)
        
    def forward(self, x):
        c_mu = self.conv_mu(x)
        c_mu = torch.reshape(c_mu, [-1,1024])
        return self.lin_mu(c_mu)
    
class VAE_decoder(nn.Module):
    
    def __init__(self, dim_in, dim_out):
        super(VAE_decoder, self).__init__()
        self.lin = nn.Sequential(nn.Linear(dim_in,dim_in*2), nn.ReLU(),
                                 nn.Linear(dim_in*2,dim_in*2), nn.ReLU(),
                                 nn.Linear(dim_in*2,dim_in*2), nn.ReLU(),
                                 nn.Linear(dim_in*2,dim_in*2), nn.ReLU(),
                                 nn.Linear(dim_in*2,dim_out))
        
    def forward(self, z):
        r = self.lin(z)
        return r
    
class GNCAModel(nn.Module):

    def __init__(self, sobels, channel_n, alpha_channel,
                 fire_rate=0.5, calibration=1.0, device=torch.device("cpu")):
        super(GNCAModel, self).__init__()

        self.sobels = sobels
        self.device = device
        self.channel_n = channel_n
        self.alpha_channel = alpha_channel
        
        self.pool = torch.nn.MaxPool2d(kernel_size=3, padding=1, stride=1)

        self.fire_rate = fire_rate
        self.calibration = calibration
        self.to(self.device)

    def perceive(self, x, angle):

        def _perceive_with(x, weight):
            size = weight.shape[0]
            padding = (size-1)/2
            conv_weights = torch.from_numpy(weight.astype(np.float32)).to(self.device)
            conv_weights = conv_weights.view(1,1,size,size).repeat(self.channel_n, 1, 1, 1)
            return F.conv2d(x, conv_weights, padding=int(padding), groups=self.channel_n)

        ys = [x,self.pool(x)]
        for sobel in self.sobels:
            wa_1, wa_2 = get_sobel(sobel)
            wa_1/=np.sum(np.abs(wa_1))
            wa_2/=np.sum(np.abs(wa_2))
            y1 = _perceive_with(x, wa_1)
            y2 = _perceive_with(x, wa_2)
            ys.append(y1)
            ys.append(y2)
        y = torch.cat(ys,1)
        return y
    
    def linear(self, x, w, b=None):
        original_shape = x.size()
        batch = x.size(0)
        y = torch.reshape(x, [batch,-1,original_shape[-1]]).to(self.device)
        if b is None:
            y = torch.bmm(y, w)
        else:
            y = torch.bmm(y, w)+b
        y = torch.reshape(y, list(original_shape[:-1])+[y.size(-1)])
        return y

    def update(self, x, params, fire_rate, angle):
        w0, b0, w1 = params
        
        x = x.transpose(1,3)
        pre_life_mask = get_living_mask(x, self.alpha_channel, 3)

        dx = self.perceive(x, angle)
        dx = dx.transpose(1,3)
        dx = self.linear(dx, w0, b0)
        dx = F.relu(dx)
        dx = self.linear(dx, w1)

        if fire_rate is None:
            fire_rate=self.fire_rate
        stochastic = torch.rand([dx.size(0),dx.size(1),dx.size(2),1])>fire_rate
        stochastic = stochastic.float().to(self.device)
        dx = dx * stochastic
        dx = dx.transpose(1,3)

        x = x+dx

        post_life_mask = get_living_mask(x, self.alpha_channel, 3)
        life_mask = (pre_life_mask & post_life_mask).float()
        
        x = x * life_mask
        x = x.transpose(1,3)
        return x

    def forward(self, x, params, steps, calibration_map=None, fire_rate=None, angle=0.0):
        history = [x.detach().cpu().clamp(0.0, 1.0).numpy(),]
        for step in range(steps):
            x = self.update(x, params, fire_rate, angle)
            if calibration_map is not None:
                h = x[..., :(self.alpha_channel+1)]
                t = calibration_map[..., :(self.alpha_channel+1)]
                _delta = t*(h-1)
                delta = _delta * self.calibration * (calibration_map!=0).float()
                _x = x[..., :(self.alpha_channel+1)]-delta
                x = torch.cat((_x,x[..., (self.alpha_channel+1):]), -1)
            history.append(x.detach().cpu().clamp(0.0, 1.0).numpy())
        return x, history
    
class model_VAE(nn.Module):
    
    def __init__(self, sobels, hidden_encoder, hidden_channel_n, channel_n, size, alpha_channel,
                 device=torch.device("cpu")):
        super(model_VAE, self).__init__()
        self.sobels = sobels
        self.channel_n = channel_n
        self.hidden_channel_n = hidden_channel_n
        self.size = size
        self.alpha_channel = alpha_channel
        self.eps = 1e-3
        
        self.encoder = VAE_encoder(hidden_encoder)
        self.decoder_w0 = VAE_decoder(hidden_encoder, channel_n*(len(self.sobels)+1)*2*self.hidden_channel_n)
        self.decoder_b0 = VAE_decoder(hidden_encoder, self.hidden_channel_n)
        self.decoder_w1 = VAE_decoder(hidden_encoder, self.hidden_channel_n*channel_n)
        self.GNCA = GNCAModel(self.sobels, self.channel_n, self.alpha_channel, device=device)
        
        self.device = device
        self.to(self.device)
    
    def encode(self, x):
        z = self.encoder(x)
        return z
    
    def decode(self, z):
        w0 = self.decoder_w0(z)
        b0 = self.decoder_b0(z)
        w1 = self.decoder_w1(z)
        params = (torch.reshape(w0,[-1,self.channel_n*(len(self.sobels)+1)*2,self.hidden_channel_n]),
                  torch.reshape(b0,[-1,1,self.hidden_channel_n]),
                  torch.reshape(w1,[-1,self.hidden_channel_n,self.channel_n]))
        return params
    
    def infer(self, x0, x, steps, calibration_map=None):
        with torch.no_grad():
            x_ = torch.reshape(x, [-1,self.alpha_channel,self.size,self.size])
            z = self.encode(x_)
            params = self.decode(z)
            y, history = self.GNCA(x0, params, steps, calibration_map=calibration_map)
            y = y[..., :(self.alpha_channel+1)].clamp(self.eps, 1.0-self.eps)
        return y, history
    
    def multi_infer(self, x0, xs, steps, gamma=0.99, calibration_map=None):
        with torch.no_grad():
            params_list = []
            zs = []
            for x in xs:
                x_ = torch.reshape(x, [-1,self.alpha_channel,self.size,self.size])
                z = self.encode(x_)
                zs.append(z.detach().cpu().numpy())
                params = self.decode(z)
                params_list.append(params)
            y = x0
            history = [x0.detach().cpu().numpy()]
            for i in range(steps):
                y, _ = self.GNCA(y, params_list[i%len(params_list)], 1, calibration_map=calibration_map)
                his = y[..., :(self.alpha_channel+1)].clamp(self.eps, 1.0-self.eps).detach().cpu().numpy()
                history.append(his)
            y = y[..., :(self.alpha_channel+1)].clamp(self.eps, 1.0-self.eps)
        return y, history, zs
    
    def train(self, x0, x, target, steps, beta, calibration_map=None):
        x_ = torch.reshape(x, [-1,self.alpha_channel,self.size,self.size])
        z = self.encode(x_)
        params = self.decode(z)
        y_raw, _ = self.GNCA(x0, params, steps, calibration_map=calibration_map)
        y_raw = y_raw.clamp(self.eps, 1.0-self.eps)
        y = y_raw[..., :(self.alpha_channel+1)]
        
        mse = F.mse_loss(y, target)
        l2 = torch.sum(torch.pow(z, 2))
        loss = mse + beta*l2
        
        return y_raw, loss, (mse.item(), l2.item())

def read_and_resize(path, size):
    raw=mpimg.imread(path)
    scale = size/min(raw.shape[0], raw.shape[1])
    new_shape = (max(int(raw.shape[1]*scale),64), max(int(raw.shape[0]*scale),64))
    img = cv2.resize(raw, new_shape)
    img = img[(img.shape[0]-size)//2:(img.shape[0]-size)//2+size,
              (img.shape[1]-size)//2:(img.shape[1]-size)//2+size, :]
    return img

def plot_loss(loss_log):
    plt.figure(figsize=(10, 4))
    plt.title('Loss history (log10)')
    plt.plot(np.log10(loss_log), '.', alpha=0.1)
    plt.show()
    return

In [None]:
ROOT = "/disk2/mingxiang_workDir/vggface2/test/"
SIZE = 40

DEVICE = torch.device("cuda:0")
model_path = "models/gen_AE_vgg2.pth"
init_coord = (SIZE//2, SIZE//2)

SOBEL_SIZES = [3,5,9]
ALPHA_CHANNEL = 3
HIDDEN_ENCODER = 1024
CHANNEL_N = 24
HIDDEN_CHANNEL_N = 256

BATCH_SIZE = 8
N_STEPS = 160

names = [x for x in os.listdir(ROOT) if x[0]!='.']
paths = {}
for name in names:
    paths[name] = [x for x in os.listdir(ROOT+name) if x[0]!='.']
print("num_images", np.sum([len(paths[name]) for name in names]))

my_model = model_VAE(SOBEL_SIZES, HIDDEN_ENCODER, HIDDEN_CHANNEL_N, CHANNEL_N,
                     SIZE, ALPHA_CHANNEL, device=DEVICE)
my_model.load_state_dict(torch.load(model_path))

loss_log = []

In [None]:
n_batch = 8
for index in range(n_batch):
    name_batch = np.random.choice(len(names), BATCH_SIZE, replace=False)
    path_is = [np.random.randint(len(paths[names[names_i]])) for names_i in name_batch]
    x_np = []
    for i in range(len(name_batch)):
        name = names[name_batch[i]]
        path = ROOT+name+"/"+paths[name][path_is[i]]
        x_np.append(read_and_resize(path, SIZE))
    x_np = np.array(x_np).transpose([0,3,1,2])/255.0
    x_np = x_np.astype(np.float32)

    x = torch.from_numpy(x_np).to(DEVICE)
    target_np = x_np.reshape([-1, ALPHA_CHANNEL, SIZE, SIZE]).transpose([0,2,3,1])
    alpha_values = np.expand_dims(np.ones(target_np.shape[:-1]), -1)
    target_np = np.concatenate([target_np, alpha_values], -1)

    seed = make_seed((SIZE,SIZE), CHANNEL_N, np.arange(CHANNEL_N-ALPHA_CHANNEL)+ALPHA_CHANNEL, init_coord)
    x0_np = np.repeat(seed[None, ...], len(name_batch), 0)
    x0 = torch.from_numpy(x0_np.astype(np.float32)).to(DEVICE)

    y, history = my_model.infer(x0, x, N_STEPS)
    y = y.detach().cpu().numpy()

    i_shows = [i for i in range(BATCH_SIZE)]
    plt.figure(figsize=(18,5))
    for i,ii in enumerate(i_shows):
        plt.subplot(2,len(i_shows),i+1)
        plt.imshow(to_rgb(target_np[ii]))
        plt.axis('off')
    for i,ii in enumerate(i_shows):
        plt.subplot(2,len(i_shows),i+len(i_shows)+1)
        plt.imshow(to_rgb(y[ii,...,:(ALPHA_CHANNEL+1)]))
        plt.axis('off')
    plt.show()
    print("----------")

In [None]:
n_batch = 1
for index in range(n_batch):
    name_batch = np.random.choice(len(names), BATCH_SIZE, replace=False)
    path_is = [np.random.randint(len(paths[names[names_i]])) for names_i in name_batch]
    x_np = []
    for i in range(len(name_batch)):
        name = names[name_batch[i]]
        path = ROOT+name+"/"+paths[name][path_is[i]]
        x_np.append(read_and_resize(path, SIZE))
    x_np = np.array(x_np).transpose([0,3,1,2])/255.0
    x_np_raw = x_np.astype(np.float32)
    
    damages = []
    for _ in range(BATCH_SIZE):
        n_damage = 8
        damage = 1.0-make_circle_masks(n_damage, SIZE, SIZE, rmin=0.02, rmax=0.05)
        damage = np.sum(damage, 0)>=n_damage
        damages.append(damage)
    damages = np.array(damages)[:, None, ...]
    x_np = x_np.astype(np.float32)*damages

    x = torch.from_numpy(x_np).to(DEVICE)
    
    target_np_raw = x_np_raw.reshape([-1, ALPHA_CHANNEL, SIZE, SIZE]).transpose([0,2,3,1])
    alpha_values = np.expand_dims(np.ones(target_np_raw.shape[:-1]), -1)
    target_np_raw = np.concatenate([target_np_raw, alpha_values], -1)
    
    target_np = x_np.reshape([-1, ALPHA_CHANNEL, SIZE, SIZE]).transpose([0,2,3,1])
    alpha_values = np.expand_dims(np.ones(target_np.shape[:-1]), -1)
    target_np = np.concatenate([target_np, alpha_values], -1)

    seed = make_seed((SIZE,SIZE), CHANNEL_N, np.arange(CHANNEL_N-ALPHA_CHANNEL)+ALPHA_CHANNEL, init_coord)
    x0_np = np.repeat(seed[None, ...], len(name_batch), 0)
    x0 = torch.from_numpy(x0_np.astype(np.float32)).to(DEVICE)

    y, history = my_model.infer(x0, x, N_STEPS)
    y = y.detach().cpu().numpy()

    i_shows = [i for i in range(BATCH_SIZE)]
    plt.figure(figsize=(18,5))
    for i,ii in enumerate(i_shows):
        plt.subplot(3,len(i_shows),i+1)
        plt.imshow(to_rgb(target_np_raw[ii]))
        plt.axis('off')
    for i,ii in enumerate(i_shows):
        plt.subplot(3,len(i_shows),i+len(i_shows)+1)
        plt.imshow(to_rgb(target_np[ii]))
        plt.axis('off')
    for i,ii in enumerate(i_shows):
        plt.subplot(3,len(i_shows),i+len(i_shows)*2+1)
        plt.imshow(to_rgb(y[ii,...,:(ALPHA_CHANNEL+1)]))
        plt.axis('off')
    plt.show()
    print("----------")