In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("sakshaymahna/cityscapes-depth-and-segmentation")

print("Path to dataset files:", path)

In [None]:
import os
cmd = f"mv {path} ./"

os.system(cmd)

In [None]:
%mv ./1/data ./data

In [None]:
%rm -r ./1

In [None]:
import numpy as np
from matplotlib import pyplot as plt

x = np.load("./data/train/image/0.npy")
x = x*255
x = x.astype(np.uint8)
plt.imshow(x)

In [None]:
x = np.load("./data/train/label/0.npy")
plt.imshow(x)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from natsort import natsorted
import os 

class CityScapes(Dataset):
    def __init__(self, main_dir, transforms=None):
        self.main_dir = main_dir
        self.transforms = transforms
        image_dir = os.path.join(main_dir, 'image')
        self.images = []
        for file in natsorted(os.listdir(image_dir)):
            self.images.append(file)
        label_dir = os.path.join(main_dir, 'label')
        self.labels = []
        for file in natsorted(os.listdir(image_dir)):
            self.labels.append(file)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image_path = os.path.join(self.main_dir, 'image', self.images[index])
        label_path = os.path.join(self.main_dir, 'label', self.images[index])
        image, label = np.load(image_path), np.load(label_path)
        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=label)
            image, label = transformed['image'], transformed['mask']
            label = torch.where(label==13, 1, 0)
        return image, label

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import random_split

train_transforms = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(),
    A.Rotate(limit=10, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    ToTensorV2()
])
test_transforms = A.Compose([
    A.Resize(224, 224),
    ToTensorV2()
])
train_dataset = CityScapes("./data/train", transforms=train_transforms)
train_dataset, valid_dataset = random_split(train_dataset, (0.9, 0.1))
test_dataset = CityScapes("./data/val", transforms=test_transforms)


In [None]:
img, label = train_dataset[10]
# plt.imshow(label)
img = img.permute(1,2,0)
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10,10))
axes[0].imshow(img.cpu().numpy())
axes[1].imshow(label)

In [None]:
import torch.nn as nn

def simple_block(in_channel, out_channel):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU()
    )

In [None]:
weight = torch.load("./vgg16_bn-6c64b313.pth")
for item in weight.items():
    print(item[0], item[1].shape)

In [None]:
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_channel=3):
        super().__init__()
        self.encoder1 = nn.Sequential(
            simple_block(input_channel, 64),
            simple_block(64, 64),
        )
        self.encoder2 = nn.Sequential(
            simple_block(64, 128),
            simple_block(128, 128),
        )
        self.encoder3 = nn.Sequential(
            simple_block(128, 256),
            simple_block(256, 256),
            simple_block(256, 256),
        )
        self.encoder4 = nn.Sequential(
            simple_block(256, 512),
            simple_block(512, 512),
            simple_block(512, 512),
        )
        self.encoder5 = nn.Sequential(
            simple_block(512, 512),
            simple_block(512, 512),
            simple_block(512, 512),
        )
    
    def forward(self, data):
        x = self.encoder1(data)
        x, id1 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        x = self.encoder2(x)
        x, id2 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        x = self.encoder3(x)
        x, id3 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        x = self.encoder4(x)
        x, id4 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        x = self.encoder5(x)
        x, id5 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        
        return x, [id1, id2, id3, id4, id5]

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_channel=2):
        super().__init__()
        self.decoder1 = nn.Sequential(
            simple_block(512, 512),
            simple_block(512, 512),
            simple_block(512, 512),
        )
        self.decoder2 = nn.Sequential(
            simple_block(512, 512),
            simple_block(512, 512),
            simple_block(512, 256),
        )
        self.decoder3 = nn.Sequential(
            simple_block(256, 256),
            simple_block(256, 256),
            simple_block(256, 128),
        )
        self.decoder4 = nn.Sequential(
            simple_block(128, 128),
            simple_block(128, 64),
        )
        self.decoder5 = nn.Sequential(
            simple_block(64, 64),
            simple_block(64, output_channel)
        )
    def forward(self, data, ids):
        reverted_ids = ids[::-1]
        x = F.max_unpool2d(data, indices=reverted_ids[0], kernel_size=2, stride=2)
        x = self.decoder1(x)
        x = F.max_unpool2d(x, indices=reverted_ids[1], kernel_size=2, stride=2)
        x = self.decoder2(x)
        x = F.max_unpool2d(x, indices=reverted_ids[2], kernel_size=2, stride=2)
        x = self.decoder3(x)
        x = F.max_unpool2d(x, indices=reverted_ids[3], kernel_size=2, stride=2)
        x = self.decoder4(x)
        x = F.max_unpool2d(x, indices=reverted_ids[4], kernel_size=2, stride=2)
        x = self.decoder5(x)
        return x

In [None]:
!wget https://download.pytorch.org/models/vgg16_bn-6c64b313.pth

In [None]:
class SegNet(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.encoder = self.load_encoder(in_channel)
        self.decoder = Decoder(out_channel)

    def load_encoder(self, in_channel):
        weight = torch.load("./vgg16_bn-6c64b313.pth")
        for key in list(weight.keys()):
            if key.startswith('classifier'):
                del weight[key]
        encoder = Encoder(in_channel)

        encoder_names = []
        for key in encoder.state_dict().keys():
            if 'num_batches_tracked' in key:
                continue
            encoder_names.append(key)
        
        new_weights = self.state_dict()
        for key, value in zip(encoder_names, weight.values()):
            new_weights[key] = value
        
        encoder.load_state_dict(new_weights)

        return encoder
    
    def forward(self, data):
        x, ids = self.encoder(data)
        x = self.decoder(x, ids)
        return x

In [None]:
model = SegNet(3, 2)