In [None]:
!wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300-images.tgz
!wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/segbench/BSDS300-human.tgz
!tar zxvf *-images.tgz
!tar zxvf *-human.tgz

In [None]:
import re
import os.path
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

In [None]:
root_dir = "./BSDS300"
train_data = os.path.join(root_dir, 'images/train')
test_data = os.path.join(root_dir, 'images/test')
labels = os.path.join(root_dir, 'human/color')

In [None]:
def extract_labels(label_files):
    """Converts segmentation data from .seg files into np.array"""
    meta = {}
    data = []
    with open(label_files, 'r') as file:
        matcher = re.compile('(?P<seg>^[0-9 ]+)')
        for line in file:
            seg_match = matcher.search(line)
            if seg_match:
                string_segment = seg_match.group('seg').split(' ')
                int_segment = np.asarray(string_segment, dtype=int)
                data.append(int_segment)
                continue
            elif "data" not in line:
                meta_data = line.strip('\n').split(' ', 1)
                index, value = meta_data[0], meta_data[1]
                meta[index] = value
    height, width = int(meta['height']), int(meta['width'])
    seg_num = int(meta['segments'])
    # print(f"User id: {meta['user']}     Image id: {meta['image']}")
    # print(f"Height: {height}       Width: {width}")
    segmentation = np.zeros((height, width))
    for seg in data:
        segmentation[seg[1], seg[2]:(seg[3] + 1)] = seg[0]
    return segmentation, seg_num

In [None]:
def create_image(seg_val, seg_max):
    """Creates an image using extracted segmentation data"""
    seg_val = (seg_val / seg_max) * 255
    plt.figure()
    plt.axis('off')
    plt.imshow(seg_val)
    plt.show()

# To test whether extracting segmentations is successful
test = os.path.join(labels, '1105/15004.seg')
segmentation, seg_num = extract_labels(test)
create_image(segmentation, seg_num)

In [None]:
class CropArrayCentre(object):
    """Custom transform to crop the centre of arrays (both images and segmentations)"""
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, label = sample["image"], sample["label"]
        y, x, _ = image.shape
        crop_y, crop_x = self.output_size
        start_x, start_y = x // 2 - (crop_x // 2), y // 2 - (crop_y // 2)
        image = image[start_y: start_y + crop_y, start_x: start_x + crop_x]
        label = label[start_y: start_y + crop_y, start_x: start_x + crop_x]
        return {'image': image, 'label': label}


class TwoTensor(object):
    """Custom transform to convert arrays (both images and segmentations) to tensors"""
    def __call__(self, sample):
        image, label = sample["image"], sample["label"]
        image, label = torch.from_numpy(image).permute(2, 0, 1), torch.from_numpy(label)
        return {'image': image, 'label': label}

In [None]:
class Berkeley(Dataset):
    """Custom dataset containing training/test data + their respective labels"""
    def __init__(self, image_files, label_files):
        """Images and labels are converted into np.arrays and listed in ascending index
        :param image_files: Path to images
        :param label_files: Path to segmentation labels"""
        self.images, self.labels = self.array_from_path(image_files, label_files)
        self.transform = transforms.Compose([
            CropArrayCentre(321),  # Crops image + segmentation to uniform size
            TwoTensor()  # Converts image + segmentation np.array to torch.tensor
        ])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        sample = {"image": self.images[index], "label": self.labels[index]}
        if self.transform:
            sample = self.transform(sample)
        return sample

    def array_from_path(self, image_files, label_files):
        images = sorted([img for img in os.listdir(image_files)])  # List in format .jpg
        image_names = [os.path.splitext(img)[0] for img in images]  # List of image ids (sans .jpg)
        ordered_files = {}
        for root, user_folder, files in os.walk(label_files):
            for file in files:
                file_name = os.path.splitext(file)[0]
                if file_name in image_names:
                    file_path = os.path.join(root, file)
                    seg, _ = extract_labels(file_path)
                    ordered_files[file_name] = seg

        images = [np.asarray(Image.open(os.path.join(image_files, img))) for img in images]
        labels = [value for _, value in sorted(ordered_files.items(), key=lambda ele: ele[0])]
        return images, labels

In [None]:
train = Berkeley(train_data, labels)
train_loader = DataLoader(
    train,
    batch_size=25,
    shuffle=False)

In [None]:
test = Berkeley(test_data, labels)
test_loader = DataLoader(
    test,
    batch_size=25,
    shuffle=False)

In [None]:
def show_images(sample_batched, i):
    """Show images and segmentation side-by-side in a batch of samples"""
    images_batch, labels_batch = sample_batched['image'], sample_batched['label']
    image_0, seg_0 = (images_batch[i].permute(1, 2, 0)), (labels_batch[i])
    image_1, seg_1 = (images_batch[i+1].permute(1, 2, 0)), (labels_batch[i+1])
    four_show = [image_0, seg_0, image_1, seg_1]

    for ind in range(len(four_show)):
        axarr[ind].imshow(four_show[ind])
        axarr[ind].axis('off')

for i_batch, sample_batched in enumerate(train_loader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['label'].size())
    if i_batch == 5: # Stop and examine a specific batch
        for ind in range(0,4,2):
            f, axarr = plt.subplots(1,4, figsize=(15,15))
            show_images(sample_batched, ind)
            plt.ioff()
            plt.show()
        break
        

In [None]:
class Block(torch.nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.stack = torch.nn.Sequential(
            torch.nn.Conv2d(in_channel, out_channel, (3, 3)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channel, out_channel, (3, 3)),
            torch.nn.BatchNorm2d(out_channel),
            torch.nn.ReLU()
        )

    def forward(self, x):
        x = self.stack(x)
        return x

In [None]:
class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, model):
        x, features = self.contracting(model)
        model = self.expanding(x, features)
        return model

    def contracting(self, model):
        channels = [1, 64, 128, 256, 512, 1024]
        features = []
        for i in range(len(channels)-1):
            block = Block(channels[i], channels[i+1])(model)
            features.append(block)
            model = torch.nn.MaxPool2d((2, 2))(block)
        return model, features

    def expanding(self, model, features):
        channels = [1024, 512, 256, 128, 64]
        for i in range(len(channels)-1):
            block = Block(channels[i], channels[i+1])(model)
            block = torch.cat([block, features[(len(channels)) - i - 2]])
            model = torch.nn.ConvTranspose2d(channels[i+1], channels[i+1], (2, 2))(block)
        return model
