In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from PIL import Image
import torchvision
from torchvision import datasets, models, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.metrics import *
import time
import os
from torch.utils import data
import random

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda') 
else:
    device = torch.device('cpu')

In [None]:
# root directory
andrea_dir = "/home/andreasabo/Documents/HNProject/"

# data directory on current machine: abhishekmoturu, andreasabo, denizjafari, navidkorhani
data_dir = "/home/navidkorhani/Documents/HNProject/"

# read target df
csv_path = os.path.join(andrea_dir, "all_splits_1000000.csv")
data_df = pd.read_csv(csv_path, usecols=['subj_id', 'image_ids', 'view_label', 'view_train'])

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        hidden_dim = 2000
        latent_dim = 400
        self.fc1 = nn.Linear(65536, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 65536)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        #print("z.size() =", z.size())
        h3 = F.relu(self.fc3(z))
        #print("h3.size() =", h3.size())
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 65536))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

vae_model = VAE().to(device)
checkpoint = torch.load('results/h2000_l400_e100/vae_model.pt')
vae_model.load_state_dict(checkpoint)
vae_model.eval()

for params in vae_model.parameters():
    params.requires_grad = False

In [None]:
batch_size = 256

image_names = data_df['image_ids'].to_list()
num_of_images = len(filenames)

r = 0
ind = 0
with torch.no_grad():
    while ind < num_of_image:
        if r%5==0:
            print(r)
        r += 1
        start = ind
        end = min(ind+batch_size, num_of_images)

        images_np = np.array([np.array(Image.open(data_dir+'all_label_img/'+image_file+'.jpg')) 
                              for image_file in image_names[start:end]])/255

        images_tensor = torch.tensor(images_np) #batch_size x 256 x 256
        images_tensor.to(device)

        recon_batch, mu, logvar = vae_model(images_tensor.view(-1, 65536)) # batch_size x (256^2)\
        reshap_recon = recon_batch.view(-1, 1, 256, 256)

        for i in range(start, end):
            save_image(reshap_recon[i].cpu(), data_dir +'all_label_img_recon400/'+image_names[i]+'.jpg')

        ind = end

In [None]:
with torch.no_grad():
    for ind, row in data_df.iterrows():
        if ind%1000==0:
            print(ind)
        img_path = data_dir + 'all_label_img/' + row['image_ids'] + '.jpg'
        image = Image.open(img_path).convert('L')
        image = ToTensor()(image)
        image = image.to(device)

        #mu, logvar = vae_model.encode(image.view(-1, 65536))
        #z = vae_model.reparameterize(mu, logvar)
        #output_file = 'latent100_images/'+row['image_ids']+'.npy'
        #print(output_file)
        #np.save(output_file, z.detach().cpu().numpy())
        
        #recon_batch, mu, logvar = vae_model(image.view(-1, 65536))
        #recon_img = recon_batch.view(1, 256, 256)
        #save_image(recon_img.cpu(), data_dir + 'all_label_img_recon400/' + row['image_ids'] + '.jpg')
        
        ind+=1
    

In [None]:
targets.dtype


In [None]:
data_df.head()

In [None]:
import torch
import numpy as np

In [None]:
a = torch.tensor(np.array([[[1,2,3], [4,5,6], [7,8,9]], [[11,12,13], [14,15,16], [17,18,19]]]))

In [None]:
a.shape

In [None]:
a[0]