In [None]:
import os, sys
sys.path.append("..")

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import numpy as np
import torch
import torch.nn as nn
import torchvision
import gc

from src import distributions
import torch.nn.functional as F

from src.resnet2 import ResNet_G
from src.unet import UNet
from src.tools import h5py_to_dataset

from src.tools import freeze

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

In [None]:
DEVICE_IDS = [0,1,2,3]

# DATASET, DATASET_PATH = 'ave_celeba', '../../data/ave_celeba/'
# DATASET, DATASET_PATH = 'celeba', '../../data/celeba_aligned/' # 202k Celeba Images resized to 64x64
# DATASET, DATASET_PATH = 'mnist01', '../../data/'
# DATASET, DATASET_PATH = 'fashionmnist_all', '../../data/'
# DATASET, DATASET_PATH = 'fashionmnist017', '../../data/'
DATASET, DATASET_PATH = 'handbag_shoes_fruit360', '../../data/' # shoes, handbags, fruit360

SHIFT = True if DATASET == 'ave_celeba' else False # Preprocess the data with shifting means?
INPUT_PATH = '../checkpoints/{}/'.format(DATASET)

In [None]:
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:{DEVICE_IDS[0]}')

if DATASET == 'ave_celeba':
    IMG_SIZE, SIZE, NC, Z_DIM = 64, 200000, 3, 128
    ALPHAS = [0.25, 0.5, 0.25]
    CLASSES = [0, 1, 2]
elif DATASET == 'celeba':
    IMG_SIZE, SIZE, NC, Z_DIM = 64, 200000, 3, 128
    ALPHAS = [1.]
    CLASSES = [0]
elif DATASET == 'fashionmnist_all':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [0.1 for _ in range(10)]
    CLASSES = list(range(10))
elif DATASET == 'fashionmnist017':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [1/3. for _ in range(3)]
    CLASSES = [0,1,7]
elif DATASET == 'mnist01':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [0.5, .5]
    CLASSES = [0,1]
elif DATASET == 'handbag_shoes_fruit360':
    IMG_SIZE, SIZE, NC, Z_DIM = 64, 6000, 3, 128
    ALPHAS = [1./3, 1./3, 1./3]
    CLASSES = None
else:
    raise Exception('Unknown dataset')
    
K = len(ALPHAS)

### Prepare Samplers (Z, Y)

In [None]:
Z_sampler = distributions.StandardNormalSampler(dim=Z_DIM)
Y_samplers = []

transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(14, fill=(255,255,255)) if DATASET == 'handbag_shoes_fruit360' else torchvision.transforms.Lambda(lambda x:x),
    torchvision.transforms.Resize(IMG_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: 2 * x - 1)
])

if DATASET != 'handbag_shoes_fruit360': 
    for k in range(K):
        if DATASET in ['ave_celeba', 'celeba']:
            dataset = torchvision.datasets.ImageFolder(DATASET_PATH, transform=transform)
        elif 'fashionmnist' in DATASET:
            dataset = torchvision.datasets.FashionMNIST(root=DATASET_PATH, download=True, transform=transform, train=True)
        elif 'mnist' in DATASET:
            dataset = torchvision.datasets.MNIST(root=DATASET_PATH, download=True, transform=transform)
        else:
            raise Exception('Unknown dataset')

        try:
            dataset.samples = [s for s in dataset.samples if s[1] == CLASSES[k]]
        except:
            idx = [t == CLASSES[k] for t in dataset.targets]
            if 'mnist' in DATASET:
                dataset.targets, dataset.data = np.array(dataset.targets)[idx], torch.tensor(dataset.data)[idx]
            else:
                dataset.targets, dataset.data = np.array(dataset.targets)[idx], np.array(dataset.data)[idx]

        Y_samplers.append(distributions.DatasetSampler(dataset))
        
elif DATASET == 'handbag_shoes_fruit360':
    dataset = h5py_to_dataset(os.path.join(DATASET_PATH, 'handbag_64.hdf5'))
    Y_samplers.append(distributions.DatasetSampler(dataset))
    dataset = h5py_to_dataset(os.path.join(DATASET_PATH, 'shoes_64.hdf5'))
    Y_samplers.append(distributions.DatasetSampler(dataset))
    dataset = torchvision.datasets.ImageFolder(os.path.join(DATASET_PATH, 'fruit360'), transform=transform)
    Y_samplers.append(distributions.DatasetSampler(dataset))
    
