<a href="https://colab.research.google.com/github/ashura1234/deeplabv3-Segmentation/blob/main/indoornyu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from __future__ import print_function

import import_ipynb
import torch
import torch.utils.data as data
import os
import random
import glob
from PIL import Image
from utils import preprocess

importing Jupyter notebook from utils.ipynb


In [None]:
_FOLDERS_MAP = {
    'image': 'images',
    'label': 'target',
}

_DATA_FORMAT_MAP = {
    'image': 'png',
    'label': 'png',
}

In [None]:
class IndoorNYU(data.Dataset):
    CLASSES = [
      'ceiling', 'floor', 'wall'
    ]

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, crop_size=None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.crop_size = crop_size

        if download:
            self.download()

        dataset_split = 'train' if self.train else 'val'
        self.images = self._get_files('image', dataset_split)
        self.masks = self._get_files('label', dataset_split)

    def __getitem__(self, index):
        _img = Image.open(self.images[index]).convert('RGB')
        _target = Image.open(self.masks[index])

        _img, _target = preprocess(_img, _target,
                                   flip=True if self.train else False,
                                   scale=(0.5, 2.0) if self.train else None,
                                   crop=(self.crop_size, self.crop_size) if self.train else (480, 640))

        if self.transform is not None:
            _img = self.transform(_img)

        if self.target_transform is not None:
            _target = self.target_transform(_target)
        _target = torch.squeeze(_target[:,:,1])
        #print("target shape =", _target.shape)
        return _img, _target

    def _get_files(self, data, dataset_split):
        pattern = '*.%s' % (_DATA_FORMAT_MAP[data])
        search_files = os.path.join(
            self.root, _FOLDERS_MAP[data], dataset_split, '*', pattern)
        filenames = glob.glob(search_files)
        print("Read", len(filenames), "files")
        return sorted(filenames)

    def __len__(self):
        return len(self.images)

    def download(self):
        raise NotImplementedError('Automatic download not yet implemented.')


NameError: ignored