In [23]:
import os
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.io import read_image
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import pandas as pd

__file__ = Path(os.path.realpath("__file__"))
project_path = __file__.parent
dataset_path = Path.home() / "Dataset" / "CelebA"

In [24]:
from collections import namedtuple
import csv
import PIL
from typing import Any, Callable, List, Optional, Union, Tuple

CSV = namedtuple("CSV", ["header", "index", "data"])


class CelebA(Dataset):
    """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        split (string): One of {'train', 'valid', 'test', 'all'}.
            Accordingly dataset is selected.
        target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
            or ``landmarks``. Can also be a list to output a tuple with all specified target types.
            The targets represent:

                - ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
                - ``identity`` (int): label for each person (data points with the same identity are the same person)
                - ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
                - ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
                  righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)

            Defaults to ``attr``. If empty, ``None`` will be returned as target.

        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    base_folder = "celeba"

    def __init__(
            self,
            dataset_path,
            split = "train",
            target_type = ["attr"],
            transform = None,
            target_transform = None,
    ) -> None:
        self.split = split
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        
        self.target_type = target_type
        self.transform = transform
        self.target_transform = target_transform
        
        dataset_path = Path(dataset_path)
        if not dataset_path.exists():
            raise RuntimeError("dataset path does not exist")
        self.dataset_path = dataset_path
        
        split_ = split_map[split.lower()]
        splits = self._load_csv(self.dataset_path / "Eval" / "list_eval_partition.txt")
        identity = self._load_csv(self.dataset_path / "Anno" / "identity_CelebA.txt")
        attr = self._load_csv(self.dataset_path / "Anno" / "list_attr_celeba.txt", header=1)
        bbox = self._load_csv(self.dataset_path / "Anno" / "list_bbox_celeba.txt", header=1)
        landmarks = self._load_csv(self.dataset_path / "Anno" / "list_landmarks_celeba.txt", header=1)
        landmarks_align = self._load_csv(self.dataset_path / "Anno" / "list_landmarks_align_celeba.txt", header=1)

        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()

        self.filename = splits.index
        self.identity = identity.data[mask]
        self.bbox = bbox.data[mask]
        self.landmarks_align = landmarks_align.data[mask]
        self.attr = attr.data[mask]
        # map from {-1, 1} to {0, 1}
        self.attr = torch.div(self.attr + 1, 2, rounding_mode='floor')
        self.attr_names = attr.header

    def _load_csv(
        self,
        filename,
        header: Optional[int] = None,
    ) -> CSV:
        data, indices, headers = [], [], []

        with open(filename) as csv_file:
            data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True))

        if header is not None:
            headers = data[header]
            data = data[header + 1:]

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, torch.tensor(data_int))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        X = PIL.Image.open(self.dataset_path / "img_align_celeba_png" / str(self.filename[index].split(".")[0] + ".png"))

        target: Any = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
                # TODO: refactor with utils.verify_str_arg
                raise ValueError("Target type \"{}\" is not recognized.".format(t))

        if self.transform is not None:
            X = self.transform(X)

        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None

        return X, target

    def __len__(self) -> int:
        return len(self.attr)

    def extra_repr(self) -> str:
        lines = ["Target type: {target_type}", "Split: {split}"]
        return '\n'.join(lines).format(**self.__dict__)

In [25]:
BATCH_SIZE = 64

train_data = CelebA(dataset_path=dataset_path, split="train", target_type=["attr"], transform=ToTensor())
valid_data = CelebA(dataset_path=dataset_path, split="valid", target_type=["attr"], transform=ToTensor())
test_data  = CelebA(dataset_path=dataset_path, split="test",  target_type=["attr"], transform=ToTensor())

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader  = DataLoader(test_data,  batch_size=BATCH_SIZE, shuffle=True)

In [52]:
# train_features, train_labels = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([64, 3, 218, 178])
Labels batch shape: torch.Size([64, 40])


In [57]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model.eval()

input_batch, train_labels = next(iter(train_dataloader))
with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)

Using cuda device


Using cache found in /home/sunmoon/.cache/torch/hub/pytorch_vision_v0.10.0


tensor([4.6437e-05, 2.7993e-04, 1.4022e-05, 4.3904e-05, 1.4072e-05, 2.1825e-05,
        1.8124e-05, 1.4898e-05, 2.2942e-06, 1.2285e-06, 3.4518e-05, 3.7220e-06,
        5.3221e-05, 2.1544e-05, 3.8609e-05, 5.4068e-06, 7.2105e-06, 2.4682e-05,
        1.5075e-05, 5.7484e-05, 7.5473e-06, 2.8291e-06, 3.4276e-06, 1.2852e-05,
        7.8343e-06, 4.2282e-06, 5.4779e-06, 1.4387e-05, 4.0712e-05, 1.1474e-04,
        4.0393e-06, 4.4407e-06, 1.6673e-05, 1.1810e-05, 2.0024e-05, 1.6985e-05,
        5.6110e-05, 8.9325e-06, 2.7367e-05, 4.2044e-06, 3.5585e-06, 5.9624e-06,
        1.7508e-05, 3.3779e-05, 1.0272e-05, 4.1349e-05, 5.2873e-06, 7.9776e-06,
        2.2164e-06, 7.4286e-06, 1.3204e-05, 6.9608e-05, 3.5796e-06, 3.4903e-06,
        1.6521e-05, 3.5552e-06, 3.7338e-05, 3.4693e-06, 5.0624e-06, 2.0486e-05,
        4.5229e-05, 6.2730e-05, 1.1617e-04, 2.0091e-05, 2.9312e-05, 6.8504e-06,
        1.5239e-05, 2.1021e-05, 2.9148e-05, 7.5893e-05, 2.0104e-05, 4.9530e-04,
        1.0213e-05, 2.2823e-05, 4.2519e-