In [None]:
# import statements

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import samplers
import numpy as np
import importlib
import matplotlib.pyplot as plt
import time
from network_unet import UNet
from randomized_svd_jacobian import randomized_svd as rsvd

import torch_dct as dct

importlib.reload(samplers)

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
# Load CelebA

to_tensor = torchvision.transforms.ToTensor()
downsize = torchvision.transforms.Resize((256, 256))
composed_transform = torchvision.transforms.Compose([downsize, to_tensor])
root = "" # path to CelebA dataset
trainset = torchvision.datasets.CelebA(root=root, split='train', download=True, transform=composed_transform)
trainset_abridged = torch.utils.data.Subset(trainset, range(2000)) # 2000 images
batch_size = 16
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = hidden_dims
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_dim, hidden_dims[0]))
        for i in range(len(hidden_dims)-1):
            self.layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
        self.layers.append(nn.Linear(hidden_dims[-1], output_dim))

    def forward(self, x):
        for i in range(len(self.layers)-1):
            x = F.elu(self.layers[i](x))
        return self.layers[-1](x).reshape(-1, 3, 80, 80)

In [None]:
# Class Enc1 is just an input layer plus Elu

class Enc1(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Enc1, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.linear = nn.Linear(input_dim, hidden_dim)

    def forward(self, x):
        return F.elu(self.linear(x))

# Class enc2 is just an output layer

class Enc2(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super(Enc2, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

In [None]:
class TruncatedDCT(nn.Module):
    def __init__(self, num_coeffs):
        super(TruncatedDCT, self).__init__()
        self.num_coeffs = num_coeffs

    def forward(self, x):
        x_dct = dct.dct_2d(x)
        x_dct_trunc = x_dct[:,:,:self.num_coeffs,:self.num_coeffs]
        return x_dct_trunc

class TruncatedIDCT(nn.Module):
    def __init__(self, num_coeffs, original_size):
        super(TruncatedIDCT, self).__init__()
        self.num_coeffs = num_coeffs
        self.original_size = original_size

    def forward(self, X_dct_trunc):
        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size - self.num_coeffs, self.num_coeffs, device=device)], dim=2)
        X_dct_trunc = torch.cat([X_dct_trunc, torch.zeros(X_dct_trunc.shape[0], X_dct_trunc.shape[1], self.original_size, self.original_size - self.num_coeffs, device=device)], dim=3)
        x_idct = dct.idct_2d(X_dct_trunc)
        return x_idct

In [None]:
# Load enc1, enc2, dec from checkpoint

checkpt_fname = "" # fill in with checkpoint filename
checkpt = torch.load(checkpt_fname, map_location=device)

D = 3*80*80
input_dim = D
latent_dim = 700
hidden_dim = 10000

num_coeffs = 80

trunc_dct = TruncatedDCT(num_coeffs=num_coeffs)
trunc_idct = TruncatedIDCT(num_coeffs=num_coeffs, original_size=256)

enc1 = Enc1(input_dim, hidden_dim).to(device)
enc2 = Enc2(hidden_dim, latent_dim).to(device)
output_unet = UNet(in_nc=3, out_nc=3).to(device)
dec = nn.Sequential(MLP(latent_dim,input_dim,[hidden_dim]), trunc_idct, output_unet).to(device)

enc1.load_state_dict(checkpt['enc1_state_dict'])
enc2.load_state_dict(checkpt['enc2_state_dict'])
dec.load_state_dict(checkpt['dec_state_dict'])

losses = checkpt['losses']

In [None]:
del checkpt

In [None]:
plt.plot(torch.log(torch.tensor(losses)))

In [None]:
# Concatenate enc1 and enc2 to get encoder

enc = nn.Sequential(enc1, enc2)

# Concatenate enc and dec to get autoencoder

autoencoder = nn.Sequential(enc, dec)

In [None]:
# Draw N random points from training set

X, _ = next(iter(trainloader))

X = X.to(device)
# Take truncated DCT of X
X_dct = trunc_dct(X) # X.shape = (batch_size, 3, 128, 128)
X_dct = X_dct.reshape(X_dct.shape[0], D)

In [None]:
# Singular values of encoder Jacobian at training point

idx = 0
x = X_dct[idx]
rank = 700
start = time.time()
U, S, V = rsvd(enc, x, rank, oversampling_factor=10)
end = time.time()
print('Time for rsvd: %.4f' % (end - start))

In [None]:
plt.plot(S.detach().cpu())

In [None]:
def show(img):
    npimg = img.numpy()
    plt.figure(figsize=(20,5))
    # no ticks
    plt.xticks([])
    plt.yticks([])
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
# Visualize first 10 elements of X on subplots

X_viz = X[:8].cpu()

# Visualize X_viz on torch grid

grid = torchvision.utils.make_grid(X_viz.reshape(-1,3,256,256), nrow=10)
show(grid)

In [None]:
# Visualize outputs of decoder MLP

dec_mlp_out = dec[0](enc(X_dct[:5]))
dec_mlp_out = trunc_idct(dec_mlp_out)
dec_mlp_out = dec_mlp_out.detach().cpu()

# Visualize dec_mlp_out on torch grid

grid = torchvision.utils.make_grid(dec_mlp_out.reshape(-1,3,256,256), nrow=5)
show(grid)

In [None]:
# Visualize reconstructions

X_rec = dec(enc(X_dct[:8]))
X_rec = X_rec.detach().cpu()

# Visualize X_rec on torch grid

grid = torchvision.utils.make_grid(X_rec.reshape(-1,3,256,256), nrow=10)
show(grid)

In [None]:
# Interpolate between latents and visualize reconstructions

Z = enc(X_dct)
Z_interp = torch.zeros(10, latent_dim).to(device)
Z_interp[0] = Z[0]
Z_interp[-1] = Z[3]
for i in range(1, 9):
    Z_interp[i] = (i/9) * Z[3] + ((9-i)/9) * Z[0]
X_interp = dec(Z_interp)
X_interp = X_interp.detach().cpu()

# Visualize X_interp on torch grid

grid = torchvision.utils.make_grid(X_interp.reshape(-1,3,256,256), nrow=10)
show(grid)

In [None]:
# Take dominant left-singular vector of enc at training point

idx = 0
x = X_dct[idx]
rank = 700
U, S, V = rsvd(enc, x, rank, oversampling_factor=10)
u = U[:,:5]

# Move along each of the top 5 left-singular vectors of enc at training point
# Visualize reconstructions along rows

Z = enc(X_dct[:10])
Z_interp = torch.zeros(u.shape[1]*8, latent_dim).to(device)
t = torch.linspace(-1, 1, 8).to(device)
scale = 20000 # 2k for sigma=0.5, 20k for sigma=0
for i in range(u.shape[1]):
    for j in range(8):
        Z_interp[i*8+j] = Z[idx] + t[j] * scale * u[:,i]
X_interp = dec(Z_interp)
X_interp = X_interp.detach().cpu()

# Visualize X_interp on torch grid

grid = torchvision.utils.make_grid(X_interp.reshape(-1,3,256,256), nrow=8)
show(grid)

# Save the grid

plt.savefig('results/sigma0_top5_left_singular_vecs_training_0.png', bbox_inches='tight', dpi=300)