# 🐈 Cat Discriminator

### 📝 Imports

In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision
import torchvision.transforms as transforms

import os

from PIL import Image

from cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from tools.augmentation.data_augmenter import DataAugmenter


  check_for_updates()


### 🔧 Config

In [2]:
learning_rate = 0.01
epochs_between_saves = 5
total_epochs = 30

saved_model_path = "trained_networks/cat_discriminator_512x512.pth"
bathroom_cat_dir_path = 'data/bathroom-cat-512x512/'
captain_dir_path = 'data/captain-512x512/'
control_dir_path= 'data/control-512x512/'

### 🌐 Create Transforms

In [3]:
transforms = transforms.Compose([
    DataAugmenter(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

DataAugmenter initialized


### 🚦 Load Training Data

In [4]:
from cats_dataset import CatsDataset


dataset = CatsDataset(bathroom_cat_dir=bathroom_cat_dir_path, captain_dir=captain_dir_path, control_dir=control_dir_path, transform=transforms)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

### 🥾 Initialize the Neural Net

In [5]:
net = CatDiscriminatorNeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

  net.load_state_dict(torch.load(saved_model_path))


### 🏃‍♂️‍➡️ Train

In [6]:
from matplotlib import pyplot as plt

epoch = 1
while epoch <= total_epochs:
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
    if epoch % epochs_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)
        
    print (f'Loss: {running_loss / len(dataloader):.4f}')
    epoch += 1

Training epoch 1
Loss: 0.1043
Training epoch 2
Loss: 0.1045
Training epoch 3
Loss: 0.0673
Training epoch 4
Loss: 0.0960
Training epoch 5
Saving model...
Loss: 0.0650
Training epoch 6
Loss: 0.0903
Training epoch 7
Loss: 0.0902
Training epoch 8
Loss: 0.0567
Training epoch 9
Loss: 0.0770
Training epoch 10
Saving model...
Loss: 0.0588
Training epoch 11
Loss: 0.0682
Training epoch 12
Loss: 0.0660
Training epoch 13
Loss: 0.0599
Training epoch 14
Loss: 0.0752
Training epoch 15
Saving model...
Loss: 0.0756
Training epoch 16
Loss: 0.0667
Training epoch 17
Loss: 0.0600
Training epoch 18
Loss: 0.0666
Training epoch 19
Loss: 0.0549
Training epoch 20
Saving model...
Loss: 0.0471
Training epoch 21
Loss: 0.0652
Training epoch 22
Loss: 0.0716
Training epoch 23
Loss: 0.0532
Training epoch 24
Loss: 0.0739
Training epoch 25
Saving model...
Loss: 0.0474
Training epoch 26
Loss: 0.0695
Training epoch 27
Loss: 0.0655
Training epoch 28
Loss: 0.0547
Training epoch 29
Loss: 0.0770
Training epoch 30
Saving model

### 💾 Save Progress

In [7]:
torch.save(net.state_dict(), saved_model_path)