In [1]:
import torch
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from helpers import train, evaluation, generate_confusion_matrix
from helpers import plot_accuracy, plot_loss, plot_confusion_matrix
import numpy as np

In [2]:
DATA_PATH = './15SceneData/'
PLOTS_PATH = './plots/'

DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [3]:
# Load the dataset
train_dataset = datasets.ImageFolder(DATA_PATH + 'train',
                                     transform=transforms.Compose([
                                         transforms.Resize((224, 224)),
                                         transforms.ToTensor()
                                     ]))

val_dataset = datasets.ImageFolder(DATA_PATH + 'validation',
                                   transform=transforms.Compose([
                                       transforms.Resize((224, 224)),
                                       transforms.ToTensor()
                                   ]))

# Calculate mean and std for train dataset

In [4]:
loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)

mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

mean, std

(tensor([0.4559, 0.4559, 0.4559]), tensor([0.2199, 0.2199, 0.2199]))

# Mean and std for val dataset

In [5]:
loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True)

mean = 0.
std = 0.
for images, _ in loader:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

mean /= len(loader.dataset)
std /= len(loader.dataset)

mean, std

(tensor([0.4552, 0.4552, 0.4552]), tensor([0.2191, 0.2191, 0.2191]))

# Check if all images are b-w

In [6]:
# check if all images are black-white

for i, img in enumerate(train_dataset):
    image = img[0]
    if torch.all(image[0] == image[1]) and torch.all(image[0] == image[2]):
        continue
    else:
        print(i)

# Nothing printed, so all images are black-white (same value for all channels)

In [7]:
# check if all images are black-white

for i, img in enumerate(val_dataset):
    image = img[0]
    if torch.all(image[0] == image[1]) and torch.all(image[0] == image[2]):
        continue
    else:
        print(i)

# Nothing printed, so all images are black-white (same value for all channels)