In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np

In [2]:
def get_x(path, width):
    """Gets the x value from the image filename"""
    return (float(int(path.split("_")[1])) - width/2) / (width/2)

def get_y(path, height):
    """Gets the y value from the image filename"""
    return (float(int(path.split("_")[2])) - height/2) / (height/2)

class XYDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = PIL.Image.open(image_path)
        width, height = image.size
        x = float(get_x(os.path.basename(image_path), width))
        y = float(get_y(os.path.basename(image_path), height))
      
        if float(np.random.rand(1)) > 0.5:
            image = transforms.functional.hflip(image)
            x = -x
        
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, torch.tensor([x, y]).float()
    
dataset = XYDataset('dataset_xy_new', random_hflips=False)

In [3]:
test_percent = 0.1
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

In [4]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [5]:
model = models.resnet18(pretrained=True)

In [6]:
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model = model.to(device)

In [7]:
NUM_EPOCHS = 70
BEST_MODEL_PATH = 'best_steering_model_xy_new_v2.pth'
best_loss = 1e9

optimizer = optim.Adam(model.parameters())

round = 1
for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += float(loss)
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        test_loss += float(loss)
    test_loss /= len(test_loader)
    print(round)
    round += 1
    print('%f, %f' % (train_loss, test_loss))
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss
print("Finish!!")

1
0.162133, 0.054802
2
0.024632, 0.015023
3
0.018107, 0.041451
4
0.015714, 0.019242
5
0.015550, 0.021598
6
0.010245, 0.015614
7
0.008105, 0.020794
8
0.007742, 0.015450
9
0.007195, 0.013075
10
0.006996, 0.012671
11
0.005473, 0.009885
12
0.006741, 0.010676
13
0.005742, 0.009734
14
0.004528, 0.011392
15
0.004517, 0.013478
16
0.004048, 0.015308
17
0.004316, 0.009454
18
0.004105, 0.014004
19
0.003855, 0.009700
20
0.002997, 0.020603
21
0.004177, 0.011823
22
0.003857, 0.009080
23
0.002590, 0.010065
24
0.004390, 0.010153
25
0.003530, 0.008671
26
0.003061, 0.008547
27
0.003053, 0.017307
28
0.003354, 0.007661
29
0.002190, 0.007408
30
0.004152, 0.011062
31
0.002874, 0.009388
32
0.002069, 0.008641
33
0.002307, 0.009727
34
0.001826, 0.008477
35
0.002741, 0.009628
36
0.003296, 0.006090
37
0.002857, 0.007627
38
0.002293, 0.008683
39
0.003057, 0.009779
40
0.002227, 0.008292
41
0.003556, 0.018136
42
0.003321, 0.010097
43
0.001929, 0.014586
44
0.002435, 0.010203
45
0.001590, 0.009817
46
0.002206, 0.0078