In [1]:
import glob
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import signal
from sklearn.decomposition import PCA
from random import randint

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

%matplotlib inline

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
BATCH_SIZE = 32
NO_MAX_FRAMES = [50, 100, 200, 400]
NO_FRAMES_PREDICTED_AHEAD = [1, 6, 18, 36]
SIZE_FRAGMENT = [3]

#### Model

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        no_filters = 2048
        self.conv1 = nn.Conv2d(1, no_filters, kernel_size=(3, 3))
        self.linear1 = nn.Linear(no_filters, 32)
        self.linear2 = nn.Linear(32, 1)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = x.view(x.shape[0], -1)
        x = self.linear1(x)
        x = torch.sigmoid(self.linear2(x))
        return x

#### Data

In [5]:
def train(color, no_frames, no_frames_predicted, size):
    train_x = np.load(f'data/frames_cropped/train_x_{color}_noFrames{no_frames}_noFramesPredicedAhead{no_frames_predicted}_size{size}.npy')
    train_y = np.load(f'data/frames_cropped/train_y_{color}_noFrames{no_frames}_noFramesPredicedAhead{no_frames_predicted}_size{size}.npy')
    test_x = np.load(f'data/frames_cropped/test_x_{color}_noFrames{no_frames}_noFramesPredicedAhead{no_frames_predicted}_size{size}.npy')
    test_y = np.load(f'data/frames_cropped/test_y_{color}_noFrames{no_frames}_noFramesPredicedAhead{no_frames_predicted}_size{size}.npy')

    x_train = torch.tensor(train_x).float()
    y_train = torch.tensor(train_y).float()
    x_test = torch.tensor(train_x).float()
    y_test = torch.tensor(train_y).float()

    train_set = torch.utils.data.TensorDataset(x_train, y_train)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    test_set = torch.utils.data.TensorDataset(x_train, y_train)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)
    
    model = CNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
    loss_func = torch.nn.MSELoss()

    loss_tr_s, loss_ts_s = [], []
    for epoch in range(20):
        # Train
        model.train()
        loss_tr = []
        for step, (raw, out) in enumerate(train_loader):        
            raw = raw.squeeze(dim=0).to(device)
            out = out.squeeze(dim=0).to(device)

            p = model(raw).reshape(-1)        

            batch_loss = loss_func(out, p)
            loss_tr.append(batch_loss.detach().item())
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
        
        # Test
        model.eval()
        loss_ts = []
        with torch.no_grad():
            for step, (raw, out) in enumerate(test_loader):        
                raw = raw.squeeze(dim=0).to(device)
                out = out.squeeze(dim=0).to(device)

                p = model(raw).reshape(-1)        
                batch_loss = loss_func(out, p)
                loss_ts.append(batch_loss.detach().item())

        loss_tr_s.append(np.mean(loss_tr))
        loss_ts_s.append(np.mean(loss_ts))
#         plt.plot(loss_tr_s, label='train')
#         plt.plot(loss_ts_s, label='test')
#         plt.legend()
#         plt.show()

    print(color, no_frames, no_frames_predicted, size, 
          round(loss_tr_s[-1], 4), round(loss_ts_s[-1], 4))

    torch.save(
        model.state_dict(),
        f'checkpoints/{color}/model_noFrames{no_frames}_noFramesPredicted{no_frames_predicted}_size{size}.pt'
    )

In [6]:
for no_frames in NO_MAX_FRAMES:
    for no_frames_predicted in NO_FRAMES_PREDICTED_AHEAD:
        for size in SIZE_FRAGMENT:
            train('red', no_frames, no_frames_predicted, size)
            train('green', no_frames, no_frames_predicted, size)
            train('blue', no_frames, no_frames_predicted, size)

red 50 1 3 0.0075 0.0074
green 50 1 3 0.0081 0.008
blue 50 1 3 0.0077 0.0077
red 50 6 3 0.0356 0.035
green 50 6 3 0.0394 0.0393
blue 50 6 3 0.0366 0.0362
red 50 18 3 0.0849 0.0861
green 50 18 3 0.093 0.0923
blue 50 18 3 0.0867 0.0862
red 50 36 3 0.1383 0.1384
green 50 36 3 0.1471 0.146
blue 50 36 3 0.138 0.1372
red 100 1 3 0.0055 0.0053
green 100 1 3 0.0061 0.006
blue 100 1 3 0.006 0.0064
red 100 6 3 0.0256 0.0253
green 100 6 3 0.0282 0.0282
blue 100 6 3 0.0275 0.0283
red 100 18 3 0.0581 0.0579
green 100 18 3 0.0624 0.0621
blue 100 18 3 0.0614 0.0615
red 100 36 3 0.0974 0.0968
green 100 36 3 0.103 0.1024
blue 100 36 3 0.1015 0.1006
red 200 1 3 0.0061 0.006
green 200 1 3 0.0056 0.0056
blue 200 1 3 0.0062 0.0061
red 200 6 3 0.0263 0.0264
green 200 6 3 0.0254 0.0251
blue 200 6 3 0.0264 0.0262
red 200 18 3 0.0595 0.0592
green 200 18 3 0.0581 0.0579
blue 200 18 3 0.0603 0.0596
red 200 36 3 0.0925 0.0919
green 200 36 3 0.0894 0.0927
blue 200 36 3 0.0959 0.0963
red 400 1 3 0.0069 0.0069
green