# Road Follower - Train Model

In this notebook we will train a neural network to take an input image, and output a set of x, y values corresponding to a target.

We will be using PyTorch deep learning framework to train ResNet18 neural network architecture model for road follower application.

In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np

### Download and extract data

Before you start, you should upload the ``road_following_<Date&Time>.zip`` file that you created in the ``data_collection.ipynb`` notebook on the robot. 

> If you're training on the JetBot you collected data on, you can skip this!

You should then extract this dataset by calling the command below:

In [2]:
!unzip -o -q road_following_dataset_big3.zip

You should see a folder named ``dataset_all`` appear in the file browser.

### Create Dataset Instance

Here we create a custom ``torch.utils.data.Dataset`` implementation, which implements the ``__len__`` and ``__getitem__`` functions.  This class
is responsible for loading images and parsing the x, y values from the image filenames.  Because we implement the ``torch.utils.data.Dataset`` class,
we can use all of the torch data utilities :)

We hard coded some transformations (like color jitter) into our dataset.  We made random horizontal flips optional (in case you want to follow a non-symmetric path, like a road
where we need to 'stay right').  If it doesn't matter whether your robot follows some convention, you could enable flips to augment the dataset.

In [3]:
def get_x(path, width):
    """Gets the x value from the image filename"""
    return (float(int(path.split("_")[1])) - width/2) / (width/2)

def get_y(path, height):
    """Gets the y value from the image filename"""
    return (float(int(path.split("_")[2])) - height/2) / (height/2)

class XYDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = PIL.Image.open(image_path)
        width, height = image.size
        x = float(get_x(os.path.basename(image_path), width))
        y = float(get_y(os.path.basename(image_path), height))
      
        if float(np.random.rand(1)) > 0.5:
            image = transforms.functional.hflip(image)
            x = -x
        
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, torch.tensor([x, y]).float()
    
dataset = XYDataset('dataset_big3', random_hflips=False)

### Split dataset into train and test sets
Once we read dataset, we will split data set in train and test sets. In this example we split train and test a 90%-10%. The test set will be used to verify the accuracy of the model we train.

In [4]:
test_percent = 0.1
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

### Create data loaders to load data in batches

We use ``DataLoader`` class to load data in batches, shuffle data and allow using multi-subprocesses. In this example we use batch size of 64. Batch size will be based on memory available with your GPU and it can impact accuracy of the model.

In [5]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

### Define Neural Network Model 

We use ResNet-18 model available on PyTorch TorchVision. 

In a process called transfer learning, we can repurpose a pre-trained model (trained on millions of images) for a new task that has possibly much less data available.


More details on ResNet-18 : https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py

More Details on Transfer Learning: https://www.youtube.com/watch?v=yofjFQddwHE 

In [6]:
model = models.resnet18(pretrained=True)

ResNet model has fully connect (fc) final layer with 512 as ``in_features`` and we will be training for regression thus ``out_features`` as 1

Finally, we transfer our model for execution on the GPU

In [7]:
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model = model.to(device)

### Train Regression:

We train for 50 epochs and save best model if the loss is reduced. 

In [8]:
NUM_EPOCHS = 70
BEST_MODEL_PATH = 'best_steering_model_big3.pth'
best_loss = 1e9

optimizer = optim.Adam(model.parameters())

for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += float(loss)
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        test_loss += float(loss)
    test_loss /= len(test_loader)
    
    print('%f, %f' % (train_loss, test_loss))
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss

