In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np

In [2]:
def get_x(path, width):
    """Gets the x value from the image filename"""
    return (float(int(path.split("_")[1])) - width/2) / (width/2)

def get_y(path, height):
    """Gets the y value from the image filename"""
    return (float(int(path.split("_")[2])) - height/2) / (height/2)

class XYDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = PIL.Image.open(image_path)
        width, height = image.size
        x = float(get_x(os.path.basename(image_path), width))
        y = float(get_y(os.path.basename(image_path), height))
      
        if float(np.random.rand(1)) > 0.5:
            image = transforms.functional.hflip(image)
            x = -x
        
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, torch.tensor([x, y]).float()
    
dataset = XYDataset('dataset_xxy', random_hflips=False)

In [3]:
test_percent = 0.1
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

In [4]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [5]:
model = models.resnet18(pretrained=True)

In [6]:
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model = model.to(device)

In [7]:
NUM_EPOCHS = 70
BEST_MODEL_PATH = 'best_steering_model_xxy.pth'
best_loss = 1e9

optimizer = optim.Adam(model.parameters())

for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += float(loss)
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        test_loss += float(loss)
    test_loss /= len(test_loader)
    
    print('%f, %f' % (train_loss, test_loss))
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss

0.623652, 1.022657
0.104185, 0.088921
0.081440, 0.067857
0.075172, 0.076926
0.065926, 0.061785
0.062497, 0.061930
0.068001, 0.074041
0.062002, 0.061792
0.059112, 0.126831
0.060180, 0.084006
0.059240, 0.063995
0.059821, 0.067992
0.053553, 0.075711
0.062036, 0.078732
0.058845, 0.068901
0.051857, 0.080259
0.054566, 0.078531
0.056491, 0.102886
0.052722, 0.081024
0.055101, 0.060498
0.051313, 0.067216
0.045847, 0.073449
0.046388, 0.064081
0.049086, 0.139089
0.058581, 0.084304
0.051848, 0.083457
0.044610, 0.080573
0.051045, 0.072639
0.049164, 0.074692
0.047109, 0.097032
0.054627, 0.081024
0.051911, 0.070059
0.046542, 0.070093
0.047217, 0.091532
0.048085, 0.068654
0.047322, 0.066774
0.055886, 0.077320
0.047178, 0.094664
0.051306, 0.078095
0.060027, 0.081251
0.047725, 0.072723
0.045672, 0.065756
0.045472, 0.067192
0.046524, 0.072172
0.041617, 0.095719
0.043779, 0.068636
0.045941, 0.083112
0.044877, 0.071470
0.047063, 0.081103
0.045150, 0.081841
0.040706, 0.082896
0.042143, 0.078370
0.046348, 0.