In [1]:
%matplotlib inline
from IPython.display import clear_output
import os
from copy import deepcopy

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import re

torch.manual_seed(42)
np.random.seed(42)

In [2]:
import os

import torch
import numpy as np
from torchvision import datasets
import re

from torchvision.datasets.utils import download_url

class LFWPeopleAttribute(datasets.LFWPeople):
  def __init__(
    self,
    root,
    split = '10fold',
    image_set = 'funneled',
    transform = None,
    target_transform = None,
    download = False,
    attribute_classes = None,
    target_class = None,
  ):
    self.attributes_file = f"lfw_attributes.txt"
    super().__init__(root, split, image_set, transform, target_transform, download)
    # todo - check integrity
    # get attributes
    self.attribute_class_selects = attribute_classes
    self.target_class = target_class
    self.attribute_class_to_idx = self._get_attribute_class()

    self.target_to_attribute, self.attribute_class_selects_idxs = self._get_target_to_attributes()


    self.data, self.targets, self.attributes = self._normalize_data()

    self.identities = self.targets
    if self.target_class is not None:
      self.attribute_target_index = self.attribute_class_selects_idxs[self.target_class]
      self.targets = self._get_attribute_target()

  def download(self):
    super().download()
    download_url(f"https://www.cs.columbia.edu/CAVE/databases/pubfig/download/lfw_attributes.txt", self.root)

  def _get_attribute_class(self):
    with open(os.path.join(self.root, self.attributes_file)) as f:
      for i, line in enumerate(f):
        if i == 1:
          attribute_classes = line.strip().split('\t')[3:]
          break
        continue
    return {attribute_class : i for i, attribute_class in enumerate(attribute_classes)}

  def _get_target_to_attributes(self):
    target_to_attribute = {}

    with open(os.path.join(self.root, self.attributes_file)) as f:
      lines = f.readlines()
      attribute_classes = lines[1].strip().split('\t')[3:]
      
      attribute_class_selects_idxs = {attr: self.attribute_class_to_idx[attr] for attr in self.attribute_class_selects}

      for line in lines[2:]:
        split_lines = line.strip().split("\t")
        identity = split_lines[0]
        identity = re.sub(" ", "_", identity)
        idx = self.class_to_idx[identity]
        identity_attribute = [split_lines[2:][i] for i in attribute_class_selects_idxs.values()]
        target_to_attribute[idx] = identity_attribute

      return target_to_attribute, attribute_class_selects_idxs

  def _get_binary_attribute_target(self):
    attributes = np.Array(self.attributes)
    if self.target_class == 'gender':
      attribute_idx = self.attribute_class_selects_idxs['Male']
    elif self.target_class == 'smile':
      attribute_idx = self.attribute_class_selects_idxs['Smiling']
    attributes[:, attribute_idx]
    targets = np.sign(attributes)
    targets[targets == -1] = -1

    return targets

  def _get_multiple_attribute_target(self):
    attributes = np.Array(self.attributes)
    if self.target_class == 'race':
      idxs = [self.attribute_class_selects_idxs[attr] for attr in ['Asian', 'White', 'Black']]
    elif self.target_class == 'age':
      idxs = [self.attribute_class_selects_idxs[attr] for attr in ['Baby', 'Child', 'Youth', 'Middle Aged', 'Senior']]
    elif self.target_class == 'hair':
      idxs = [self.attribute_class_selects_idxs[attr] for attr in ['Black Hair', 'Blond Hair', 'Brown Hair', 'Bald']]
    elif self.target_class == 'eyewear':
      idxs = [self.attribute_class_selects_idxs[attr] for attr in ['No Eyewear', 'Eyeglasses', 'Sunglasses']]


  def _get_attribute_target(self):
    if self.target_class in ['gender', 'smile']:
      targets = self._get_binary_attribute_target()
    elif self.target_class in ['race', 'age', 'hair', 'eyewear']:
      targets = self._get_multiple_attribute_target()

    return targets


  def _normalize_data(self):
    normalized_datas = []
    normalized_targets = []
    normalized_attributes = []
    for data, target in zip(self.data, self.targets):
      if target in self.target_to_attribute:
        normalized_datas.append(data)
        normalized_targets.append(target)
        normalized_attributes.append(self.target_to_attribute[target])
    return normalized_datas, normalized_targets, normalized_attributes

  def __getitem__(self, index):
    img = self._loader(self.data[index])
    attribute = self.attributes[index]
    target = self.targets[index]

    if self.transform is not None:
       img = self.transform(img)

    if self.target_transform is not None:
      target = self.target_transform(target)

    return img, target, attribute

data = custom_datasets.LFWPeopleAttribute(
    root = '/content/',
    download = True,
    attribute_classes=[ # todo : delete this and change to default property datasets object
      'Male', # Gender
      'Smiling', # Smile
      'Asian', 'White', 'Black', # Race
      'Baby', 'Child', 'Youth', 'Middle Aged', 'Senior', # Age
      'Black Hair', 'Blond Hair', 'Brown Hair', 'Bald',# Hair
      'No Eyewear', 'Eyeglasses', 'Sunglasses', # Eyewear
    ],
    # inference_class = '',
    target_class = 'gender'
  )

Downloading http://vis-www.cs.umass.edu/lfw/lfw-funneled.tgz to ./lfw-py/lfw-funneled.tgz


  0%|          | 0/243346528 [00:00<?, ?it/s]

Extracting ./lfw-py/lfw-funneled.tgz to ./lfw-py
Downloading http://vis-www.cs.umass.edu/lfw/people.txt to ./lfw-py/people.txt


  0%|          | 0/94770 [00:00<?, ?it/s]

Downloading http://vis-www.cs.umass.edu/lfw/lfw-names.txt to ./lfw-py/lfw-names.txt


  0%|          | 0/94727 [00:00<?, ?it/s]

Downloading https://www.cs.columbia.edu/CAVE/databases/pubfig/download/lfw_attributes.txt to ./lfw-py/lfw_attributes.txt


  0%|          | 0/14879205 [00:00<?, ?it/s]

In [5]:
attributes = np.array(data.attributes)

In [9]:
attribute_subset = attributes[:, [1, 2]]

In [11]:
max_subset = np.argmax(attribute_subset, axis=-1)

In [13]:
attribute_subset.shape

(13205, 2)

In [14]:
max_subset.shape

(13205,)