In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.optim as optim
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt

class KITTIDataset(Dataset):
    def __init__(self, images_dir, masks_dir=None, transform=None, mask_transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.mask_transform = mask_transform
        self.image_names = []
        self.mask_names = []
        
        if masks_dir:
            
            for img_name in os.listdir(images_dir):
                base_name1, base_name2 = img_name.split("_")
                
                mask_name = f"{base_name1}_lane_{base_name2}"
                
                if os.path.exists(os.path.join(masks_dir, mask_name)):
                    self.image_names.append(img_name)
                    self.mask_names.append(mask_name)
        else:
            
            self.image_names = os.listdir(images_dir)
    
        
    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        if self.masks_dir:
            mask_name = self.mask_names[idx]
            mask_path = os.path.join(self.masks_dir, mask_name)
            mask = Image.open(mask_path).convert("L")  
            if self.mask_transform:
                mask = self.mask_transform(mask)
            return image, mask
        else:
            return image, img_name  


image_transforms = transforms.Compose([
    transforms.Resize((256, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  
                         std=[0.229, 0.224, 0.225])
])

mask_transforms = transforms.Compose([
    transforms.Resize((256, 512)),
    transforms.ToTensor()
])


train_dataset = KITTIDataset(
    images_dir="dataset/data_road/training/image_2",
    masks_dir="dataset/data_road/training/gt_image_2",
    transform=image_transforms,
    mask_transform=mask_transforms
)

test_dataset = KITTIDataset(
    images_dir="dataset/data_road/testing/image_2",
    masks_dir=None,
    transform=image_transforms
)


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)


In [None]:
class UNet(nn.Module):
    def __init__(self, n_classes=1):
        super(UNet, self).__init__()
        
        def CBR(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        
        self.enc1 = CBR(3, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.enc4 = CBR(256, 512)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.center = CBR(512, 1024)
        
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = CBR(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = CBR(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = CBR(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = CBR(128, 64)
        
        self.final = nn.Conv2d(64, n_classes, 1)
        
    def forward(self, x):
        enc1 = self.enc1(x)      # [B, 64, H, W]
        enc2 = self.enc2(self.pool(enc1))  # [B, 128, H/2, W/2]
        enc3 = self.enc3(self.pool(enc2))  # [B, 256, H/4, W/4]
        enc4 = self.enc4(self.pool(enc3))  # [B, 512, H/8, W/8]
        
        center = self.center(self.pool(enc4))  # [B, 1024, H/16, W/16]
        
        dec4 = self.up4(center)  # [B, 512, H/8, W/8]
        dec4 = torch.cat([dec4, enc4], dim=1)  # [B, 1024, H/8, W/8]
        dec4 = self.dec4(dec4)  # [B, 512, H/8, W/8]
        
        dec3 = self.up3(dec4)  # [B, 256, H/4, W/4]
        dec3 = torch.cat([dec3, enc3], dim=1)  # [B, 512, H/4, W/4]
        dec3 = self.dec3(dec3)  # [B, 256, H/4, W/4]
        
        dec2 = self.up2(dec3)  # [B, 128, H/2, W/2]
        dec2 = torch.cat([dec2, enc2], dim=1)  # [B, 256, H/2, W/2]
        dec2 = self.dec2(dec2)  # [B, 128, H/2, W/2]
        
        dec1 = self.up1(dec2)  # [B, 64, H, W]
        dec1 = torch.cat([dec1, enc1], dim=1)  # [B, 128, H, W]
        dec1 = self.dec1(dec1)  # [B, 64, H, W]
        
        final = self.final(dec1)  # [B, n_classes, H, W]
        return final
model = UNet(n_classes=1)  


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 200

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
    
    def forward(self, preds, targets):
        preds = preds.view(-1)
        targets = targets.view(-1)
        intersection = (preds * targets).sum()
        dice = (2.*intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1 - dice

criterion = nn.BCEWithLogitsLoss()



In [None]:
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for images, masks in train_loader:
            images = images.to(device)
            masks = masks.to(device)
            masks = masks.float()  
            masks = (masks > 100/255).float()  
            
            optimizer.zero_grad()
            outputs = model(images)
            outputs = outputs.squeeze(1)  
            loss = criterion(outputs, masks.squeeze(1))
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': epoch_loss / (pbar.n + 1)})
            pbar.update(1)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
    
    
    if epoch == num_epochs - 1 or (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"unet_epoch_{epoch+1}.pth")


Epoch 1/200: 100%|██████████| 6/6 [00:21<00:00,  3.50s/it, loss=0.656]


Epoch 1/200, Loss: 0.6565


Epoch 2/200: 100%|██████████| 6/6 [00:17<00:00,  2.92s/it, loss=0.591]


Epoch 2/200, Loss: 0.5912


Epoch 3/200: 100%|██████████| 6/6 [00:17<00:00,  2.92s/it, loss=0.521]


Epoch 3/200, Loss: 0.5208


Epoch 4/200: 100%|██████████| 6/6 [00:17<00:00,  2.94s/it, loss=0.464]


Epoch 4/200, Loss: 0.4640


Epoch 5/200: 100%|██████████| 6/6 [00:17<00:00,  2.92s/it, loss=0.407]


Epoch 5/200, Loss: 0.4074


Epoch 6/200: 100%|██████████| 6/6 [00:18<00:00,  3.06s/it, loss=0.368]


Epoch 6/200, Loss: 0.3682


Epoch 7/200: 100%|██████████| 6/6 [00:18<00:00,  3.14s/it, loss=0.336]


Epoch 7/200, Loss: 0.3362


Epoch 8/200: 100%|██████████| 6/6 [00:18<00:00,  3.15s/it, loss=0.318]


Epoch 8/200, Loss: 0.3185


Epoch 9/200: 100%|██████████| 6/6 [00:19<00:00,  3.28s/it, loss=0.31] 


Epoch 9/200, Loss: 0.3096


Epoch 10/200: 100%|██████████| 6/6 [00:19<00:00,  3.22s/it, loss=0.302]


Epoch 10/200, Loss: 0.3023


Epoch 11/200: 100%|██████████| 6/6 [00:19<00:00,  3.29s/it, loss=0.297]


Epoch 11/200, Loss: 0.2966


Epoch 12/200: 100%|██████████| 6/6 [00:19<00:00,  3.33s/it, loss=0.29] 


Epoch 12/200, Loss: 0.2898


Epoch 13/200: 100%|██████████| 6/6 [00:19<00:00,  3.31s/it, loss=0.285]


Epoch 13/200, Loss: 0.2852


Epoch 14/200: 100%|██████████| 6/6 [00:19<00:00,  3.32s/it, loss=0.281]


Epoch 14/200, Loss: 0.2810


Epoch 15/200: 100%|██████████| 6/6 [00:19<00:00,  3.24s/it, loss=0.279]


Epoch 15/200, Loss: 0.2787


Epoch 16/200: 100%|██████████| 6/6 [00:19<00:00,  3.27s/it, loss=0.273]


Epoch 16/200, Loss: 0.2735


Epoch 17/200: 100%|██████████| 6/6 [00:19<00:00,  3.26s/it, loss=0.267]


Epoch 17/200, Loss: 0.2674


Epoch 18/200: 100%|██████████| 6/6 [00:19<00:00,  3.26s/it, loss=0.264]


Epoch 18/200, Loss: 0.2638


Epoch 19/200: 100%|██████████| 6/6 [00:19<00:00,  3.29s/it, loss=0.261]


Epoch 19/200, Loss: 0.2605


Epoch 20/200: 100%|██████████| 6/6 [00:19<00:00,  3.31s/it, loss=0.256]


Epoch 20/200, Loss: 0.2560


Epoch 21/200: 100%|██████████| 6/6 [00:19<00:00,  3.27s/it, loss=0.253]


Epoch 21/200, Loss: 0.2531


Epoch 22/200: 100%|██████████| 6/6 [00:20<00:00,  3.36s/it, loss=0.249]


Epoch 22/200, Loss: 0.2485


Epoch 23/200: 100%|██████████| 6/6 [00:19<00:00,  3.29s/it, loss=0.244]


Epoch 23/200, Loss: 0.2437


Epoch 24/200: 100%|██████████| 6/6 [00:19<00:00,  3.33s/it, loss=0.241]


Epoch 24/200, Loss: 0.2409


Epoch 25/200: 100%|██████████| 6/6 [00:20<00:00,  3.41s/it, loss=0.237]


Epoch 25/200, Loss: 0.2372


Epoch 26/200: 100%|██████████| 6/6 [00:20<00:00,  3.40s/it, loss=0.234]


Epoch 26/200, Loss: 0.2340


Epoch 27/200: 100%|██████████| 6/6 [00:20<00:00,  3.40s/it, loss=0.231]


Epoch 27/200, Loss: 0.2311


Epoch 28/200: 100%|██████████| 6/6 [00:20<00:00,  3.44s/it, loss=0.227]


Epoch 28/200, Loss: 0.2275


Epoch 29/200: 100%|██████████| 6/6 [00:20<00:00,  3.43s/it, loss=0.225]


Epoch 29/200, Loss: 0.2247


Epoch 30/200: 100%|██████████| 6/6 [00:20<00:00,  3.44s/it, loss=0.221]


Epoch 30/200, Loss: 0.2211


Epoch 31/200: 100%|██████████| 6/6 [00:21<00:00,  3.50s/it, loss=0.219]


Epoch 31/200, Loss: 0.2186


Epoch 32/200: 100%|██████████| 6/6 [00:20<00:00,  3.47s/it, loss=0.216]


Epoch 32/200, Loss: 0.2156


Epoch 33/200: 100%|██████████| 6/6 [00:20<00:00,  3.48s/it, loss=0.213]


Epoch 33/200, Loss: 0.2127


Epoch 34/200: 100%|██████████| 6/6 [00:21<00:00,  3.52s/it, loss=0.21] 


Epoch 34/200, Loss: 0.2100


Epoch 35/200: 100%|██████████| 6/6 [00:20<00:00,  3.39s/it, loss=0.209]


Epoch 35/200, Loss: 0.2088


Epoch 36/200: 100%|██████████| 6/6 [00:25<00:00,  4.20s/it, loss=0.207]


Epoch 36/200, Loss: 0.2072


Epoch 37/200: 100%|██████████| 6/6 [00:26<00:00,  4.46s/it, loss=0.205]


Epoch 37/200, Loss: 0.2055


Epoch 38/200: 100%|██████████| 6/6 [00:23<00:00,  3.85s/it, loss=0.204]


Epoch 38/200, Loss: 0.2043


Epoch 39/200: 100%|██████████| 6/6 [00:26<00:00,  4.38s/it, loss=0.2]  


Epoch 39/200, Loss: 0.2000


Epoch 40/200: 100%|██████████| 6/6 [00:23<00:00,  3.84s/it, loss=0.196]


Epoch 40/200, Loss: 0.1964


Epoch 41/200: 100%|██████████| 6/6 [00:25<00:00,  4.20s/it, loss=0.193]


Epoch 41/200, Loss: 0.1933


Epoch 42/200: 100%|██████████| 6/6 [00:24<00:00,  4.13s/it, loss=0.19] 


Epoch 42/200, Loss: 0.1902


Epoch 43/200: 100%|██████████| 6/6 [00:25<00:00,  4.18s/it, loss=0.188]


Epoch 43/200, Loss: 0.1876


Epoch 44/200: 100%|██████████| 6/6 [00:24<00:00,  4.13s/it, loss=0.185]


Epoch 44/200, Loss: 0.1853


Epoch 45/200: 100%|██████████| 6/6 [00:24<00:00,  4.10s/it, loss=0.183]


Epoch 45/200, Loss: 0.1828


Epoch 46/200: 100%|██████████| 6/6 [00:24<00:00,  4.04s/it, loss=0.181]


Epoch 46/200, Loss: 0.1805


Epoch 47/200: 100%|██████████| 6/6 [00:26<00:00,  4.41s/it, loss=0.178]


Epoch 47/200, Loss: 0.1782


Epoch 48/200: 100%|██████████| 6/6 [00:25<00:00,  4.19s/it, loss=0.176]


Epoch 48/200, Loss: 0.1762


Epoch 49/200: 100%|██████████| 6/6 [00:24<00:00,  4.16s/it, loss=0.174]


Epoch 49/200, Loss: 0.1742


Epoch 50/200: 100%|██████████| 6/6 [00:27<00:00,  4.53s/it, loss=0.172]


Epoch 50/200, Loss: 0.1719


Epoch 51/200: 100%|██████████| 6/6 [00:28<00:00,  4.78s/it, loss=0.17] 


Epoch 51/200, Loss: 0.1699


Epoch 52/200: 100%|██████████| 6/6 [00:26<00:00,  4.49s/it, loss=0.168]


Epoch 52/200, Loss: 0.1681


Epoch 53/200: 100%|██████████| 6/6 [00:25<00:00,  4.18s/it, loss=0.166]


Epoch 53/200, Loss: 0.1662


Epoch 54/200: 100%|██████████| 6/6 [00:30<00:00,  5.03s/it, loss=0.164]


Epoch 54/200, Loss: 0.1644


Epoch 55/200: 100%|██████████| 6/6 [00:31<00:00,  5.33s/it, loss=0.162]


Epoch 55/200, Loss: 0.1623


Epoch 56/200: 100%|██████████| 6/6 [00:29<00:00,  4.97s/it, loss=0.161]


Epoch 56/200, Loss: 0.1607


Epoch 57/200: 100%|██████████| 6/6 [00:26<00:00,  4.45s/it, loss=0.159]


Epoch 57/200, Loss: 0.1586


Epoch 58/200: 100%|██████████| 6/6 [00:26<00:00,  4.34s/it, loss=0.157]


Epoch 58/200, Loss: 0.1572


Epoch 59/200: 100%|██████████| 6/6 [00:26<00:00,  4.36s/it, loss=0.155]


Epoch 59/200, Loss: 0.1553


Epoch 60/200: 100%|██████████| 6/6 [00:27<00:00,  4.52s/it, loss=0.154]


Epoch 60/200, Loss: 0.1535


Epoch 61/200: 100%|██████████| 6/6 [00:28<00:00,  4.82s/it, loss=0.152]


Epoch 61/200, Loss: 0.1519


Epoch 62/200: 100%|██████████| 6/6 [00:28<00:00,  4.79s/it, loss=0.15] 


Epoch 62/200, Loss: 0.1500


Epoch 63/200: 100%|██████████| 6/6 [00:33<00:00,  5.63s/it, loss=0.148]


Epoch 63/200, Loss: 0.1485


Epoch 64/200: 100%|██████████| 6/6 [00:32<00:00,  5.39s/it, loss=0.147]


Epoch 64/200, Loss: 0.1467


Epoch 65/200: 100%|██████████| 6/6 [00:28<00:00,  4.73s/it, loss=0.145]


Epoch 65/200, Loss: 0.1450


Epoch 66/200: 100%|██████████| 6/6 [00:31<00:00,  5.18s/it, loss=0.144]


Epoch 66/200, Loss: 0.1436


Epoch 67/200: 100%|██████████| 6/6 [00:34<00:00,  5.69s/it, loss=0.142]


Epoch 67/200, Loss: 0.1423


Epoch 68/200: 100%|██████████| 6/6 [00:31<00:00,  5.29s/it, loss=0.141]


Epoch 68/200, Loss: 0.1407


Epoch 69/200: 100%|██████████| 6/6 [00:29<00:00,  5.00s/it, loss=0.139]


Epoch 69/200, Loss: 0.1392


Epoch 70/200: 100%|██████████| 6/6 [00:30<00:00,  5.13s/it, loss=0.138]


Epoch 70/200, Loss: 0.1378


Epoch 71/200: 100%|██████████| 6/6 [00:28<00:00,  4.71s/it, loss=0.136]


Epoch 71/200, Loss: 0.1362


Epoch 72/200: 100%|██████████| 6/6 [00:31<00:00,  5.18s/it, loss=0.135]


Epoch 72/200, Loss: 0.1347


Epoch 73/200: 100%|██████████| 6/6 [00:32<00:00,  5.34s/it, loss=0.134]


Epoch 73/200, Loss: 0.1335


Epoch 74/200: 100%|██████████| 6/6 [00:28<00:00,  4.75s/it, loss=0.132]


Epoch 74/200, Loss: 0.1322


Epoch 75/200: 100%|██████████| 6/6 [00:30<00:00,  5.15s/it, loss=0.131]


Epoch 75/200, Loss: 0.1306


Epoch 76/200: 100%|██████████| 6/6 [00:34<00:00,  5.70s/it, loss=0.129]


Epoch 76/200, Loss: 0.1294


Epoch 77/200: 100%|██████████| 6/6 [00:31<00:00,  5.26s/it, loss=0.128]


Epoch 77/200, Loss: 0.1284


Epoch 78/200: 100%|██████████| 6/6 [00:28<00:00,  4.80s/it, loss=0.127]


Epoch 78/200, Loss: 0.1270


Epoch 79/200: 100%|██████████| 6/6 [00:29<00:00,  4.84s/it, loss=0.126]


Epoch 79/200, Loss: 0.1257


Epoch 80/200: 100%|██████████| 6/6 [00:28<00:00,  4.82s/it, loss=0.124]


Epoch 80/200, Loss: 0.1240


Epoch 81/200: 100%|██████████| 6/6 [00:29<00:00,  4.89s/it, loss=0.122]


Epoch 81/200, Loss: 0.1225


Epoch 82/200: 100%|██████████| 6/6 [00:27<00:00,  4.55s/it, loss=0.121]


Epoch 82/200, Loss: 0.1213


Epoch 83/200: 100%|██████████| 6/6 [00:28<00:00,  4.73s/it, loss=0.12] 


Epoch 83/200, Loss: 0.1198


Epoch 84/200: 100%|██████████| 6/6 [00:28<00:00,  4.81s/it, loss=0.118]


Epoch 84/200, Loss: 0.1185


Epoch 85/200: 100%|██████████| 6/6 [00:28<00:00,  4.83s/it, loss=0.117]


Epoch 85/200, Loss: 0.1172


Epoch 86/200: 100%|██████████| 6/6 [00:28<00:00,  4.73s/it, loss=0.116]


Epoch 86/200, Loss: 0.1161


Epoch 87/200: 100%|██████████| 6/6 [00:28<00:00,  4.69s/it, loss=0.115]


Epoch 87/200, Loss: 0.1153


Epoch 88/200: 100%|██████████| 6/6 [00:29<00:00,  5.00s/it, loss=0.114]


Epoch 88/200, Loss: 0.1140


Epoch 89/200: 100%|██████████| 6/6 [00:29<00:00,  4.93s/it, loss=0.113]


Epoch 89/200, Loss: 0.1129


Epoch 90/200: 100%|██████████| 6/6 [00:30<00:00,  5.03s/it, loss=0.112]


Epoch 90/200, Loss: 0.1116


Epoch 91/200: 100%|██████████| 6/6 [00:31<00:00,  5.28s/it, loss=0.11] 


Epoch 91/200, Loss: 0.1104


Epoch 92/200: 100%|██████████| 6/6 [00:28<00:00,  4.79s/it, loss=0.109]


Epoch 92/200, Loss: 0.1093


Epoch 93/200: 100%|██████████| 6/6 [00:27<00:00,  4.62s/it, loss=0.108]


Epoch 93/200, Loss: 0.1083


Epoch 94/200: 100%|██████████| 6/6 [00:28<00:00,  4.68s/it, loss=0.107]


Epoch 94/200, Loss: 0.1072


Epoch 95/200: 100%|██████████| 6/6 [00:33<00:00,  5.50s/it, loss=0.106]


Epoch 95/200, Loss: 0.1062


Epoch 96/200: 100%|██████████| 6/6 [00:37<00:00,  6.28s/it, loss=0.105]


Epoch 96/200, Loss: 0.1050


Epoch 97/200: 100%|██████████| 6/6 [00:31<00:00,  5.23s/it, loss=0.104]


Epoch 97/200, Loss: 0.1038


Epoch 98/200: 100%|██████████| 6/6 [00:33<00:00,  5.52s/it, loss=0.103]


Epoch 98/200, Loss: 0.1029


Epoch 99/200: 100%|██████████| 6/6 [00:32<00:00,  5.40s/it, loss=0.102]


Epoch 99/200, Loss: 0.1022


Epoch 100/200: 100%|██████████| 6/6 [00:29<00:00,  4.93s/it, loss=0.101]


Epoch 100/200, Loss: 0.1011


Epoch 101/200: 100%|██████████| 6/6 [00:29<00:00,  4.98s/it, loss=0.1]  


Epoch 101/200, Loss: 0.1001


Epoch 102/200: 100%|██████████| 6/6 [00:28<00:00,  4.73s/it, loss=0.0992]


Epoch 102/200, Loss: 0.0992


Epoch 103/200: 100%|██████████| 6/6 [00:28<00:00,  4.72s/it, loss=0.0981]


Epoch 103/200, Loss: 0.0981


Epoch 104/200: 100%|██████████| 6/6 [00:27<00:00,  4.52s/it, loss=0.0969]


Epoch 104/200, Loss: 0.0969


Epoch 105/200: 100%|██████████| 6/6 [00:26<00:00,  4.42s/it, loss=0.096] 


Epoch 105/200, Loss: 0.0960


Epoch 106/200: 100%|██████████| 6/6 [00:25<00:00,  4.26s/it, loss=0.0952]


Epoch 106/200, Loss: 0.0952


Epoch 107/200: 100%|██████████| 6/6 [00:34<00:00,  5.69s/it, loss=0.0943]


Epoch 107/200, Loss: 0.0943


Epoch 108/200: 100%|██████████| 6/6 [00:32<00:00,  5.46s/it, loss=0.0934]


Epoch 108/200, Loss: 0.0934


Epoch 109/200: 100%|██████████| 6/6 [00:30<00:00,  5.05s/it, loss=0.0929]


Epoch 109/200, Loss: 0.0929


Epoch 110/200: 100%|██████████| 6/6 [00:34<00:00,  5.70s/it, loss=0.092] 


Epoch 110/200, Loss: 0.0920


Epoch 111/200: 100%|██████████| 6/6 [00:34<00:00,  5.73s/it, loss=0.0911]


Epoch 111/200, Loss: 0.0911


Epoch 112/200: 100%|██████████| 6/6 [00:32<00:00,  5.36s/it, loss=0.0899]


Epoch 112/200, Loss: 0.0899


Epoch 113/200: 100%|██████████| 6/6 [00:32<00:00,  5.48s/it, loss=0.0889]


Epoch 113/200, Loss: 0.0889


Epoch 114/200: 100%|██████████| 6/6 [00:34<00:00,  5.73s/it, loss=0.0881]


Epoch 114/200, Loss: 0.0881


Epoch 115/200: 100%|██████████| 6/6 [00:39<00:00,  6.59s/it, loss=0.0875]


Epoch 115/200, Loss: 0.0875


Epoch 116/200: 100%|██████████| 6/6 [00:30<00:00,  5.02s/it, loss=0.087] 


Epoch 116/200, Loss: 0.0870


Epoch 117/200: 100%|██████████| 6/6 [00:35<00:00,  5.88s/it, loss=0.0861]


Epoch 117/200, Loss: 0.0861


Epoch 118/200: 100%|██████████| 6/6 [00:34<00:00,  5.70s/it, loss=0.0851]


Epoch 118/200, Loss: 0.0851


Epoch 119/200: 100%|██████████| 6/6 [00:31<00:00,  5.27s/it, loss=0.0845]


Epoch 119/200, Loss: 0.0845


Epoch 120/200: 100%|██████████| 6/6 [00:32<00:00,  5.49s/it, loss=0.0835]


Epoch 120/200, Loss: 0.0835


Epoch 121/200: 100%|██████████| 6/6 [00:29<00:00,  4.98s/it, loss=0.0826]


Epoch 121/200, Loss: 0.0826


Epoch 122/200: 100%|██████████| 6/6 [00:28<00:00,  4.80s/it, loss=0.0819]


Epoch 122/200, Loss: 0.0819


Epoch 123/200: 100%|██████████| 6/6 [00:26<00:00,  4.49s/it, loss=0.081] 


Epoch 123/200, Loss: 0.0810


Epoch 124/200: 100%|██████████| 6/6 [00:29<00:00,  4.93s/it, loss=0.0802]


Epoch 124/200, Loss: 0.0802


Epoch 125/200: 100%|██████████| 6/6 [00:26<00:00,  4.34s/it, loss=0.0795]


Epoch 125/200, Loss: 0.0795


Epoch 126/200: 100%|██████████| 6/6 [00:29<00:00,  4.89s/it, loss=0.0788]


Epoch 126/200, Loss: 0.0788


Epoch 127/200: 100%|██████████| 6/6 [00:26<00:00,  4.37s/it, loss=0.0779]


Epoch 127/200, Loss: 0.0779


Epoch 128/200: 100%|██████████| 6/6 [00:25<00:00,  4.25s/it, loss=0.0772]


Epoch 128/200, Loss: 0.0772


Epoch 129/200: 100%|██████████| 6/6 [00:26<00:00,  4.50s/it, loss=0.0765]


Epoch 129/200, Loss: 0.0765


Epoch 130/200: 100%|██████████| 6/6 [00:27<00:00,  4.55s/it, loss=0.0758]


Epoch 130/200, Loss: 0.0758


Epoch 131/200: 100%|██████████| 6/6 [00:27<00:00,  4.54s/it, loss=0.0752]


Epoch 131/200, Loss: 0.0752


Epoch 132/200: 100%|██████████| 6/6 [00:27<00:00,  4.61s/it, loss=0.0747]


Epoch 132/200, Loss: 0.0747


Epoch 133/200: 100%|██████████| 6/6 [00:27<00:00,  4.53s/it, loss=0.074] 


Epoch 133/200, Loss: 0.0740


Epoch 134/200: 100%|██████████| 6/6 [00:27<00:00,  4.54s/it, loss=0.0735]


Epoch 134/200, Loss: 0.0735


Epoch 135/200: 100%|██████████| 6/6 [00:27<00:00,  4.56s/it, loss=0.0727]


Epoch 135/200, Loss: 0.0727


Epoch 136/200: 100%|██████████| 6/6 [00:25<00:00,  4.30s/it, loss=0.0719]


Epoch 136/200, Loss: 0.0719


Epoch 137/200: 100%|██████████| 6/6 [00:25<00:00,  4.26s/it, loss=0.0712]


Epoch 137/200, Loss: 0.0712


Epoch 138/200: 100%|██████████| 6/6 [00:27<00:00,  4.52s/it, loss=0.0707]


Epoch 138/200, Loss: 0.0707


Epoch 139/200: 100%|██████████| 6/6 [00:26<00:00,  4.47s/it, loss=0.07]  


Epoch 139/200, Loss: 0.0700


Epoch 140/200: 100%|██████████| 6/6 [00:25<00:00,  4.22s/it, loss=0.0692]


Epoch 140/200, Loss: 0.0692


Epoch 141/200: 100%|██████████| 6/6 [00:27<00:00,  4.51s/it, loss=0.0685]


Epoch 141/200, Loss: 0.0685


Epoch 142/200: 100%|██████████| 6/6 [00:27<00:00,  4.64s/it, loss=0.068] 


Epoch 142/200, Loss: 0.0680


Epoch 143/200: 100%|██████████| 6/6 [00:25<00:00,  4.27s/it, loss=0.0674]


Epoch 143/200, Loss: 0.0674


Epoch 144/200: 100%|██████████| 6/6 [00:25<00:00,  4.29s/it, loss=0.067] 


Epoch 144/200, Loss: 0.0670


Epoch 145/200: 100%|██████████| 6/6 [00:25<00:00,  4.21s/it, loss=0.0664]


Epoch 145/200, Loss: 0.0664


Epoch 146/200: 100%|██████████| 6/6 [00:26<00:00,  4.39s/it, loss=0.0661]


Epoch 146/200, Loss: 0.0661


Epoch 147/200: 100%|██████████| 6/6 [00:23<00:00,  3.97s/it, loss=0.0653]


Epoch 147/200, Loss: 0.0653


Epoch 148/200: 100%|██████████| 6/6 [00:24<00:00,  4.16s/it, loss=0.0646]


Epoch 148/200, Loss: 0.0646


Epoch 149/200: 100%|██████████| 6/6 [00:24<00:00,  4.11s/it, loss=0.0643]


Epoch 149/200, Loss: 0.0643


Epoch 150/200: 100%|██████████| 6/6 [00:25<00:00,  4.18s/it, loss=0.0635]


Epoch 150/200, Loss: 0.0635


Epoch 151/200: 100%|██████████| 6/6 [00:26<00:00,  4.34s/it, loss=0.063] 


Epoch 151/200, Loss: 0.0630


Epoch 152/200: 100%|██████████| 6/6 [00:24<00:00,  4.13s/it, loss=0.0625]


Epoch 152/200, Loss: 0.0625


Epoch 153/200: 100%|██████████| 6/6 [00:26<00:00,  4.45s/it, loss=0.0618]


Epoch 153/200, Loss: 0.0618


Epoch 154/200: 100%|██████████| 6/6 [00:25<00:00,  4.22s/it, loss=0.0614]


Epoch 154/200, Loss: 0.0614


Epoch 155/200: 100%|██████████| 6/6 [00:26<00:00,  4.36s/it, loss=0.0609]


Epoch 155/200, Loss: 0.0609


Epoch 156/200: 100%|██████████| 6/6 [00:26<00:00,  4.37s/it, loss=0.0602]


Epoch 156/200, Loss: 0.0602


Epoch 157/200: 100%|██████████| 6/6 [00:26<00:00,  4.50s/it, loss=0.0596]


Epoch 157/200, Loss: 0.0596


Epoch 158/200: 100%|██████████| 6/6 [00:25<00:00,  4.19s/it, loss=0.059] 


Epoch 158/200, Loss: 0.0590


Epoch 159/200: 100%|██████████| 6/6 [00:24<00:00,  4.07s/it, loss=0.0585]


Epoch 159/200, Loss: 0.0585


Epoch 160/200: 100%|██████████| 6/6 [00:25<00:00,  4.26s/it, loss=0.0581]


Epoch 160/200, Loss: 0.0581


Epoch 161/200: 100%|██████████| 6/6 [00:24<00:00,  4.14s/it, loss=0.0577]


Epoch 161/200, Loss: 0.0577


Epoch 162/200: 100%|██████████| 6/6 [00:22<00:00,  3.80s/it, loss=0.0572]


Epoch 162/200, Loss: 0.0572


Epoch 163/200: 100%|██████████| 6/6 [00:24<00:00,  4.16s/it, loss=0.0565]


Epoch 163/200, Loss: 0.0565


Epoch 164/200: 100%|██████████| 6/6 [00:24<00:00,  4.12s/it, loss=0.056] 


Epoch 164/200, Loss: 0.0560


Epoch 165/200: 100%|██████████| 6/6 [00:26<00:00,  4.47s/it, loss=0.0556]


Epoch 165/200, Loss: 0.0556


Epoch 166/200: 100%|██████████| 6/6 [00:24<00:00,  4.07s/it, loss=0.0551]


Epoch 166/200, Loss: 0.0551


Epoch 167/200: 100%|██████████| 6/6 [00:25<00:00,  4.26s/it, loss=0.0546]


Epoch 167/200, Loss: 0.0546


Epoch 168/200: 100%|██████████| 6/6 [00:25<00:00,  4.33s/it, loss=0.0541]


Epoch 168/200, Loss: 0.0541


Epoch 169/200: 100%|██████████| 6/6 [00:26<00:00,  4.47s/it, loss=0.0537]


Epoch 169/200, Loss: 0.0537


Epoch 170/200: 100%|██████████| 6/6 [00:26<00:00,  4.44s/it, loss=0.0533]


Epoch 170/200, Loss: 0.0533


Epoch 171/200: 100%|██████████| 6/6 [00:26<00:00,  4.44s/it, loss=0.0529]


Epoch 171/200, Loss: 0.0529


Epoch 172/200: 100%|██████████| 6/6 [00:25<00:00,  4.27s/it, loss=0.0526]


Epoch 172/200, Loss: 0.0526


Epoch 173/200: 100%|██████████| 6/6 [00:25<00:00,  4.26s/it, loss=0.0524]


Epoch 173/200, Loss: 0.0524


Epoch 174/200: 100%|██████████| 6/6 [00:25<00:00,  4.24s/it, loss=0.0521]


Epoch 174/200, Loss: 0.0521


Epoch 175/200: 100%|██████████| 6/6 [00:25<00:00,  4.26s/it, loss=0.0514]


Epoch 175/200, Loss: 0.0514


Epoch 176/200: 100%|██████████| 6/6 [00:25<00:00,  4.23s/it, loss=0.051] 


Epoch 176/200, Loss: 0.0510


Epoch 177/200: 100%|██████████| 6/6 [00:26<00:00,  4.38s/it, loss=0.0507]


Epoch 177/200, Loss: 0.0507


Epoch 178/200: 100%|██████████| 6/6 [00:25<00:00,  4.23s/it, loss=0.0503]


Epoch 178/200, Loss: 0.0503


Epoch 179/200: 100%|██████████| 6/6 [00:26<00:00,  4.39s/it, loss=0.0497]


Epoch 179/200, Loss: 0.0497


Epoch 180/200: 100%|██████████| 6/6 [00:25<00:00,  4.27s/it, loss=0.0493]


Epoch 180/200, Loss: 0.0493


Epoch 181/200: 100%|██████████| 6/6 [00:25<00:00,  4.28s/it, loss=0.0488]


Epoch 181/200, Loss: 0.0488


Epoch 182/200: 100%|██████████| 6/6 [00:26<00:00,  4.49s/it, loss=0.0485]


Epoch 182/200, Loss: 0.0485


Epoch 183/200: 100%|██████████| 6/6 [00:27<00:00,  4.59s/it, loss=0.0481]


Epoch 183/200, Loss: 0.0481


Epoch 184/200: 100%|██████████| 6/6 [00:25<00:00,  4.28s/it, loss=0.0477]


Epoch 184/200, Loss: 0.0477


Epoch 185/200: 100%|██████████| 6/6 [00:26<00:00,  4.35s/it, loss=0.0472]


Epoch 185/200, Loss: 0.0472


Epoch 186/200: 100%|██████████| 6/6 [00:30<00:00,  5.14s/it, loss=0.0468]


Epoch 186/200, Loss: 0.0468


Epoch 187/200: 100%|██████████| 6/6 [00:28<00:00,  4.77s/it, loss=0.0465]


Epoch 187/200, Loss: 0.0465


Epoch 188/200: 100%|██████████| 6/6 [00:26<00:00,  4.47s/it, loss=0.0462]


Epoch 188/200, Loss: 0.0462


Epoch 189/200: 100%|██████████| 6/6 [00:27<00:00,  4.60s/it, loss=0.0458]


Epoch 189/200, Loss: 0.0458


Epoch 190/200: 100%|██████████| 6/6 [00:26<00:00,  4.35s/it, loss=0.0455]


Epoch 190/200, Loss: 0.0455


Epoch 191/200: 100%|██████████| 6/6 [00:27<00:00,  4.59s/it, loss=0.0452]


Epoch 191/200, Loss: 0.0452


Epoch 192/200: 100%|██████████| 6/6 [00:25<00:00,  4.31s/it, loss=0.0447]


Epoch 192/200, Loss: 0.0447


Epoch 193/200: 100%|██████████| 6/6 [00:27<00:00,  4.60s/it, loss=0.0443]


Epoch 193/200, Loss: 0.0443


Epoch 194/200: 100%|██████████| 6/6 [00:25<00:00,  4.17s/it, loss=0.044] 


Epoch 194/200, Loss: 0.0440


Epoch 195/200: 100%|██████████| 6/6 [00:29<00:00,  4.84s/it, loss=0.0437]


Epoch 195/200, Loss: 0.0437


Epoch 196/200: 100%|██████████| 6/6 [00:28<00:00,  4.82s/it, loss=0.0433]


Epoch 196/200, Loss: 0.0433


Epoch 197/200: 100%|██████████| 6/6 [00:29<00:00,  4.99s/it, loss=0.043] 


Epoch 197/200, Loss: 0.0430


Epoch 198/200: 100%|██████████| 6/6 [00:30<00:00,  5.06s/it, loss=0.0429]


Epoch 198/200, Loss: 0.0429


Epoch 199/200: 100%|██████████| 6/6 [00:29<00:00,  4.88s/it, loss=0.0424]


Epoch 199/200, Loss: 0.0424


Epoch 200/200: 100%|██████████| 6/6 [00:28<00:00,  4.76s/it, loss=0.042] 

Epoch 200/200, Loss: 0.0420





In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def preprocess_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    original_size = image.size  
    image = transform(image).unsqueeze(0).to(device)  
    return image, original_size

def postprocess_output(output, original_size, threshold=0.5):
    output = output.squeeze().cpu().detach().numpy()
    output = sigmoid(output)
    output = (output > threshold).astype(np.uint8) * 255
    output = Image.fromarray(output).resize(original_size, Image.NEAREST)
    return np.array(output)

def hough_transform(image, threshold=50, min_line_length=50, max_line_gap=150):
    edges = cv2.Canny(image, 50, 150, apertureSize=3)
    lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold, 
                            minLineLength=min_line_length, 
                            maxLineGap=max_line_gap)
    line_image = np.zeros_like(image)
    if lines is not None:
        for line in lines:
            for x1,y1,x2,y2 in line:
                cv2.line(line_image, (x1,y1), (x2,y2), 255, 2)
    return line_image

def combine_segmentation_hough(segmentation, hough):
    combined = cv2.bitwise_and(segmentation, hough)
    return combined

def hough_transform_area(image, threshold=50, min_line_length=50, max_line_gap=150):
    edges = cv2.Canny(image, 50, 150, apertureSize=3)
    lines = cv2.HoughLinesP(edges, 1, np.pi/180, threshold, 
                            minLineLength=min_line_length, 
                            maxLineGap=max_line_gap)
    line_image = np.zeros_like(image)
    points = []
    
    if lines is not None:
        for line in lines:
            for x1, y1, x2, y2 in line:
                
                points.append([x1, y1])
                points.append([x2, y2])
                
                cv2.line(line_image, (x1, y1), (x2, y2), 255, 2)
    
    if points:
        
        points = np.array(points)
        hull = cv2.convexHull(points)
        area_mask = np.zeros_like(image)
        cv2.fillConvexPoly(area_mask, hull, 255)
        return area_mask
    else:
        
        return np.zeros_like(image)

def visualize_result(original_image_path, segmentation, combined, save_path=None):
    original = cv2.imread(original_image_path)
    original = cv2.resize(original, (segmentation.shape[1], segmentation.shape[0]))
    
    overlay = original.copy()
    overlay[segmentation > 0] = [0, 255, 0]

    
    cv2.addWeighted(original, 0.7, overlay, 0.3, 0, overlay)
    
    if save_path:
        cv2.imwrite(save_path, overlay)
    else:
        cv2.imshow("Lane Detection", overlay)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        
def keep_largest_connected_area(mask):
    if mask.dtype != np.uint8:
        mask = mask.astype(np.uint8)
    contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return np.zeros_like(mask)

    
    largest_contour = max(contours, key=cv2.contourArea)

    
    largest_area_mask = np.zeros_like(mask)
    cv2.drawContours(largest_area_mask, [largest_contour], -1, 255, thickness=cv2.FILLED)

    return largest_area_mask


In [None]:
def process_video(input_video_path, output_video_path, transform, model, device):
    cap = cv2.VideoCapture(input_video_path)
    
    
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
    
    while(cap.isOpened()):
        ret, frame = cap.read()
        if not ret:
            break
        
        
        temp_image_path = "temp_frame.png"
        cv2.imwrite(temp_image_path, frame)
        
        
        image, original_size = preprocess_image(temp_image_path, transform)
        
        
        model.eval()
        with torch.no_grad():
            output = model(image.to(device))
        segmentation = postprocess_output(output, original_size)

        original_image = cv2.resize(frame, (segmentation.shape[1], segmentation.shape[0]))
        overlay = original_image.copy()
        overlay[segmentation > 0] = [0, 255, 0]  
        
        
        cv2.addWeighted(original_image, 0.7, overlay, 0.3, 0, overlay)
        
        
        final_frame = cv2.resize(overlay, (frame_width, frame_height))
        
        
        out.write(final_frame)
    
    cap.release()
    out.release()
    cv2.destroyAllWindows()
    
    if os.path.exists("temp_frame.png"):
        os.remove("temp_frame.png")



In [None]:
# model.load_state_dict(torch.load(f"unet_epoch_{200}.pth"))
model.eval()



  model.load_state_dict(torch.load(f"unet_epoch_{200}.pth"))


UNet(
  (enc1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (enc3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (enc4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (center): Sequential(
    (0

In [None]:

for test_img in os.listdir("dataset/data_road/testing/image_2"):
    test_image_path = "dataset/data_road/testing/image_2/" + test_img
    
    image, original_size = preprocess_image(test_image_path, image_transforms)
    with torch.no_grad():
        output = model(image.to(device))
    segmentation = postprocess_output(output, original_size)
    segmentation = keep_largest_connected_area(segmentation)
    hough = hough_transform_area(segmentation)
    combined = combine_segmentation_hough(segmentation, hough)
    if not os.path.exists("results"):
        os.mkdir("results")
    visualize_result(test_image_path, segmentation, combined, save_path=f"results/{test_img}")


In [None]:
input_video_path = "project_video.mp4"  
output_video_path = "project_video_detected.avi"
process_video(input_video_path, output_video_path, image_transforms, model, device)

In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    dice_scores = []
    iou_scores = []
    pixel_accuracies = []
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = (masks>100/255).to(device).float()
            outputs = model(images)
            outputs = torch.sigmoid(outputs)
            outputs = (outputs > 0.5).float()

            
            outputs_flat = outputs.view(-1)
            masks_flat = masks.view(-1)

            intersection = (outputs_flat * masks_flat).sum()
            dice_score = (2. * intersection + 1e-6) / (outputs_flat.sum() + masks_flat.sum() + 1e-6)
            dice_scores.append(dice_score.item())
      
            union = outputs_flat.sum() + masks_flat.sum() - intersection
            iou_score = (intersection + 1e-6) / (union + 1e-6)
            iou_scores.append(iou_score.item())

            
            correct = (outputs_flat == masks_flat).sum()
            total = outputs_flat.numel()
            pixel_accuracy = correct.item() / total
            pixel_accuracies.append(pixel_accuracy)

    avg_dice = np.mean(dice_scores)
    avg_iou = np.mean(iou_scores)
    avg_accuracy = np.mean(pixel_accuracies)

    print(f"Average Dice Coefficient: {avg_dice:.4f}")
    print(f"Average IoU: {avg_iou:.4f}")
    print(f"Average Pixel Accuracy: {avg_accuracy:.4f}")

    return avg_dice, avg_iou, avg_accuracy

evaluate(model, train_loader, device)

Average Dice Coefficient: 0.9981
Average IoU: 0.9962
Average Pixel Accuracy: 0.9997


(np.float64(0.9980734884738922),
 np.float64(0.9961545169353485),
 np.float64(0.9996948877970379))