# Import libs

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt

from torchvision import transforms
from torch.nn import ModuleList

# custom libs
from libs.TDContainer import TDContainer
from libs.PretrainedModels import PretrainedModelsCreator, AlexNet_cc, SqueezeNet_cc, InceptionV3_cc
from libs.utils import get_model_name, import_dataset

In [None]:
random.seed(1996)
np.random.seed(1996)

# Prepare dataset

In [None]:
path_dst = 'dataset'
path_gdrive = ''

# parameters for dataloaders
batch_size=32
num_workers=2
drop_last=True

In [None]:
# default mean and std needed by pretrained models from pytorch
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

dst = import_dataset(path_dst=path_dst, 
    train_transform=transforms.Compose([
        transforms.Resize(256),
        # using the same as the test because the trash bin is centered in the image
        transforms.CenterCrop(224), # good for inceptionv3?

        transforms.RandomApply(ModuleList([
            transforms.Pad(padding=5, fill=0, padding_mode='constant'),
        ]), p=0.3),        

        transforms.RandomApply(ModuleList([
            transforms.ColorJitter(brightness=.6, hue=.4),
        ]), p=0.3),

        transforms.RandomApply(ModuleList([
            transforms.Grayscale(num_output_channels=3),
        ]), p=0.2),

        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomPerspective(distortion_scale=0.5, p=0.4),
        transforms.RandomEqualize(p=0.4),
        transforms.RandomAffine(degrees=(20, 50), translate=(0.2, 0.5), scale=(0.6, 0.75)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ]),
    test_transform=transforms.Compose([
        transforms.Resize(256), 
        transforms.CenterCrop(224), # good for inceptionv3?
        transforms.AutoAugmentPolicy(policy=transforms.AutoAugmentPolicy.SVHN),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ]), path_gdrive=path_gdrive)

dst.create_data_loader(batch_size=batch_size, num_workers=num_workers, drop_last=drop_last)

In [None]:
train_features, train_labels = next(iter(dst.training_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

plt.figure(figsize=(15,8))
for i in range(5):
    plt.subplot(2,5,i+1)
    plt.title("Class: %s" % train_labels[0].numpy())
    img = train_features[0].squeeze()
    label = train_labels[0]
    plt.imshow( np.transpose(img.numpy().astype(np.uint8), (1, 2, 0) ) )
    train_features, train_labels = next(iter(dst.training_loader))
plt.show()

# Find best LR

# Training