# 🐈 Cat Discriminator

### 📝 Imports

In [15]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision
import torchvision.transforms as transforms

import os

from PIL import Image

from cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from tools.augmentation.data_augmenter import DataAugmenter


### 🔧 Config

In [16]:
learning_rate = 0.01
epochs_between_saves = 5
total_epochs = 100

saved_model_path = "trained_networks/cat_discriminator_512x512.pth"
bathroom_cat_dir_path = 'data/bathroom-cat-512x512/'
captain_dir_path = 'data/captain-512x512/'
control_dir_path= 'data/control-512x512/'

### 🌐 Create Transforms

In [17]:
transforms = transforms.Compose([
    DataAugmenter(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

DataAugmenter initialized


### 🚦 Load Training Data

In [18]:
from cats_dataset import CatsDataset


dataset = CatsDataset(bathroom_cat_dir=bathroom_cat_dir_path, captain_dir=captain_dir_path, control_dir=control_dir_path, transform=transforms)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

### 🥾 Initialize the Neural Net

In [19]:
net = CatDiscriminatorNeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

  net.load_state_dict(torch.load(saved_model_path))


### 🏃‍♂️‍➡️ Train

In [20]:
from matplotlib import pyplot as plt

epoch = 1
while epoch <= total_epochs:
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
    if epoch % epochs_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)
        
    print (f'Loss: {running_loss / len(dataloader):.4f}')
    epoch += 1

Training epoch 1
Loss: 0.1071
Training epoch 2
Loss: 0.0979
Training epoch 3
Loss: 0.1097
Training epoch 4
Loss: 0.0967
Training epoch 5
Saving model...
Loss: 0.1001
Training epoch 6
Loss: 0.1083
Training epoch 7
Loss: 0.0972
Training epoch 8
Loss: 0.0903
Training epoch 9
Loss: 0.0847
Training epoch 10
Saving model...
Loss: 0.0998
Training epoch 11
Loss: 0.1060
Training epoch 12
Loss: 0.1084
Training epoch 13
Loss: 0.0966
Training epoch 14
Loss: 0.1026
Training epoch 15
Saving model...
Loss: 0.0995
Training epoch 16
Loss: 0.1051
Training epoch 17
Loss: 0.0847
Training epoch 18
Loss: 0.1037
Training epoch 19
Loss: 0.1067
Training epoch 20
Saving model...
Loss: 0.0929
Training epoch 21
Loss: 0.0823
Training epoch 22
Loss: 0.1022
Training epoch 23
Loss: 0.0993
Training epoch 24
Loss: 0.0798
Training epoch 25
Saving model...
Loss: 0.1069
Training epoch 26
Loss: 0.1147
Training epoch 27
Loss: 0.1040
Training epoch 28
Loss: 0.1021
Training epoch 29
Loss: 0.1021
Training epoch 30
Saving model

### 💾 Save Progress

In [21]:
torch.save(net.state_dict(), saved_model_path)