In [38]:
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs
from utils import download_url

def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)


def read_label_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        return torch.from_numpy(parsed).view(length).long()


def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols)
    
class MNIST_Class_Selection(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
    class_to_idx = {_class: i for i, _class in enumerate(classes)}

    @property
    def targets(self):
        if self.train:
            return self.train_labels
        else:
            return self.test_labels

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, class_nums=None):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        if class_nums is None:
            self.class_nums = set(xrange(10))
        else:
            self.class_nums = class_nums

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))
        
#         if self.train: 
#             self.dots = torch.FloatTensor(self.train_data.shape[0], 1, 28, 28)
#             self.classes = torch.FloatTensor(self.train_labels.shape)
#             print(self.dots.shape)
#             for idx in xrange(self.train_data.shape[0]):
#                 self.dots[idx] = self.transform(Image.fromarray(self.train_data[idx].numpy(), mode='L')).reshape(1, 28, 28)
#                 self.classes[idx] = self.train_labels[idx]
#         else:
#             self.dots = torch.FloatTensor(self.test_data.shape[0], 1, 28, 28)
#             self.classes = torch.FloatTensor(self.test_labels.shape)
#             for idx in xrange(self.test_data.shape[0]):
#                 self.dots[idx] = self.transform(Image.fromarray(self.test_data[idx].numpy(), mode='L')).reshape(1, 28, 28)
#                 self.classes[idx] = self.test_labels[idx]

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            download_url(url, root=os.path.join(self.root, self.raw_folder),
                         filename=filename, md5=None)
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                    gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        
        indixes_train = np.argwhere(np.apply_along_axis(lambda x : x[0] in self.class_nums, 1, np.array(training_set[1]).reshape(-1, 1)) == 1).reshape(-1)
        indixes_test = np.argwhere(np.apply_along_axis(lambda x : x[0] in self.class_nums, 1, np.array(test_set[1]).reshape(-1, 1)) == 1).reshape(-1)

        if len(self.class_nums) == 2:
            nums = list(self.class_nums)
            training_set[1][indixes_train] = torch.LongTensor(np.where(training_set[1][indixes_train] == nums[0], -1, 1))
            test_set[1][indixes_test] = torch.LongTensor(np.where(test_set[1][indixes_test] == nums[0], -1, 1))
        
        training_set = (training_set[0][indixes_train], training_set[1][indixes_train])
        test_set = (test_set[0][indixes_test], test_set[1][indixes_test])
        
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str


In [39]:

from torchvision import transforms

In [40]:
transform = transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                    ])
dataset_m_train = MNIST_Class_Selection('.', train=True, download=True, transform=transform, class_nums=set([3,4]))
dataset_m_test = MNIST_Class_Selection('.', train=False, transform=transform, class_nums=set([3,4]))


# dataloader_m_train = DataLoader(dataset_m_train, batch_size=32, shuffle=True)
# dataloader_m_test = DataLoader(dataset_m_test, batch_size=32, shuffle=True)

 15%|█▌        | 1.53M/9.91M [00:00<00:00, 15.2MB/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./raw/train-images-idx3-ubyte.gz


 79%|███████▉  | 7.83M/9.91M [00:00<00:00, 20.7MB/s]Exception KeyError: KeyError(<weakref at 0x7fe4dd98c3c0; to 'tqdm' at 0x7fe4dbeeae90>,) in <bound method tqdm.__del__ of 9.92MB [00:00, 20.7MB/s]> ignored
0.00B [00:00, ?B/s]Exception KeyError: KeyError(<weakref at 0x7fe4dd98c470; to 'tqdm' at 0x7fe4dbef1b10>,) in <bound method tqdm.__del__ of 32.8kB [00:00, 3.47MB/s]> ignored
 88%|████████▊ | 1.46M/1.65M [00:00<00:00, 14.5MB/s]Exception KeyError: KeyError(<weakref at 0x7fe4dd98c470; to 'tqdm' at 0x7fe4dbef1510>,) in <bound method tqdm.__del__ of 1.65MB [00:00, 14.5MB/s]> ignored
0.00B [00:00, ?B/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./raw/t10k-labels-idx1-ubyte.gz
Processing...


Exception KeyError: KeyError(<weakref at 0x7fe4dd98c470; to 'tqdm' at 0x7fe4dda32d90>,) in <bound method tqdm.__del__ of 8.19kB [00:00, 1.40MB/s]> ignored


Done!
torch.Size([11973, 1, 28, 28])


In [41]:
dataset_m_test.classes

tensor([ 1.,  1., -1.,  ...,  1., -1.,  1.])