In [1]:
import pandas as pd
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
import torch
import torch.nn as nn

METAINFO = {
    "classes": (
        "unlabelled", "asphalt/concrete", "dirt", "mud", "water", "gravel",
        "other-terrain", "tree-trunk", "tree-foliage", "bush/shrub", "fence",
        "other-structure", "pole", "vehicle", "rock", "log", "other-object",
        "sky", "grass",
    ),
    "palette": [
        (0, 0, 0), (230, 25, 75), (60, 180, 75), (255, 225, 25), (0, 130, 200),
        (145, 30, 180), (70, 240, 240), (240, 50, 230), (210, 245, 60),
        (250, 190, 190), (0, 128, 128), (170, 110, 40), (255, 250, 200),
        (128, 0, 0), (170, 255, 195), (128, 128, 0), (255, 215, 180),
        (0, 0, 128), (128, 128, 128),
    ],
    "cidx": list(range(19))
}

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.05):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_prob)
        )

    def forward(self, x):
        return self.conv(x)

class UNetPlusPlus(nn.Module):
    def __init__(self, num_classes, deep_supervision=True, dropout_prob=0.05):
        super(UNetPlusPlus, self).__init__()
        self.deep_supervision = deep_supervision

        nb_filter = [32, 64, 128, 256, 512]

        self.pool = nn.MaxPool2d(2, 2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = ConvBlock(3, nb_filter[0], dropout_prob)
        self.conv1_0 = ConvBlock(nb_filter[0], nb_filter[1], dropout_prob)
        self.conv2_0 = ConvBlock(nb_filter[1], nb_filter[2], dropout_prob)
        self.conv3_0 = ConvBlock(nb_filter[2], nb_filter[3], dropout_prob)
        self.conv4_0 = ConvBlock(nb_filter[3], nb_filter[4], dropout_prob)

        self.conv0_1 = ConvBlock(nb_filter[0]+nb_filter[1], nb_filter[0], dropout_prob)
        self.conv1_1 = ConvBlock(nb_filter[1]+nb_filter[2], nb_filter[1], dropout_prob)
        self.conv2_1 = ConvBlock(nb_filter[2]+nb_filter[3], nb_filter[2], dropout_prob)
        self.conv3_1 = ConvBlock(nb_filter[3]+nb_filter[4], nb_filter[3], dropout_prob)

        self.conv0_2 = ConvBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], dropout_prob)
        self.conv1_2 = ConvBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], dropout_prob)
        self.conv2_2 = ConvBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], dropout_prob)

        self.conv0_3 = ConvBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], dropout_prob)
        self.conv1_3 = ConvBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], dropout_prob)

        self.conv0_4 = ConvBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], dropout_prob)

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)

    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]
        else:
            output = self.final(x0_4)
            return output


In [None]:
def load_and_sample_data(file_path, sample_fraction=1):
    df = pd.read_csv(file_path)
    sampled_df = df.sample(frac=sample_fraction, random_state=42)
    return sampled_df

# Example image path
test_df = load_and_sample_data('splits/test.csv')

class WildScene(Dataset):
    def __init__(self, df, img_size, num_classes, transform=None):
        self.df = df
        self.img_size = img_size
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row['im_path'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img / 255.0
        
        label = cv2.imread(row['label_path'], cv2.IMREAD_GRAYSCALE)
        label = cv2.resize(label, (self.img_size, self.img_size))
        
        if self.transform:
            img = self.transform(img)
        
        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        label = torch.from_numpy(label).long()
        
        return img, label


def label_to_rgb(label, palette):
    rgb_image = np.zeros((*label.shape, 3), dtype=np.uint8)
    for label_idx, color in enumerate(palette):
        rgb_image[label == label_idx] = color
    return rgb_image


def visualize_prediction(image, label, prediction, palette):
    image = image.cpu().numpy().transpose((1, 2, 0))
    label = label.cpu().numpy()
    prediction = prediction.cpu().numpy()

    plt.figure(figsize=(15, 15))
    plt.subplot(131)
    plt.imshow(image)
    plt.title("Image")
    plt.axis('off')

    plt.subplot(132)
    plt.imshow(label_to_rgb(label, palette))
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(133)
    plt.imshow(label_to_rgb(prediction, palette))
    plt.title("Prediction")
    plt.axis('off')
    plt.show()

test_dataset = WildScene(test_df, 512, 19)
test_loader = DataLoader(test_dataset, batch_size=8)

model_path = 'unetplusplus_best_model.pth'
# Initialize the model
num_classes = 19
model = UNetPlusPlus(num_classes=num_classes, deep_supervision=True).cuda()
test_df.head(16)

In [None]:
model.load_state_dict(torch.load(model_path))

model.eval()
with torch.no_grad():
    for data in test_loader:
        if data is None:
            continue
        
        images, labels = data
        images = images.cuda()

        outputs = model(images)
        if isinstance(outputs, list):
            outputs = outputs[-1]

        outputs = outputs.argmax(1).cpu()

        for i in range(len(images)):
            image = images[i]
            label = labels[i]
            prediction = outputs[i]
            visualize_prediction(image, label, prediction, METAINFO['palette'])