0.314137, 0.081513
0.059544, 0.056107
0.048966, 0.023799
0.044848, 0.031278
0.039415, 0.021861
0.040538, 0.042709
0.063919, 0.071817
0.043979, 0.017437
0.033876, 0.021551
0.028916, 0.022814
0.028268, 0.018644
0.030030, 0.032980
0.036405, 0.021336
0.028095, 0.032037
0.025023, 0.018886
0.017964, 0.013048
0.017891, 0.018560
0.018345, 0.041573
0.018858, 0.020772
0.018322, 0.111703
0.013922, 0.020800
0.018567, 0.099899
0.018345, 0.032071
0.010991, 0.015399
0.011785, 0.085936
0.010664, 0.019307
0.008296, 0.021757
0.008489, 0.022266
0.007849, 0.019367
0.008162, 0.030640
0.005713, 0.025770
0.006052, 0.017802
0.005254, 0.012415
0.005677, 0.016905
0.004695, 0.015901
0.005439, 0.018588
0.004854, 0.014538
0.004646, 0.019742
0.007998, 0.028954
0.007586, 0.015701
0.006206, 0.016242
0.007435, 0.019836
0.012934, 0.018923
0.006483, 0.018122
0.007157, 0.014683
0.004074, 0.016398
0.004142, 0.014583
0.002962, 0.014948
0.004571, 0.016957
0.006277, 0.018138
0.003985, 0.014507
0.002987, 0.014761
0.004922, 0.

Once the model is trained, it will generate ``best_steering_model_xy.pth`` file which you can use for inferencing in the live demo notebook.

If you trained on a different machine other than JetBot, you'll need to upload this to the JetBot to the ``road_following`` example folder.

In [9]:
# import os
# from PIL import Image, UnidentifiedImageError

# image_dir = 'dataset3_xy'
# deleted_files = 0

# for filename in os.listdir(image_dir):
#     if filename.endswith('.jpg'):
#         image_path = os.path.join(image_dir, filename)
#         try:
#             # 이미지 파일 열기 (검사)
#             with Image.open(image_path) as img:
#                 img.verify()  # 실제 이미지가 유효한지 검사
#         except (UnidentifiedImageError, IOError) as e:
#             print(f"❌ 삭제됨: {filename} — 이유: {e}")
#             os.remove(image_path)
#             # 관련 JSON도 같이 삭제
#             json_path = image_path.replace('.jpg', '.json')
#             if os.path.exists(json_path):
#                 os.remove(json_path)
#             deleted_files += 1

# print(f"총 {deleted_files}개 손상 이미지(.jpg) 및 레이블(.json) 삭제 완료.")

❌ 삭제됨: xy_067_043_be40b40c-4dd8-11f0-879e-a46bb606808d.jpg — 이유: cannot identify image file 'dataset3_xy/xy_067_043_be40b40c-4dd8-11f0-879e-a46bb606808d.jpg'
❌ 삭제됨: xy_157_077_cf0efbf4-4dd8-11f0-879e-a46bb606808d.jpg — 이유: cannot identify image file 'dataset3_xy/xy_157_077_cf0efbf4-4dd8-11f0-879e-a46bb606808d.jpg'
❌ 삭제됨: xy_201_108_d0d152ca-4dd8-11f0-879e-a46bb606808d.jpg — 이유: cannot identify image file 'dataset3_xy/xy_201_108_d0d152ca-4dd8-11f0-879e-a46bb606808d.jpg'
❌ 삭제됨: xy_116_047_c0962962-4dd8-11f0-879e-a46bb606808d.jpg — 이유: cannot identify image file 'dataset3_xy/xy_116_047_c0962962-4dd8-11f0-879e-a46bb606808d.jpg'
❌ 삭제됨: xy_140_069_5a1f0c26-4dd8-11f0-be5f-a46bb606808d.jpg — 이유: cannot identify image file 'dataset3_xy/xy_140_069_5a1f0c26-4dd8-11f0-be5f-a46bb606808d.jpg'
❌ 삭제됨: xy_114_090_57f72776-4dd8-11f0-be5f-a46bb606808d.jpg — 이유: cannot identify image file 'dataset3_xy/xy_114_090_57f72776-4dd8-11f0-be5f-a46bb606808d.jpg'
❌ 삭제됨: xy_110_055_c244ca8e-4dd8-11f0-879e-a46bb60680