In [2]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as datasets

import os
import pickle
import numpy as np
from PIL import Image

from skimage import io

from matplotlib import pyplot as plt
%matplotlib inline

import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['image.interpolation'] = 'nearest'
mpl.rcParams['figure.figsize'] = 15, 25

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
class EncodingsDataset():
    def __init__(self, artifact_dir, model_file, view, kind):
        self.artifact_dir = artifact_dir
        self.view = view
        self.camera_images_dir = os.path.join(artifact_dir, 'data', 'camera_data', kind)
        self.parts_images_dir = os.path.join(artifact_dir, 'data', 'parts_data', kind)

        self.model_dir = os.path.join(artifact_dir, 'models')
        self.model_path = os.path.join(self.model_dir, model_file)

        self.camera_image_names = _get_camera_images_by_view(self.camera_images_dir, view)
        self.parts_image_names = _get_camera_images_by_view(self.parts_images_dir, view)
    
    def __getitem__(self, idx):
        camera_image_path = os.path.join(self.camera_images_dir, self.camera_image_names[idx])
        parts_image_path = os.path.join(self.parts_images_dir, self.camera_image_names[idx])
        model_path = os.path.join(self.model_dir, '')

        camera_image = Image.open(camera_image_path).convert('RGB')
        camera_image = torchvision.transforms.functional.to_tensor(camera_image) 
        camera_image = camera_image.view(1, *parts_image.shape)

        parts_image = Image.open(parts_image_path).convert('RGB')
        parts_image = torchvision.transforms.functional.to_tensor(parts_image)
        parts_image = parts_image.view(1, *parts_image.shape)

        encoder_model = torch.load(self.model_path).encoder
        part_encoding = encoder_model(parts_image)

        return camera_image, part_encoding

    def __len__(self):
        return len(self.camera_image_names)

    def _get_camera_images_by_view(path, view):
        images = list(sorted(os.listdir(path)))
        l = map(lambda name: name.split('.')[0], images)
        l = list(sorted(list(filter(lambda name: name.endswith(view), l))))
        return l
        
