# 🐈 Cat Discriminator

### 📝 Imports

In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision
import torchvision.transforms as transforms

import os

from PIL import Image

from cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from src.augmentation.data_augmenter import DataAugmenter


  check_for_updates()


### 🔧 Config

In [2]:
image_size = 512

learning_rate = 0.01
epochs_between_saves = 5
total_epochs = 200

saved_model_path = "trained_networks/cat_discriminator.pth"

training_data_path = 'data/staging'

bathroom_cat_path = '/bathroom-cat'
captain_dir_path = '/captain'
control_dir_path= '/control'

### 🌐 Create Transforms

In [3]:
transforms = transforms.Compose([
    DataAugmenter(image_size),
    transforms.ToTensor(),
])

DataAugmenter initialized


### 🚦 Load Training Data

In [4]:
from cats_dataset import CatsDataset


dataset = CatsDataset(
    bathroom_cat_dir=training_data_path + bathroom_cat_path, 
    captain_dir=training_data_path + captain_dir_path, 
    control_dir=training_data_path + control_dir_path, 
    transform=transforms)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=2)

### 🥾 Initialize the Neural Net

In [5]:
net = CatDiscriminatorNeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

  net.load_state_dict(torch.load(saved_model_path))


### 🏃‍♂️‍➡️ Train

In [6]:
epoch = 1
while epoch <= total_epochs:
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
    if epoch % epochs_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)
        
    print (f'Loss: {running_loss / len(dataloader):.4f}')
    epoch += 1

Training epoch 1
Loss: 0.3562
Training epoch 2
Loss: 0.3644
Training epoch 3
Loss: 0.3536
Training epoch 4
Loss: 0.3038
Training epoch 5


KeyboardInterrupt: 

### 💾 Save Progress

In [7]:
torch.save(net.state_dict(), saved_model_path)