In [None]:
from torch.utils import data
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
from PIL import Image
from torchvision.transforms import ToTensor


In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        hidden_dim = 800
        latent_dim = 50
        self.fc1 = nn.Linear(65536, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, 65536)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        #print("z.size() =", z.size())
        h3 = F.relu(self.fc3(z))
        #print("h3.size() =", h3.size())
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 65536))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
project_dir = '/home/navidkorhani/Documents/HNProject/'
# Load data and get label
img_path1 = project_dir+'all_label_img/1068_1_1.jpg'
img_path2 = project_dir+'all_label_img/1068_1_2.jpg'
im1 = np.array(Image.open(img_path1).convert('L'))
im2 = np.array(Image.open(img_path2).convert('L'))
ary = np.array([im1, im2])

#array of dimension N x 256 x 256
images = torch.tensor(ary, dtype=torch.float)

In [None]:
model = VAE()
checkpoint = torch.load(project_dir+'HNUltra/saved models/vae_model_h800_l50.pt')
model.load_state_dict(checkpoint)

In [None]:
mu, logvar = model.encode(images.view(-1, 65536))
z = model.reparameterize(mu, logvar)

In [None]:
z.shape