In [1]:
%matplotlib inline

In [2]:
from __future__ import print_function, division
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from skimage import io, transform
from skimage.transform import rotate

import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import torch.optim as optim
import os

import wandb
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

In [3]:
wandb.init(project="kevin-behavior-cloning-training", entity="launchkart")
wandb.run.name = "kevin-train-ashley-luigi-raceway-time-trials"
wandb.run.save()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlaunchkart[0m (use `wandb login --relogin` to force relogin)




True

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [5]:
x = np.load("data/Ashley_Luigi_Raceway_Time_Trials_4_Races/X.npy")
y = np.load("data/Ashley_Luigi_Raceway_Time_Trials_4_Races/y.npy")
split_idx = int(0.8 * len(x))
x_train, x_val = x[:split_idx], x[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
print(x_train.shape[0], 'train samples')
print(x_train.shape)
print(y_train.shape)
print(y_train)

3236 train samples
(3236, 66, 200, 3)
(3236, 5)
[[-0.07180786 -0.00982666  0.          0.          0.        ]
 [-0.07180786 -0.00982666  0.          0.          0.        ]
 [-0.07180786 -0.00982666  0.          0.          0.        ]
 ...
 [-0.30645752 -0.01312256  1.          0.          1.        ]
 [-0.30645752 -0.01312256  1.          0.          1.        ]
 [-0.30645752 -0.01312256  1.          0.          1.        ]]


In [6]:
class MarioKartDataset(Dataset):
    """Nose Keypoints dataset."""

    def __init__(self, x, y):
        self.samples = []
        for i in range(len(x)):
            x_sample, y_sample = x[i], y[i]
            sample = {'obs': x_sample, 'action': y_sample}

            self.samples.append(sample)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.samples[idx]
        return sample


In [7]:
mario_kart_train_dataset = MarioKartDataset(x_train, y_train)
mario_kart_val_dataset = MarioKartDataset(x_val, y_val)
print(len(mario_kart_train_dataset))

3236


In [8]:
batch_size=64
mario_kart_train_dataloader = DataLoader(mario_kart_train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
mario_kart_val_dataloader = DataLoader(mario_kart_val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MarioKartBCAgent(nn.Module):
    def __init__(self):
        super(MarioKartBCAgent, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2), nn.ReLU(), nn.MaxPool2d(kernel_size=2))
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), nn.ReLU(), nn.MaxPool2d(kernel_size=2))
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2), nn.ReLU(), nn.MaxPool2d(kernel_size=2))
        self.fc1 = nn.Linear(25600, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 5)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.flatten(x, 1)      
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output    # return x for visualization


In [10]:
agent = MarioKartBCAgent()


trial_name = 'kevin_ashley_data_luigicircuit/'
save_path = 'saved_agents/'
checkpoint = 0
cwd = os.getcwd()
agent_dir = os.path.join(cwd,save_path)
trial_dir = os.path.join(agent_dir, trial_name)
if not os.path.exists(agent_dir):
    os.mkdir(agent_dir)
if not os.path.exists(trial_dir):
    os.mkdir(trial_dir)
# agent.load_state_dict(torch.load(save_path + trial_name + str(load_checkpoint_num)))

lr = 1e-3
epochs = 100

optimizer = optim.Adam(agent.parameters(), lr=1e-3)
criterion = nn.MSELoss()

wandb.config = {
  "learning_rate": lr,
  "checkpoint": checkpoint,
  "trial_name": trial_name,
  "epochs": epochs,
  "batch_size": batch_size
}

training_losses = []
validation_losses = []

for epoch in range(epochs):
    epoch_training_losses = []
    epoch_validation_losses = []
    for _, sample_batched in enumerate(mario_kart_train_dataloader):
        obs_batch, action_batch = np.transpose(sample_batched['obs'], (0, 3, 1, 2)), sample_batched['action']
        obs_batch, action_batch = obs_batch.float().to(device), action_batch.float().to(device)
        optimizer.zero_grad() 
        pred_action = agent(obs_batch)
        loss = criterion(pred_action, action_batch.reshape(pred_action.shape))
        loss.backward()
        optimizer.step()
        epoch_training_losses.append(loss.item())
    training_loss = np.mean(epoch_training_losses)
    print("Finished Epoch", epoch + 1, ", training loss:", training_loss)
    training_losses.append(training_loss)
    wandb.log({"training_loss": training_loss})

    #Validation per epoch
    with torch.no_grad():
        agent.eval()
        for _, sample_batched in enumerate(mario_kart_val_dataloader):
            obs_batch, action_batch = np.transpose(sample_batched['obs'], (0, 3, 1, 2)), sample_batched['action']
            obs_batch, action_batch = obs_batch.float().to(device), action_batch.float().to(device)
            pred_action = agent(obs_batch)
            val_loss = criterion(pred_action, action_batch.reshape(pred_action.shape))
            epoch_validation_losses.append(val_loss.item())
        validation_loss = np.mean(epoch_validation_losses)
        print("Epoch ", epoch +1, ", validation_loss: ", validation_loss)
        validation_losses.append(validation_loss)
        wandb.log({"validation_loss": validation_loss})
        agent.train()
    if epoch % 10 == 0:
        print("Saving agent for epoch " + str(epoch))
        torch.save(agent.state_dict(), save_path + trial_name + str(epoch))


Finished Epoch 1 , training loss: 0.06756405693058874
Epoch  1 , validation_loss:  0.13271054682823327
Finished Epoch 2 , training loss: 0.060839275299918415
Epoch  2 , validation_loss:  0.14011392627771085


KeyboardInterrupt: 

In [None]:
plt.plot(np.arange(1, epochs + 1), training_losses, label="Training Loss")
plt.plot(np.arange(1, epochs + 1), validation_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and validation loss for each epoch")
plt.show()
