# 🐈 Cat Discriminator

### 📝 Imports

In [14]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision
import torchvision.transforms as transforms

import os

from PIL import Image

from cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from tools.augmentation.data_augmenter import DataAugmenter


### 🔧 Config

In [20]:
learning_rate = 0.0001
epochs_between_saves = 100

saved_model_path = "trained_networks/cat_discriminator.pth"
bathroom_cat_dir_path = 'data/bathroom-cat-128x128/'
captain_dir_path = 'data/captain-128x128/'
control_dir_path= 'data/control-128x128/'

### 🌐 Create Transforms

In [16]:
transforms = transforms.Compose([
    DataAugmenter(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

DataAugmenter initialized


### 🚦 Load Training Data

In [17]:
from cats_dataset import CatsDataset


dataset = CatsDataset(bathroom_cat_dir=bathroom_cat_dir_path, captain_dir=captain_dir_path, control_dir=control_dir_path, transform=transforms)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

### 🥾 Initialize the Neural Net

In [18]:
net = CatDiscriminatorNeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

  net.load_state_dict(torch.load(saved_model_path))


### 🏃‍♂️‍➡️ Train

In [19]:
from matplotlib import pyplot as plt

epoch = 1
while True:
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        inputs, labels = data

        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
        if epoch % epochs_between_saves == 0:
            print('Saving model...')
            torch.save(net.state_dict(), saved_model_path)
        
    print (f'Loss: {running_loss / len(dataloader):.4f}')
    epoch += 1

Training epoch 0
Loss: 1.0302
Training epoch 1
Loss: 1.0295
Training epoch 2
Loss: 1.0272
Training epoch 3
Loss: 1.0331
Training epoch 4
Loss: 1.0215
Training epoch 5
Loss: 1.0430
Training epoch 6
Loss: 1.0274
Training epoch 7
Loss: 1.0246
Training epoch 8
Loss: 1.0545
Training epoch 9
Loss: 1.0194
Training epoch 10
Loss: 1.0351
Training epoch 11
Loss: 1.0294
Training epoch 12
Loss: 1.0290
Training epoch 13
Loss: 1.0430
Training epoch 14
Loss: 1.0232
Training epoch 15
Loss: 1.0219
Training epoch 16
Loss: 1.0291
Training epoch 17
Loss: 1.0343
Training epoch 18
Loss: 1.0344
Training epoch 19
Loss: 1.0189
Training epoch 20
Loss: 1.0425
Training epoch 21
Loss: 1.0489
Training epoch 22
Loss: 1.0167
Training epoch 23
Loss: 1.0283
Training epoch 24
Loss: 1.0156
Training epoch 25
Loss: 1.0178
Training epoch 26
Loss: 1.0202
Training epoch 27
Loss: 1.0177
Training epoch 28
Loss: 1.0551
Training epoch 29
Loss: 1.0154
Training epoch 30
Loss: 1.0275
Training epoch 31
Loss: 1.0203
Training epoch 32


KeyboardInterrupt: 

### 💾 Save Progress

In [7]:
torch.save(net.state_dict(), saved_model_path)