# Train the RPS model

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kjy5/cv-rock-paper-scissors/blob/main/scripts/train.ipynb)

# TEMP: resources:
- For transfer learning see [transfer learning](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

- For saving/loading: [save load run](https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html)

## Setup
### 1. Download data
Only need to run this once

In [None]:
!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=11IkCeaEsjysSaWgMEI3SwkSzJ1Tmxz1i&confirm=t' -O data.zip
!unzip data.zip
!rm -rf data.zip

### 2. Import libraries

In [None]:
import torch
from torchvision import transforms, datasets
from torchvision.models import resnet18, ResNet18_Weights
import cv2 as cv
import os
from collections import deque
from PIL import Image
import random
import numpy as np

### 3. Set up device

In [None]:
d = torch.device("cpu")
# Use a CUDA GPU if possible (Apple Silicon MPS backend technically works, but is confusingly slow)
if torch.cuda.is_available():
    d = torch.device("cuda:0")

## Loading data and model
### Define constants

In [None]:
DATA_DIR = "data"
CLASSES = ["rock", "paper", "scissors", "clutter"]
CLUTTER_IMAGE_PREFIX = "test_"
CLUTTER_COUNT = 10000
NUM_FRAMES = 120

### Define transformers

In [None]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

### Load data

In [11]:
# Define dataloaders
train_dataloaders = {}
val_dataloaders = {}

# Load 240 random clutter images
clutter_frame_ids = random.sample(range(CLUTTER_COUNT), NUM_FRAMES * 2)
FRAME_SPLIT = int(NUM_FRAMES * 1.6)

# Load train images (80%)
train_data = deque()
for img_id in clutter_frame_ids[:FRAME_SPLIT]:
    cur_image = Image.open(os.path.join(DATA_DIR, "clutter", f"{CLUTTER_IMAGE_PREFIX}{img_id}.JPEG"))
    # Ensure using RGB
    if cur_image.mode != "RGB":
        cur_image = cur_image.convert("RGB")
    train_data.append(train_transform(cur_image))
    cur_image.close()
train_dataloaders["clutter"] = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=4)

# Load validation images (20%)
val_data = deque()
for img_id in clutter_frame_ids[FRAME_SPLIT:]:
    cur_image = Image.open(os.path.join(DATA_DIR, "clutter", f"{CLUTTER_IMAGE_PREFIX}{img_id}.JPEG"))
    # Ensure using RGB
    if cur_image.mode != "RGB":
        cur_image = cur_image.convert("RGB")
    val_data.append(val_transform(cur_image))
    cur_image.close()
val_dataloaders["clutter"] = torch.utils.data.DataLoader(val_data, batch_size=4, shuffle=True, num_workers=4)

### Load Model

In [None]:
# Load in a pretrained model
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.to(d)

In [None]:
# Load presaved model
model = torch.load("../model.pth")
model.to(d)

# Tasks:
Read whole training videos into memory

Read clutter dataset into memory

## Training


### Load in training data
#### 1. Rock

In [None]:
video_file = cv.VideoCapture("../data/rock/rock.mp4")

# Count frames
num_frames = 0
while True:
    ret, frame = video_file.read()
    if not ret:
        break
    num_frames += 1

frames = torch.empty(num_frames, 3, 224, 224)
frames_index = 0
video_file = cv.VideoCapture("../data/rock/rock.mp4")
while True:
    ret, frame = video_file.read()
    if not ret:
        break

    # Rearrange the channels
    image = frame[:, :, [2, 1, 0]]

    # Run preprocess
    image_tensor = preprocess(image).to(d)

    # Append to processed frames
    frames[frames_index] = image_tensor

print(f'Loaded {num_frames} frames')

In [None]:
# Save model
torch.save(model, "../model.pth")