# 🐈 Cat Discriminator

### 📝 Imports

In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import albumentations as A

import torchvision.transforms as transforms

import os

from src.cat_discriminator_neural_net import CatDiscriminatorNeuralNet

from src.augmentation.data_augmenter import DataAugmenter
from src.cats_dataset import CatsDataset

  check_for_updates()


### 🔧 Config

In [2]:
image_size = 512

learning_rate = 0.01
epochs_between_saves = 5
total_epochs = 200

batch_size = 32

num_cores = 2

saved_model_path = "trained_networks/cat_discriminator.pth"

training_data_path = 'data/training/'

### 🌐 Create Transforms

In [3]:
transforms = transforms.Compose([
    DataAugmenter(image_size),
    transforms.ToTensor(),
])

DataAugmenter initialized


### 🚦 Load Training Data

In [4]:
dataset = CatsDataset(
    root_dir=training_data_path,
    transform=transforms
)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_cores)

### 🥾 Initialize the Neural Net

In [5]:
net = CatDiscriminatorNeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)

  net.load_state_dict(torch.load(saved_model_path))


### 🏃‍♂️‍➡️ Train

In [None]:
epoch = 1
while epoch <= total_epochs:
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()

        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
    if epoch % epochs_between_saves == 0:
        print('Saving model...')
        torch.save(net.state_dict(), saved_model_path)
        
    print (f'Loss: {running_loss / len(dataloader):.4f}')
    epoch += 1

Training epoch 1


### 💾 Save Progress

In [7]:
torch.save(net.state_dict(), saved_model_path)