with torch.no_grad():
    Y_bar_mean = 0.
    Y_means = []
    for k in range(K):
        Y_means.append(Y_samplers[k].dataset.mean(dim=0))
        Y_bar_mean += ALPHAS[k] * Y_means[-1]
    Y_shifts = [(Y_means[k] - Y_bar_mean).cuda() for k in range(K)]
        
    if SHIFT:
        for k in range(K):
            Y_samplers[k].dataset += Y_bar_mean - Y_means[k]
    
torch.cuda.empty_cache(); gc.collect()
clear_output()

## Loading Networks

In [None]:
G = ResNet_G(Z_DIM, IMG_SIZE, nc=NC).cuda()

Ts, Ts_inv = [], []
for k in range(K):
    Ts.append(UNet(NC, NC, base_factor=48 if NC == 3 else 16).cuda())
    if DATASET != 'celeba':
        Ts_inv.append(UNet(NC, NC, base_factor=48 if NC == 3 else 16).cuda()) 

if len(DEVICE_IDS) > 1:
    G = nn.DataParallel(G, device_ids=DEVICE_IDS)
    for k in range(K):
        Ts[k] = nn.DataParallel(Ts[k], device_ids=DEVICE_IDS)
        if DATASET != 'celeba':
            Ts_inv[k] = nn.DataParallel(Ts_inv[k], device_ids=DEVICE_IDS)

G.load_state_dict(torch.load(os.path.join('../checkpoints/{}/'.format(DATASET), 'G.pt')))
for k in range(K):
    Ts[k].load_state_dict(torch.load(os.path.join('../checkpoints/{}/'.format(DATASET), f'T_{k}.pt')))
    if DATASET != 'celeba':
        Ts_inv[k].load_state_dict(torch.load(os.path.join('../checkpoints/{}/'.format(DATASET), f'T_inv_{k}.pt')))

freeze(G)
for k in range(K):
    freeze(Ts[k]);
    if DATASET != 'celeba':
        freeze(Ts_inv[k])

## Samples

In [None]:
torch.manual_seed(0xBADBEEF); np.random.seed(0xBADBEEF)
Z = Z_sampler.sample(12)
with torch.no_grad():
    X = G(Z)
    Ys = [Y_samplers[k].sample(12) for k in range(K)]
#     Ys = [Y_samplers[k].dataset[:12].cuda() for k in range(K)]
    Ts_X = [Ts[k](X) for k in range(K)]
    if DATASET != 'celeba':
        Ts_inv_Y = [Ts_inv[k](Ys[k]) for k in range(K)]
    
    if SHIFT:
        for k in range(K):
            Ts_X[k] += Y_shifts[k]
            Ys[k] += Y_shifts[k]
    
    X_avg = 0.
    for k in range(K):
        X_avg += Ts_X[k] * ALPHAS[k]
    
    Ts_inv_Y_base = [Ys[k] - Y_shifts[k] for k in range(K)]

## Generated images and maps to marginals

In [None]:
fig, axes = plt.subplots(K+2, 12, figsize=(26.5, 2 * (K+2)),dpi=200)
imgs = torch.cat([X] + [Ts_X[k] for k in range(K)] + [X_avg]).to('cpu').add(1).mul(0.5).permute(0, 2, 3, 1).detach().numpy().clip(0,1)
for i, ax in enumerate(axes.flatten()):
    ax.imshow(imgs[i], cmap=plt.get_cmap('gray').reversed() if NC == 1 else None)

axes[0,0].set_ylabel(r'$\mathbb{P}_{\xi}\!=\!G_{\xi}\sharp\mathbb{S}$', fontsize=38, rotation='horizontal', va="center", labelpad=95) #, color='limegreen')
axes[-1,0].set_ylabel(r'$\approx\mathcal{H}(\mathbb{P}_{\xi})$', fontsize=38, rotation='horizontal', va="center", labelpad=95) #, color='limegreen')
for k in range(K):
    title = '$\\mathbb{P}_{\\xi}\\rightarrow\\mathbb{P}_{' + str(k+1) + '}$'
    axes[k+1,0].set_ylabel(r'{}'.format(title), fontsize=38, rotation='horizontal', va="center", labelpad=95)
    
for i, ax in enumerate(axes.flatten()):
    ax.get_xaxis().set_visible(False)
    ax.set_yticks([])
    
fig.tight_layout(pad=0.01)

## Maps from marginals to the barycenter

