In [2]:
import torch
import torch.nn as nn

In [6]:
class AE(nn.Module):
    def __init__(self, latent_dims=256):
        super(AE, self).__init__()
        self.linear1 = nn.Linear(50176, 4096)
        self.linear2 = nn.Linear(4096, 1024)
        self.linear3 = nn.Linear(1024, latent_dims)
        self.linear4 = nn.Linear(latent_dims, 1024)
        self.linear5 = nn.Linear(1024, 4096)
        self.linear6 = nn.Linear(4096, 50176)
        
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        latent = F.relu(self.linear3(x))
        x = F.relu(self.linear4(latent))
        x = F.relu(self.linear5(x))
        x = F.relu(self.linear6(x))
        return latent, x

In [7]:
model = AE()
#Loss function
criterion = torch.nn.MSELoss()

#Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def get_device():
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'
    return device

device = get_device()
model.to(device)

AE(
  (linear1): Linear(in_features=50176, out_features=4096, bias=True)
  (linear2): Linear(in_features=4096, out_features=1024, bias=True)
  (linear3): Linear(in_features=1024, out_features=256, bias=True)
  (linear4): Linear(in_features=256, out_features=1024, bias=True)
  (linear5): Linear(in_features=1024, out_features=4096, bias=True)
  (linear6): Linear(in_features=4096, out_features=50176, bias=True)
)

In [8]:
import os
from PIL import Image
import natsort 

class CustomDataSet(torch.utils.data.Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = os.listdir(main_dir)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc)#.convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image

In [9]:
import torchvision.transforms as transforms 
image_encoding_path = "/Users/jameswang/Desktop/Robotics/depthEncodingData/"
transform = transforms.ToTensor()

image_dataset = CustomDataSet(image_encoding_path, transform)
train_loader = torch.utils.data.DataLoader(image_dataset, batch_size=32, num_workers=0)

In [13]:
import torchvision.datasets as datasets
import torch.nn.functional as F

#Converting data to torch.FloatTensor


# Download the training and test datasets
# train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)

# test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, num_workers=0)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size=32, num_workers=0)

train_log = []
n_epochs = 100

for epoch in range(1, n_epochs+1):
    # monitor training loss
    train_loss = 0.0

    #Training
    for data in train_loader:
        images = data
        images = images.to(device)
        optimizer.zero_grad()
        latent, outputs = model(images)
        
        loss = criterion(outputs, torch.flatten(images, start_dim=1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*images.size(0)
          
    train_loss = train_loss/len(train_loader)
    train_log.append(train_loss)
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))


FileNotFoundError: [Errno 2] No such file or directory: '/Users/jameswang/Desktop/Robotics/depthEncodingData/.ipynb_checkpoints'

In [None]:

def plot_ae_outputs(encoder,decoder,n=5):
    plt.figure(figsize=(10,4.5))
    for i in range(n):
        ax = plt.subplot(2,n,i+1)
        img = test_data[i][0].unsqueeze(0).to(device)
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            rec_img  = decoder(encoder(img))
        plt.imshow(img.cpu().squeeze().numpy(), cmap='gist_gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2:
            ax.set_title('Original images')
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(rec_img.cpu().squeeze().numpy(), cmap='gist_gray')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2:
            ax.set_title('Reconstructed images')
    plt.show()