In [1]:
%matplotlib inline

In [2]:
#Import all our dependencies
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from skimage import io, transform
from skimage.transform import rotate

import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

import torch.optim as optim
import os

import wandb
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

<matplotlib.pyplot._IonContext at 0x1f89cd33310>

In [3]:
#Initialize weight and biases trial so we can monitor our trials
#Make sure to rerun this cell before every trial!
wandb.init(project="tony-behavior-cloning-training", entity="launchkart")
#TODO: Name this something descriptive, and make sure to rename the trial name for different trials!
wandb.run.name = "tony-trial-1"
wandb.run.save()

[34m[1mwandb[0m: Currently logged in as: [33mtonyxin[0m (use `wandb login --relogin` to force relogin)




True

In [4]:
#Create our torch device, which will be cuda if GPU and cpu if just using CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [6]:
x = np.load("data/X.npy")
y = np.load("data/y.npy")

#If you are loading more than one npy file, you will need to load all of them and then use vstack to concatenate them all together 
# x = np.vstack(x, ...)
# y = np.vstack(y, ...)

#Split our training data into a training set and validation set
split_idx = int(0.8 * len(x))
x_train, x_val = x[:split_idx], x[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]

#Look at our training set shapes
print(x_train.shape)
print(y_train.shape)

(385, 66, 200, 3)
(385, 5)


In [7]:
#Torch has a Datasets class that makes it easy to interface with datasets when training
#This is mostly filled in for you, although we will add data augmentation later
class MarioKartDataset(Dataset):
    """Nose Keypoints dataset."""
    
    #Create a list of samples, where each sample is a tensor with an observation (image) and an action (vector of controller input)
    def __init__(self, x, y):
        self.samples = []
        for i in range(len(x)):
            x_sample, y_sample = x[i], y[i]
            sample = {'obs': x_sample, 'action': y_sample}

            self.samples.append(sample)

    def __len__(self):
        return len(self.samples)
    
    #Gets the item at index idx from our samples
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.samples[idx]
        
        #Ignore this todo until we've covered it in project meeting: 
        #TODO: Apply data augmentation here
        
        return sample


In [8]:
#Create a training and validation dataset
mario_kart_train_dataset = MarioKartDataset(x_train, y_train)
mario_kart_val_dataset = MarioKartDataset(x_val, y_val)
print(len(mario_kart_train_dataset))

385


In [9]:
#Torch uses DataLoaders to handle shuffling datasets and loading batches of data
#You can experiment with different batch sizes here
#Num_workers can be set to 1 if using GPU
batch_size=64
mario_kart_train_dataloader = DataLoader(mario_kart_train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
mario_kart_val_dataloader = DataLoader(mario_kart_val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [23]:
#TODO: Define your agent architecture here
#Take a look at our MNIST tutorial for reference.
#Recommended architecture is covered in meeting slides
#Note: The in dimension for the first fully connected layer is very hard to calculate, can prob just run first and determine based on the error message
#Note: Use Sequential to keep your code organized
class MarioKartBCAgent(nn.Module):
    def __init__(self):
        super(MarioKartBCAgent, self).__init__()
        
        #Feature Extraction Module: Conv layers (check out Conv2d, BatchNorm, Relu, Maxpool)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=1, padding=2), nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=24, out_channels=36, kernel_size=5, stride=1, padding=2), nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=36, out_channels=48, kernel_size=5, stride=1, padding=2), nn.ReLU())
        self.conv4 = nn.Sequential(nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU())
        self.conv5 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2))
        
        #Inference Module: Fully connected layers (check out Linear, Dropout, Relu)
        self.fc1 = nn.Sequential(nn.Linear(211200, 100), nn.Dropout(0.2))
        self.fc2 = nn.Sequential(nn.Linear(100, 50), nn.Dropout(0.2))
        self.fc3 = nn.Sequential(nn.Linear(50, 10), nn.Dropout(0.2))
        self.fc4 = nn.Sequential(nn.Linear(10, 5))
        
        
    def forward(self, x):
        #Pass x through conv layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        
        #Flatten x to prepare for passing into linear
        x = torch.flatten(x, 1) 
        
        #Pass x through linear layers
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        x = self.fc4(x)
        
        #Name variable output on last layer
        output = x
        
        return output    # return x for visualization


In [None]:
#This is the cell where we define our training loop
#It is mostly implemented already, but make sure to read through all of it to understand what is happening

agent = MarioKartBCAgent()

#Change the trial_name to something descriptive, based on what data you are using for training
trial_name = 'kevin_luigicircuit_timetrials/'
save_path = 'saved_agents/'


cwd = os.getcwd()
agent_dir = os.path.join(cwd,save_path)
trial_dir = os.path.join(agent_dir, trial_name)
if not os.path.exists(agent_dir):
    os.mkdir(agent_dir)
if not os.path.exists(trial_dir):
    os.mkdir(trial_dir)
    
#Change the checkpoint number if you are loading a model from a checkpoint for further training
checkpoint = 0
# agent.load_state_dict(torch.load(save_path + trial_name + str(checkpoint)))

#Define your learning rate and number of epochs to train for
lr = 1e-3
epochs = 100

#Initialize your optimizer and loss
optimizer = optim.Adam(agent.parameters(), lr=1e-3)
criterion = nn.MSELoss()

#This config is for WandB to track the hyperparameters we used for this run
wandb.config = {
  "learning_rate": lr,
  "checkpoint": checkpoint,
  "trial_name": trial_name,
  "epochs": epochs,
  "batch_size": batch_size
}

training_losses = []
validation_losses = []

#Iterate through number of epochs
for epoch in range(epochs):
    epoch_training_losses = []
    epoch_validation_losses = []
    
    #Iterate through our dataloader batch by batch
    for _, sample_batched in enumerate(mario_kart_train_dataloader):
        
        #Creates an observation and action batch
        obs_batch, action_batch = np.transpose(sample_batched['obs'], (0, 3, 1, 2)), sample_batched['action']
        obs_batch, action_batch = obs_batch.float().to(device), action_batch.float().to(device)
        
        #Optimizer performs a parameter update step based on loss calculated between predicted and grond truth actions
        optimizer.zero_grad() 
        pred_action = agent(obs_batch)
        loss = criterion(pred_action, action_batch.reshape(pred_action.shape))
        loss.backward()
        optimizer.step()
        
        epoch_training_losses.append(loss.item())
    
    #Keep track of training losses every epoch and log with WandB
    training_loss = np.mean(epoch_training_losses)
    print("Finished Epoch", epoch + 1, ", training loss:", training_loss)
    training_losses.append(training_loss)
    wandb.log({"training_loss": training_loss})

    #After every epoch, test agent performance on validation set, which it has not trained on
    #This is to make sure we aren't overfitting too hard, and are generalizing well
    with torch.no_grad():
        #Set agent to evaluation mode so we aren't calculating gradients
        agent.eval()
        
        #Iterate through our validation dataloder batch by batch
        for _, sample_batched in enumerate(mario_kart_val_dataloader):
            obs_batch, action_batch = np.transpose(sample_batched['obs'], (0, 3, 1, 2)), sample_batched['action']
            obs_batch, action_batch = obs_batch.float().to(device), action_batch.float().to(device)
            pred_action = agent(obs_batch)
            val_loss = criterion(pred_action, action_batch.reshape(pred_action.shape))
            epoch_validation_losses.append(val_loss.item())
        
        #Keep track of validationlosses every epoch and log with WandB
        validation_loss = np.mean(epoch_validation_losses)
        print("Epoch ", epoch +1, ", validation_loss: ", validation_loss)
        validation_losses.append(validation_loss)
        wandb.log({"validation_loss": validation_loss})
        agent.train()
    
    #Save agent checkpoint every 10 epochs
    if epoch == 0 or epoch % 10 == 9:
        torch.save(agent.state_dict(), save_path + trial_name + str(epoch))


Finished Epoch 1 , training loss: 1085.3543853759766
Epoch  1 , validation_loss:  1299.8576049804688
Finished Epoch 2 , training loss: 977.1715785435268
Epoch  2 , validation_loss:  1009.7860717773438
Finished Epoch 3 , training loss: 946.4096494402204
Epoch  3 , validation_loss:  884.3128356933594
Finished Epoch 4 , training loss: 920.548713684082
Epoch  4 , validation_loss:  1096.4378967285156
Finished Epoch 5 , training loss: 1004.1973920549665
Epoch  5 , validation_loss:  1065.2937622070312
Finished Epoch 6 , training loss: 921.2420817783901
Epoch  6 , validation_loss:  1155.7324523925781
Finished Epoch 7 , training loss: 944.1373040335519
Epoch  7 , validation_loss:  1104.9945678710938
Finished Epoch 8 , training loss: 941.8939377920968
Epoch  8 , validation_loss:  1024.2730102539062
Finished Epoch 9 , training loss: 946.2451531546457
Epoch  9 , validation_loss:  1061.5072021484375
Finished Epoch 10 , training loss: 932.2612947736468
Epoch  10 , validation_loss:  948.0646362304688