In [None]:
assert DATASET != 'celeba'
for k in range(K):
    fig, axes = plt.subplots(3, 3, figsize=(8.3, 6),dpi=200)
    imgs = torch.cat([Ys[k][:3], Ts_inv_Y_base[k][:3], Ts_inv_Y[k][:3]])
    imgs = imgs.to('cpu').add(1).mul(0.5).permute(0, 2, 3, 1).detach().numpy().clip(0,1)
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(imgs[i], plt.get_cmap('gray').reversed() if NC == 1 else None)
        
    axes[0,0].set_ylabel(r'$\mathbb{P}$' + f'$_{{{k+1}}}$', fontsize=38, rotation='horizontal', va="center", labelpad=95)
    axes[1,0].set_ylabel(r'$\lfloor$' + 'CS' + r'$\rceil$' + '\n' + r'$\mathbb{P}$' + f'$_{{{k+1}}}$' + r'$\rightarrow\mathbb{P}_{\xi}$', fontsize=38, rotation='horizontal', va="center", labelpad=95)#, color='gray')
    axes[2,0].set_ylabel('Our' + '\n' + r'$\mathbb{P}$' + f'$_{{{k+1}}}$' + r'$\rightarrow\mathbb{P}_{\xi}$', fontsize=38, rotation='horizontal', va="center", labelpad=95) #, color='limegreen')

            
    for i, ax in enumerate(axes.flatten()):
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])

    fig.tight_layout(pad=0.01)
    plt.show()

## Maps through the barycenter

In [None]:
assert DATASET != 'celeba'
idx = 0 # Index of the image to push
fig, axes = plt.subplots(K, K+2, figsize=(2*K+4+1.6, 2*K+0.7),dpi=200)
imgs = []
for k1 in range(K):
    with torch.no_grad():
        imgs.append(Ys[k1][idx][None])
        imgs.append(Ts_inv_Y[k1][idx][None])
        for k2 in range(K):
            if SHIFT:
                imgs.append(Ts[k2](Ts_inv_Y[k1][idx][None]) + Y_shifts[k2])
            else:
                imgs.append(Ts[k2](Ts_inv_Y[k1][idx][None]))

imgs = torch.cat(imgs).to('cpu').add(1).mul(0.5).permute(0, 2, 3, 1).detach().numpy().clip(0,1)
for i, ax in enumerate(axes.flatten()):
    ax.imshow(imgs[i], cmap=plt.get_cmap('gray').reversed() if NC == 1 else None)
    ax.get_xaxis().set_visible(False)
    ax.set_yticks([])

axes[0, 0].set_title(r'$\mathbb{P}_{n}$', fontsize=37, rotation='horizontal', va="center", pad=25)
axes[0, 1].set_title(r'$\mathbb{P}_{n}\rightarrow\mathbb{P}_{\xi}$', fontsize=37, rotation='horizontal', va="center", pad=25)

for k in range(K):
    axes[k, 0].set_ylabel(r'$n=$' + str(k+1), fontsize=37, rotation='horizontal', va="center", labelpad=65)
    title_1 = '$\\mathbb{P}_{\\xi}\\rightarrow\\mathbb{P}_{' + str(k+1) + '}}$'
    axes[0, k+2].set_title(r'{}'.format(title_1), fontsize=37, rotation='horizontal', va="center", pad=25)
fig.tight_layout(pad=0.01)
plt.show()

## Maps through the barycenter v2

In [None]:
assert DATASET != 'celeba'
for k in range(K):
    fig, axes = plt.subplots(2+K, 3, figsize=(8.3, 2*(2+K)),dpi=200)
    imgs = torch.cat([Ys[k][:3], Ts_inv_Y[k][:3], *[Ts[k2](Ts_inv_Y[k][:3])  + Y_shifts[k2] for k2 in range(K)]])
    imgs = imgs.to('cpu').add(1).mul(0.5).permute(0, 2, 3, 1).detach().numpy().clip(0,1)
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(imgs[i], plt.get_cmap('gray').reversed() if NC == 1 else None)
        
    axes[0,0].set_ylabel(r'$\mathbb{P}$' + f'$_{{{k+1}}}$', fontsize=38, rotation='horizontal', va="center", labelpad=95)
    axes[1,0].set_ylabel(r'$\mathbb{P}$' + f'$_{{{k+1}}}$' + r'$\rightarrow\mathbb{P}_{\xi}$', rotation='horizontal', fontsize=38, va="center", labelpad=95)
    for k2 in range(K):
        axes[k2+2,0].set_ylabel(
            r'$\mathbb{P}$' + f'$_{{{k+1}}}$' + r'$\rightarrow\mathbb{P}_{\xi}$' + '\n' +
            r'$\mathbb{P}_{\xi}$' + r'$\rightarrow\mathbb{P}$' + f'$_{{{k2+1}}}$',
            rotation='horizontal', fontsize=38, va="center", labelpad=95)
        

    for i, ax in enumerate(axes.flatten()):
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])

    fig.tight_layout(pad=0.01)
    plt.show()