# 🐈 Cat Discriminator

### 📝 Imports

In [7]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision.transforms as transforms

import os

from src.cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from src.augmentation.data_augmenter import DataAugmenter
from src.cats_dataset import CatsDataset

### 🔧 Config

In [8]:
image_size = 512

learning_rate = 0.01
batches_between_saves = 5
number_of_batches = 30

batch_size = 32

num_cores = 2

saved_model_path = "trained_networks/cat_discriminator.pth"

training_data_path = 'data/training/'

### 🌐 Create Transforms

In [9]:
transforms = transforms.Compose([
    DataAugmenter(image_size, augment_images=True),
    transforms.ToTensor(),
])

### 🚦 Load Training Data

In [10]:
dataset = CatsDataset(
    root_dir=training_data_path,
    transform=transforms
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

### 🥾 Initialize the Neural Net

In [11]:
net = CatDiscriminatorNeuralNet(learning_rate, saved_model_path)
net.cuda();

### 🏃‍♂️‍➡️ Train

In [12]:
batch_number = 0

while batch_number < number_of_batches:
    print('Training batch: ', batch_number)
    net.run_training_batch(learning_rate, dataloader)

    if batch_number % batches_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)

    batch_number += 1


Training batch:  0
labels
tensor([[0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.]], device='cuda:0')
labels
tensor([[0., 0., 1.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
   

UnboundLocalError: cannot access local variable 'epoch' where it is not associated with a value

### 💾 Save Progress

In [34]:
torch.save(net.state_dict(), saved_model_path)