In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def style_swap(content_feature, style_feature, kernel_size=5, stride=1):
    # content_feature and style_feature should have shape as (1, C, H, W)
    # kernel_size here is equivalent to extracted patch size

    # extract patches from style_feature with shape (1, C, H, W)
    kh, kw = kernel_size, kernel_size
    sh, sw = stride, stride

    patches = style_feature.unfold(2, kh, sh).unfold(3, kw, sw)
    
    patches = patches.permute(0, 2, 3, 1, 4, 5)
    
    patches = patches.reshape(-1, *patches.shape[-3:]) # (patch_numbers, C, kh, kw)
    print(patches.shape)
    # calculate Frobenius norm and normalize the patches at each filter
    norm = torch.norm(patches.reshape(patches.shape[0], -1), dim=1).reshape(-1, 1, 1, 1)
    
    noramalized_patches = patches / norm

    conv_out = F.conv2d(content_feature, noramalized_patches)
    
    # calculate the argmax at each spatial location, which means at each (kh, kw),
    # there should exist a filter which provides the biggest value of the output
    one_hots = torch.zeros_like(conv_out)
    one_hots.scatter_(1, conv_out.argmax(dim=1, keepdim=True), 1)

    # deconv/transpose conv
    deconv_out = F.conv_transpose2d(one_hots, patches)

    # calculate the overlap from deconv/transpose conv
    overlap = F.conv_transpose2d(one_hots, torch.ones_like(patches))

    # average the deconv result
    res = deconv_out / overlap
    return res

In [3]:
a = torch.rand((1, 512, 32, 32))
b = torch.rand((1, 512, 64, 64))

In [4]:
res = style_swap(b, a)
print(res.shape)

torch.Size([784, 512, 5, 5])
torch.Size([1, 512, 64, 64])


In [154]:
def multi_scale_style_swap(content_feature, style_feature, kernel_size=5, stride=1):
    c_shape = content_feature.shape
    s_shape = style_feature.shape
    assert (c_shape[1] == s_shape[1])
    
    combined_feature_maps = []
    for beta in [1.0/2, 1.0/(2**0.5), 1.0]:
        new_height = int(float(s_shape[2]) * beta)
        new_width = int(float(s_shape[3]) * beta)
        tmp_style_features = F.interpolate(style_feature, \
            size=(new_height, new_width), mode='bilinear', align_corners=True)
        
        combined_feature = style_swap(content_feature, \
            style_feature, kernel_size=kernel_size, stride=stride)
        
        combined_feature_maps.append(combined_feature)
    combined_feature_maps.append(content_feature)
    return combined_feature_maps

In [155]:
res = multi_scale_style_swap(a, b)

In [157]:
for feature in res:
    print(feature.shape)

torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 32, 32])
torch.Size([1, 512, 32, 32])


In [243]:
# def KMeans(image, clusters_num):
# # KMeans for only h*w*1 tensors
#     image = image.squeeze()
#     _points = image.reshape(-1, 1)
#     # randomly permuate the points and select cluster_num centroids
#     idx = torch.randperm(_points.shape[0])
#     centroids = _points[idx][:clusters_num]
    
#     # expand the points dimension to the cluster
#     points_expanded = _points.repeat(1, clusters_num)
    
#     for i in range(80):
#         centroids_expanded  = centroids.permute(1,0).repeat(points_expanded.shape[0], 1)
#         distances = (points_expanded - centroids_expanded) ** 2
#         # same shape as distance, by expanding argmin
#         distances_min_expand = distances.min(1).values.unsqueeze(1).repeat(1, distances.shape[1])
#         # the mask that can be used to filter the value
#         mask = (distances == distances_min_expand).float()
#         centroids = (distances_min_expand * mask).sum(0).unsqueeze(1)
#         print(centroids)
        
#     return centroids

# https://www.kernel-operations.io/keops/_auto_tutorials/kmeans/plot_kmeans_torch.html
def KMeans(x, K=10, Niter=80, verbose=True):
    N, D = x.shape  # Number of samples, dimension of the ambient space

    # K-means loop:
    # - x  is the point cloud,
    # - cl is the vector of class labels
    # - c  is the cloud of cluster centroids

    c = x[:K, :].clone()  # Simplistic random initialization
    x_i = x.unsqueeze(1)  # (Npoints, 1, D)
    for i in range(Niter):
        c_j = c.unsqueeze(0)  # (1, Nclusters, D)
        D_ij = ((x_i - c_j) ** 2).sum(-1)  # (Npoints, Nclusters) symbolic matrix of squared distances
        cl = D_ij.argmin(dim=1).long().view(-1)  # Points -> Nearest cluster

        Ncl = torch.bincount(cl).type(torch.float64)  # Class weights
        for d in range(D):  # Compute the cluster centroids with torch.bincount:
            c[:, d] = torch.bincount(cl, weights=x[:, d]) / Ncl
            
    return cl, c

