# 🐈 Cat Discriminator

### 📝 Imports

In [7]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision.transforms as transforms

import os

from src.cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from src.augmentation.data_augmenter import DataAugmenter
from src.cats_dataset import CatsDataset

### 🔧 Config

In [8]:
image_size = 512

learning_rate = 0.001
batches_between_saves = 1
number_of_batches = 100

batch_size = 32

num_cores = 4

saved_model_path = "trained_networks/experiment-binary-classifier-with-synthetic-data.pth"

root_directory = 'data/'
control_folders=['training/bathroom-cat', 'synthetic/tortoiseshell', 'synthetic/control', 'training/control']
captain_present_folders=['training/captain', 'synthetic/tabby']

### 🌐 Create Transforms

In [9]:
transforms = transforms.Compose([
    DataAugmenter(image_size, augment_images=True),
    transforms.ToTensor(),
])

### 🚦 Load Training Data

In [10]:
dataset = CatsDataset(
    root_dir=root_directory,
    control_folders=control_folders,
    captain_present_folders=captain_present_folders,
    transform=transforms
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

### 🥾 Initialize the Neural Net

In [11]:
net = CatDiscriminatorNeuralNet(learning_rate, saved_model_path)
net.cuda();

### 🏃‍♂️‍➡️ Train

In [12]:
batch_number = 0

while batch_number < number_of_batches:
    print('Training batch: ', batch_number)
    net.run_training_batch(learning_rate, dataloader)

    if batch_number % batches_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)

    batch_number += 1


Training batch:  0
Loss: 0.0877
Saving model...
Training batch:  1
Loss: 0.0955
Saving model...
Training batch:  2
Loss: 0.1016
Saving model...
Training batch:  3
Loss: 0.0859
Saving model...
Training batch:  4
Loss: 0.0986
Saving model...
Training batch:  5
Loss: 0.0931
Saving model...
Training batch:  6
Loss: 0.0915
Saving model...
Training batch:  7
Loss: 0.0882
Saving model...
Training batch:  8
Loss: 0.0827
Saving model...
Training batch:  9
Loss: 0.0865
Saving model...
Training batch:  10
Loss: 0.0949
Saving model...
Training batch:  11
Loss: 0.0931
Saving model...
Training batch:  12
Loss: 0.1038
Saving model...
Training batch:  13
Loss: 0.1083
Saving model...
Training batch:  14
Loss: 0.0817
Saving model...
Training batch:  15
Loss: 0.0886
Saving model...
Training batch:  16
Loss: 0.0797
Saving model...
Training batch:  17
Loss: 0.0754
Saving model...
Training batch:  18
Loss: 0.0800
Saving model...
Training batch:  19
Loss: 0.0921
Saving model...
Training batch:  20
Loss: 0.07