In [None]:
# open file list_attr_celeba.txt and read the content in a list

import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

img_path = "data/img_align_celeba"
df = pd.read_csv("data/list_attr_celeba.csv")
df = df.head(10000)
print(df.columns)
with_glasses = df['image_id'][(df['Eyeglasses'] == 1) & (df['Male']==1)].tolist()
without_glasses = df['image_id'][(df['Eyeglasses'] == -1) & (df['Male']==1)].tolist()

print(len(with_glasses))
print(len(without_glasses))

print(with_glasses[-5])




In [None]:
glass_images = [Image.open(f'{img_path}/{with_glasses[np.random.randint(0, len(with_glasses))]}') for i in range(5)]
f = plt.figure()
for i in range(5):
    # Debug, plot figure
    f.add_subplot(1, 5, i + 1)
    plt.axis('off')
    plt.imshow(glass_images[i])

plt.show(block=True) 

no_glass_images = [Image.open(f'{img_path}/{without_glasses[np.random.randint(0, len(without_glasses))]}') for i in range(5)]
f = plt.figure()
for i in range(5):
    # Debug, plot figure
    f.add_subplot(1, 5, i + 1)
    plt.axis('off')
    plt.imshow(no_glass_images[i])

plt.show(block=True) 


In [None]:
from VanillaVAE import VanillaVAE
import torch
from torchvision import transforms


INPUT_DIM = 3
Z_DIM = 500
PATH = "model_wandb.pt"
device = torch.device('cpu')

model = VanillaVAE(INPUT_DIM, Z_DIM)
checkpoint = torch.load(PATH,map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

for param in model.parameters():
    param.requires_grad = False

transform=transforms.ToTensor()

f = plt.figure()
for i in range(len(glass_images)):
    with torch.no_grad():
        out = model.forward(transform(glass_images[i]).unsqueeze(0).to(device),without_varinace=False)[0]
        
    out = out.view(-1, 3, 224, 192)
    out = np.transpose(out, (0, 2, 3, 1))
    f.add_subplot(1, len(glass_images), i + 1)
    plt.axis('off')
    plt.imshow(out[0])



In [None]:
from CustomDataset import CustomDataset
from torchvision import transforms
from torch.utils.data import DataLoader


def get_avg_mu(labels):

    batch_size = 1
    glasses_dataset = CustomDataset("data/img_align_celeba/",labels,transforms.ToTensor())
    glasses_loader = DataLoader(glasses_dataset, batch_size=batch_size, shuffle=False)

    mu_sum = torch.zeros(Z_DIM)
    sigma_sum = torch.zeros(Z_DIM)
    count = len(labels)
    for batch in glasses_loader:
        #count += batch.shape[0]
        batch = batch.to(device)
        mu, sigma = model.encode(batch)
        # print(mu[0][:5])
        mu_sum = mu_sum + mu.T.sum(dim=1)
        sigma_sum = sigma_sum + sigma.T.sum(dim=1)
        # print(sigma_sum.shape)
    
    print(count)
    mu_avg = mu_sum/count
    sigma_avg = sigma_sum/count

    print(sigma_sum.shape)
    return mu_avg
 

In [None]:
mu_avg_glasses = get_avg_mu(with_glasses)
# mu_avg_no_glasses = get_avg_mu(without_glasses)
mu_avg_no_glasses = get_avg_mu(without_glasses[:len(with_glasses)])

delta_mu = mu_avg_glasses - mu_avg_no_glasses
print(delta_mu.shape)
print(delta_mu[4])
print(delta_mu.sum())


In [None]:
torch.save(delta_mu, 'delta_mu.pt')

In [None]:
f = plt.figure()
for i in range(len(no_glass_images)):
    with torch.no_grad():
        x = transform(no_glass_images[i]).unsqueeze(0).to(device)
        # x = torch.rand_like(transform(no_glass_images[i]).unsqueeze(0)).to(device) 
        # x = torch.zeros_like(x).to(device)
        out = model.generate_with_delta(x,delta_mu*1.5,without_variance = False) #torch.zeros_like(delta_mu))
        
    out = out.view(-1, 3, 224, 192)
    out = np.transpose(out, (0, 2, 3, 1))
    f.add_subplot(1, len(no_glass_images), i + 1)
    plt.axis('off')
    plt.imshow(out[0])

In [None]:
f = plt.figure()
for i in range(2):
    print(i)
    with torch.no_grad():
        if i==0:
            print("x")
            out = model.decode(mu_avg_glasses)
        else:
            print("y")
            out = model.decode(mu_avg_no_glasses)
    out = out.view(-1, 3, 224, 192)
    out = np.transpose(out, (0, 2, 3, 1))
    f.add_subplot(1, 2, i + 1)
    plt.axis('off')
    plt.imshow(out[0])

In [None]:
print(mu_avg_glasses[4])

In [None]:
print(mu_avg_no_glasses[4])