<a href="https://colab.research.google.com/github/devashishbotre/Autonomous-Lane-Detector/blob/main/scripts/Lane_Detection_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
manideep1108_tusimple_path = kagglehub.dataset_download('manideep1108/tusimple')

print('Data source import complete.')


In [None]:
import os
import json
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from scipy.interpolate import CubicSpline
from sklearn.cluster import DBSCAN
import torchvision.models as models
from model_module import RESA, Decoder, LaneNet

# Constants
IMG_HEIGHT = 720
IMG_WIDTH = 1280
NUM_CLASSES = 2

In [None]:
class TuSimpleDataset(Dataset):
    def __init__(self, json_files, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.annotations = []
        total_annotations = 0
        for json_file in json_files:
            if not os.path.exists(json_file):
                print(f"Warning: JSON file not found: {json_file}")
                continue
            with open(json_file, 'r') as f:
                lines = f.readlines()
                total_annotations += len(lines)
                for line in lines:
                    ann = json.loads(line)
                    img_path = os.path.join(self.img_dir, ann['raw_file'])
                    if os.path.exists(img_path):
                        self.annotations.append(ann)
                    else:
                        print(f"Warning: Image not found: {img_path}")
        print(f"Total annotations in JSON: {total_annotations}")
        print(f"Valid images found: {len(self.annotations)}")

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        ann = self.annotations[idx]
        img_path = os.path.join(self.img_dir, ann['raw_file'])

        # Load and resize image
        image = cv2.imread(img_path)
        if image is None:
            print(f"Error: Failed to load image: {img_path}")
            image = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3), dtype=np.uint8)
        image = cv2.resize(image, (IMG_WIDTH, IMG_HEIGHT))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = np.zeros((IMG_HEIGHT, IMG_WIDTH), dtype=np.uint8)
        for lane in ann['lanes']:
            points = [(x, y) for x, y in zip(lane, ann['h_samples']) if x != -2 and 0 <= y < IMG_HEIGHT and 0 <= x < IMG_WIDTH]
            if len(points) > 1:
                points = np.array(points, dtype=np.int32)
                cv2.polylines(mask, [points], False, 1, 1)

        image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
        mask = torch.from_numpy(mask).long()

        return image, mask, ann

def collate_fn(batch):
    images, masks, anns = zip(*batch)
    images = torch.stack(images, 0)
    masks = torch.stack(masks, 0)
    return images, masks, anns

In [None]:
def train_model():
    batch_size = 4
    num_epochs = 20
    learning_rate = 0.0002
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LaneNet(num_classes=NUM_CLASSES).to(device)

    weights = torch.tensor([1.0, 25.0]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weights)

    tv_weight = 0.3
    narrowness_weight = 0.5
    edge_weight = 0.2

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    data_dir = '/kaggle/input/tusimple/TUSimple/train_set'
    json_files = [os.path.join(data_dir, f) for f in ['label_data_0313.json',
                                                     'label_data_0531.json',
                                                     'label_data_0601.json']]
    img_dir = data_dir
    train_dataset = TuSimpleDataset(json_files, img_dir)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, masks, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            seg_loss = criterion(outputs, masks)

            def total_variation_loss(x):
                batch_size = x.size(0)
                h_x = x.size(2)
                w_x = x.size(3)
                dh = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]).sum() / (batch_size * h_x * w_x)
                dw = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]).sum() / (batch_size * h_x * w_x)
                return (dh + dw) / 2

            def narrowness_penalty(x):
                probs = torch.softmax(x, dim=1)
                lane_prob = probs[:, 1, :, :]
                kernel = torch.ones(1, 1, 3, 3).to(device) / 9
                blurred = F.conv2d(lane_prob.unsqueeze(1), kernel, padding=1)

                blurred = torch.clamp(blurred, min=1e-6, max=1.0 - 1e-6)
                lane_prob = torch.clamp(lane_prob, min=1e-6, max=1.0 - 1e-6)

                width_penalty = torch.pow(blurred.squeeze(1), 2).mean()
                variance = torch.pow(lane_prob - blurred.squeeze(1), 2).mean()
                return width_penalty + variance

            def edge_enhancement_loss(x):
                probs = torch.softmax(x, dim=1)
                lane_prob = probs[:, 1, :, :]

                sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).to(device)
                sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).to(device)

                sobel_x = sobel_x.view(1, 1, 3, 3)
                sobel_y = sobel_y.view(1, 1, 3, 3)

                edge_x = F.conv2d(lane_prob.unsqueeze(1), sobel_x, padding=1)
                edge_y = F.conv2d(lane_prob.unsqueeze(1), sobel_y, padding=1)

                edges = torch.sqrt(edge_x.pow(2) + edge_y.pow(2) + 1e-8)
                return -torch.clamp(edges.mean(), max=10.0)

            tv_loss = total_variation_loss(torch.softmax(outputs, dim=1))
            narrowness_loss = narrowness_penalty(outputs)
            edge_loss = edge_enhancement_loss(outputs)

            total_loss = seg_loss + tv_weight * tv_loss + narrowness_weight * narrowness_loss + edge_weight * edge_loss

            if torch.isnan(total_loss):
                print(f"NaN detected: seg_loss={seg_loss.item()}, tv_loss={tv_loss.item()}, "
                      f"narrowness_loss={narrowness_loss.item()}, edge_loss={edge_loss.item()}")
                break

            optimizer.zero_grad()
            total_loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            running_loss += total_loss.item()

        if torch.isnan(total_loss):
            break

        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f"lane_model_epoch_{epoch+1}.pth")

    torch.save(model.state_dict(), "lane_model_final.pth")
    print("Training completed!")
    return model, device

In [None]:
model,device=train_model()

Total annotations in JSON: 3626
Valid images found: 3626


Epoch 1/20: 100%|██████████| 907/907 [06:13<00:00,  2.43it/s]


Epoch [1/20], Loss: 0.0869


Epoch 2/20: 100%|██████████| 907/907 [06:12<00:00,  2.43it/s]


Epoch [2/20], Loss: 0.0689


Epoch 3/20: 100%|██████████| 907/907 [06:12<00:00,  2.43it/s]


Epoch [3/20], Loss: 0.0657


Epoch 4/20: 100%|██████████| 907/907 [06:13<00:00,  2.43it/s]


Epoch [4/20], Loss: 0.0633


Epoch 5/20: 100%|██████████| 907/907 [06:14<00:00,  2.42it/s]


Epoch [5/20], Loss: 0.0605


Epoch 6/20: 100%|██████████| 907/907 [06:14<00:00,  2.42it/s]


Epoch [6/20], Loss: 0.0580


Epoch 7/20: 100%|██████████| 907/907 [06:13<00:00,  2.43it/s]


Epoch [7/20], Loss: 0.0550


Epoch 8/20: 100%|██████████| 907/907 [06:13<00:00,  2.43it/s]


Epoch [8/20], Loss: 0.0518


Epoch 9/20: 100%|██████████| 907/907 [06:11<00:00,  2.44it/s]


Epoch [9/20], Loss: 0.0486


Epoch 10/20:  77%|███████▋  | 699/907 [04:45<01:24,  2.46it/s]