In [75]:
import os
import numpy as np
import json
from skmultilearn.model_selection import iterative_train_test_split
import torch

from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

import torch.nn as nn
import torch.nn.functional as F

In [76]:
dataset_directory = 'dataset'

images = []
labels = []

for recipe_folder in os.listdir(dataset_directory):
    recipe_path = os.path.join(dataset_directory, recipe_folder)
    allergens_file = os.path.join(recipe_path, 'allergens.json')

    if os.path.isfile(allergens_file):
        with open(allergens_file, 'r') as f:
            data = json.load(f)
            allergens = data['allergens']

            for image in data['images']:
                images.append(os.path.join(dataset_directory, recipe_folder, image))
                labels.append(allergens)

In [77]:
x = np.array(images).reshape(-1,1) # iterative_train_test_split expects 2d array
y = np.array(labels)

In [78]:
x_train, y_train, x_temp, y_temp = iterative_train_test_split(x, y, test_size=0.3) # http://scikit.ml/stratification.html

In [79]:
x_test, y_test, x_val, y_val = iterative_train_test_split(x_temp, y_temp, test_size=0.5)

In [80]:
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape, x_val.shape, y_val.shape)

(43415, 1) (43415, 14) (9304, 1) (9304, 14) (9303, 1) (9303, 14)


In [81]:
x_train

array([['dataset\\$25_pumpkin_pie\\images/image_2.jpg'],
       ["dataset\\'get_up_&_go'_bars\\images/image_2.jpg"],
       ['dataset\\(panera_bread)_black_bean_soup\\images/image_2.jpg'],
       ...,
       ["dataset\\zurie's_overnight_no-knead_bread\\images/image_3.jpg"],
       ["dataset\\zurie's_overnight_no-knead_bread\\images/image_4.jpg"],
       ["dataset\\zurie's_overnight_no-knead_bread\\images/image_5.jpg"]],
      dtype='<U93')

In [114]:
# https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files

class FoodAllergenDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, target_transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx][0]
        image = Image.open(img_path).convert('RGB') # Convert all images to 3 channel RGB as dataset contains some 4 channel RGBA images 
        label = torch.tensor(self.labels[idx], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [117]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Test more transforms later in training - cropping, rotation, centering etc.

In [118]:
train_dataset = FoodAllergenDataset(x_train, y_train, transform=transform)
val_dataset = FoodAllergenDataset(x_val, y_val, transform=transform)
test_dataset = FoodAllergenDataset(x_test, y_test, transform=transform)

In [119]:
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [120]:
for images, labels in train_dataloader:
    print(images.shape)  
    print(labels.shape)

torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])
torch.Size([64, 3, 128, 128])
torch.Size([64, 14])


KeyboardInterrupt: 

In [113]:
# Sample CNN model from pytorch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()