# Train the RPS model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kjy5/cv-rock-paper-scissors/blob/main/scripts/train.ipynb)

# TEMP: resources:
- For transfer learning see [transfer learning](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

- For saving/loading: [save load run](https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html)

## Setup
### 1. Download data
Only need to run this once

In [1]:
!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=11IkCeaEsjysSaWgMEI3SwkSzJ1Tmxz1i&confirm=t' -O data.zip
!unzip data.zip
!rm -rf data.zip

--2023-03-02 15:43:00--  https://drive.google.com/uc?export=download&id=11IkCeaEsjysSaWgMEI3SwkSzJ1Tmxz1i&confirm=t
Loaded CA certificate '/etc/ssl/certs/ca-certificates.crt'
Resolving drive.google.com (drive.google.com)... 2607:f8b0:400a:804::200e, 142.251.33.110
Connecting to drive.google.com (drive.google.com)|2607:f8b0:400a:804::200e|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://doc-08-ag-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/ve5mu5vimk4fleqm9d2mni3mtqao9jbv/1677800550000/07078206192535146833/*/11IkCeaEsjysSaWgMEI3SwkSzJ1Tmxz1i?e=download&uuid=5c7bdc50-eb83-4212-9f31-a5f9e9f0e2c5 [following]
--2023-03-02 15:43:00--  https://doc-08-ag-docs.googleusercontent.com/docs/securesc/ha0ro937gcuc7l7deffksulhg5h7mbp1/ve5mu5vimk4fleqm9d2mni3mtqao9jbv/1677800550000/07078206192535146833/*/11IkCeaEsjysSaWgMEI3SwkSzJ1Tmxz1i?e=download&uuid=5c7bdc50-eb83-4212-9f31-a5f9e9f0e2c5
Resolving doc-08-ag-docs.googleusercontent

### 2. Import libraries

In [3]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets, models

### 3. Set up device

In [4]:
d = torch.device("cpu")
# Use a CUDA GPU if possible (Apple Silicon MPS backend technically works, but is confusingly slow)
if torch.cuda.is_available():
    d = torch.device("cuda:0")
d

device(type='cuda', index=0)

## Loading data and model
### Define constants

In [5]:
DATA_DIR = "data"
TEST_VAL_RATIO = 0.8
BATCH_SIZE = 16

### Define transformers

In [6]:
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


### Load data

In [7]:
dataset = datasets.ImageFolder(DATA_DIR, preprocess)

# Split into train and val
train_len = int(len(dataset) * TEST_VAL_RATIO)
val_len = int(len(dataset) - train_len)
train_dataset, val_dataset = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_dataset, BATCH_SIZE, True, pin_memory=True)
val_loader = DataLoader(val_dataset, BATCH_SIZE, True, pin_memory=True)


### Load Model

In [14]:
# Load in a pretrained model
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
for param in model.parameters():
    param.requires_grad = True # Set to false to leave ResNet unchanged during training
model.fc = nn.Linear(in_features=model.fc.in_features, out_features=4, bias=True)
model.to(d)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
# Load presaved model
model = torch.load("model.pth")
model.to(d)

## Train

In [17]:
# Usually need to do around 10 to get it trained
epochs = 1

# This usually needs to get shrunk in later epochs
optimizer = optim.Adam(model.parameters(), 0.0001)
loss = nn.CrossEntropyLoss()

for i in range(epochs):    
    # Do a training run and evaluate it after
    epoch_loss = 0
    epoch_cor = 0
    epoch_total = 0

    # Training mode
    model.train()

    for inputs, labels in train_loader:
        inputs = inputs.to(d)
        labels = labels.to(d)

        optimizer.zero_grad()
        # Run through model
        predictions = model(inputs)
        batch_loss = loss(predictions, labels)

        epoch_loss += batch_loss.item()
        epoch_total += len(predictions)
        epoch_cor += torch.sum(torch.argmax(predictions, 1) == labels).item()

        # Do a round of gradient descent
        batch_loss.backward()
        optimizer.step()
    epoch_acc = epoch_cor / epoch_total

    print(f"Epoch {i+1}/{epochs}: loss={epoch_loss}, acc={epoch_cor}/{epoch_total}")

    # See how well it does on the validation set
    model.eval() # Prediction mode
    val_cor = 0
    val_loss = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(d)
            labels = labels.to(d)

            optimizer.zero_grad()
            predictions = model(inputs)
            batch_loss = loss(predictions, labels)

            val_loss += batch_loss.item()
            val_total += len(predictions)
            val_cor += torch.sum(torch.argmax(predictions, 1) == labels).item()
    print(f"Validation loss={val_loss}, acc={val_cor}/{val_total}")

        

Epoch 1/1: loss=1.0248292600008426, acc=1696/1701
Validation loss=4.528443200870242, acc=408/426


## Save when finished

In [18]:
# Save model
torch.save(model, "model.pth")