In [271]:
def multi_stroke_fusion(stylized_maps, attention_map, theta=50.0, mode='softmax'):
    stroke_num = len(stylized_maps)
    if stroke_num == 1:
        return stylized_maps[0]
    
    one_channel_attention = torch.mean(attention_map, 1).unsqueeze(1)
    origin_shape = one_channel_attention.shape
    one_channel_attention = one_channel_attention.reshape((-1, 1)) # stretch to tensor (hw)* 1 
    _, centroids = KMeans(one_channel_attention, stroke_num)
    
    one_channel_attention = one_channel_attention.reshape(origin_shape)
    
    saliency_distances = []
    for i in range(stroke_num):
        saliency_distances.append(torch.abs(one_channel_attention - centroids[i]))
    
    multi_channel_saliency = torch.cat(saliency_distances, 1)
    
    softmax = nn.Softmax(dim=1)
    
    multi_channel_saliency = softmax(theta*(1.0 - multi_channel_saliency))
    
    finial_stylized_map = 0
    for i in range(stroke_num):
        temp = multi_channel_saliency[0, i, :, :].unsqueeze(0).unsqueeze(0)
        finial_stylized_map += temp * stylized_maps[i]
    return finial_stylized_map, centroids
    

In [272]:
a = torch.rand((1, 512, 32, 32))
b = torch.rand((1, 512, 32, 32))

In [274]:
finial_stylized_map, _ = multi_stroke_fusion([a, a, a], b)
print(finial_stylized_map.shape)

torch.Size([1, 512, 32, 32])


In [None]:
def zca_normalization(features):
    # [b, c, h, w]
    shape = features.shape

    # reshape the features to orderless feature vectors
    mean_features = torch.mean(features, dim=(2, 3), keepdims=True)
    unbiased_features = (features - mean_features).view(shape[0], shape[1], -1) # [b, c, h*w]

    # get the convariance matrix
    gram = torch.bmm(unbiased_features, unbiased_features.permute(0, 2, 1)) # [b, c, c]
    gram = gram / (shape[1] * shape[2] * shape[3])

    # converting the feature spaces
    u, s, v = torch.svd(gram, compute_uv=True)
    # u: [b, c, c], s: [b, c], v: [b, c, c]
    s = torch.unsqueeze(s, dim=1)

    # get the effective singular values
    valid_index = (s > 0.00001).float()
    s_effective = torch.max(s, torch.empty(s.shape).fill_(0.00001))
    sqrt_s_effective = torch.sqrt(s_effective) * valid_index
    sqrt_inv_s_effective = torch.sqrt(1.0 / s_effective) * valid_index
    print(s_effective.shape)

    # colorization functions
    colorization_kernel = torch.bmm((u * sqrt_inv_s_effective), v.permute(0, 2, 1))

    # normalized features
    normalized_features = torch.bmm(unbiased_features.permute(0, 2, 1), u).permute(0, 2, 1)
    normalized_features = (normalized_features.permute(0, 2, 1) * sqrt_inv_s_effective).permute(0, 2, 1)
    normalized_features = torch.bmm(normalized_features.permute(0, 2, 1), v.permute(0, 2, 1)).permute(0, 2, 1)
    normalized_features = normalized_features.view(shape)

    return normalized_features, colorization_kernel, mean_features

In [None]:
def zca_colorization(normalized_features, colorization_kernel, mean_features):
    # broadcasting the tensors for matrix multiplication
    shape = normalized_features.shape
    normalized_features = normalized_features.view(shape[0], shape[1], -1) # [b, c, h*w]

    colorized_features = torch.bmm(normalized_features.permute(0, 2, 1), colorization_kernel).permute(0, 2, 1)
    colorized_features = colorized_features.view(shape) + mean_features


    # normalized_features = normalized_features.permute(0, 2, 3, 1).view(shape[0], -1, shape[1]) # [b, c, h*w]

    # colorized_features = torch.bmm(normalized_features, colorization_kernel)
    # colorized_features = colorized_features.permute(0, 2, 1).view(shape) + mean_features
    return colorized_features