In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
%run DataLoader.ipynb
%run ImageProcessor.ipynb
%run DataExplorer.ipynb
%run DatasetStatistics.ipynb
%run DuplicateDetector.ipynb

In [None]:
DATA_ROOT = "./data"
TRAIN_DIR = "./data/Training"
TEST_DIR = "./data/Testing"
BATCH_SIZE = 64
IMAGE_SIZE = (224, 224)

In [None]:
device = (
    torch.accelerator.current_accelerator().type
    if torch.accelerator.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
loader = DataLoader(TRAIN_DIR, TEST_DIR)
all_files = loader.load_all_images()

print(f"Successfully loaded {len(all_files)} images")

In [None]:
duplicate_detector = DuplicateDetector(all_files)
duplicate_detector.detect_duplicates()

if len(duplicate_detector.duplicates) > 0:  # If duplicate files are present
    duplicate_detector.remove_duplicates_from_disk()  # Removing duplicates entirely from disk to make every next call faster
    all_files = duplicate_detector.get_unique_files()  # Cleaning list with file paths

In [None]:
loader.print_dataset_class_count()

In [None]:
processor = ImageProcessor(all_files)

processor.load_grayscale_images()
processor.display_image_grid()  # TODO: test different batch sizes

In [None]:
explorator = DataExplorer(all_files)
explorator.retrieve_sample_of_images(
    [0, len(all_files) // 2, -1]
)  # First, middle and last image

explorator.plot_histogram()

In [None]:
train_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Resize(image_size)]
)

testval_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Resize(image_size)]
)

In [None]:
trainset = datasets.ImageFolder(training_path, transform=train_transform)
testset = datasets.ImageFolder(testing_path, transform=testval_transform)

split_ratio = 0.15
trainset_len = len(trainset)
valset_len = int(split_ratio * trainset_len)

trainset, validationset = torch.utils.data.random_split(
    trainset,
    [trainset_len - valset_len, valset_len],
    generator=torch.Generator().manual_seed(42),
)

In [None]:
train_dl = DataLoader(trainset, batch_size, shuffle=True, num_workers=3)
test_dl = DataLoader(testset, batch_size, shuffle=True, num_workers=3)
validation_dl = DataLoader(validationset, batch_size, shuffle=True, num_workers=3)

In [None]:
examples = iter(train_dl)
imgs, labels = next(examples)
# access labels
# class_names = trainset.dataset.classes