Editing dataloader

In [1]:
import csv
import tarfile
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader


def tar_loader(filename):
    "Load image from tar for YFCC100M"
    root_dir = filename.parents[1]
    tarname = filename.parent.name
    image = filename.name
    
    filename = root_dir / (tarname + '.tar')
    reader = tarfile.open(filename)
    fid = reader.extractfile(f'{tarname}/{image}')
    assert fid is not None
    img = Image.open(fid)
    return img.convert('RGB')


class ImageFromCSV(Dataset):
    """Load images from a CSV list.

    It is a replacement for ImageFolder when you are interested in a
    particular set of images. Indeed, the only different is the way the images
    and targets are gotten.

    Args:
        filename (str, optional): CSV file with list of images to read.
        root (str, optional) : files in filename are a relative path with
            respect to the dirname here. It reduces size of CSV but not in
            memory.
        fields (sequence, optional): sequence with field names associated for
            image paths and targets, respectively. If not provided, it uses
            the first two fields in the first row.
    """

    def __init__(self, filename, root='', fields=None, transform=None,
                 target_transform=None, loader=default_loader):
        self.root = Path(root)
        self.filename = filename
        self.fields = fields
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.imgs = self._make_dataset()

    def _make_dataset(self):
        with open(self.filename, 'r') as fid:
            reader = csv.DictReader(fid)

            if self.fields is None:
                self.fields = reader.fieldnames
            else:
                check = [i in reader.fieldnames for i in self.fields]
                if not all(check):
                    raise ValueError(f'Missing fields in {self.filename}')

            imgs = []
            for i, row in enumerate(reader):
                img_name = row[self.fields[0]]
                path = self.root / img_name

                target = 0
                if len(self.fields) > 1:
                    target = row[self.fields[1]]

                imgs.append((path, target))
            return imgs

    def __getitem__(self, index):
        """Return item

        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where target is class_index of the target
                   class.
        """
        path, target = self.imgs[index]
        target = target
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.imgs)
    
filename = 'tmp.csv'
root = '/mnt/ilcompf2d1/data/yfcc100m/image'
aja = ImageFromCSV(filename, root, ['adobe_cil'], loader=tar_loader)
aja[0]

(<PIL.Image.Image image mode=RGB size=500x334 at 0x7F298358D5F8>, 0)

In [6]:
# !head yfcc100m_images_intersect_didemo_val_nouns_counter_eq_1_sample.csv
# !grep "4-000/0009966.jpg" yfcc100m_images_intersect_didemo_val_nouns_counter_eq_1_sample.csv 
# !grep "4-002/0021772.jpg" yfcc100m_images_intersect_didemo_val_nouns_counter_eq_1_sample.csv 
# !grep "4-002/0029950.jpg" yfcc100m_images_intersect_didemo_val_nouns_counter_eq_1_sample.csv
# !grep "4-006/0069517.jpg" yfcc100m_images_intersect_didemo_val_nouns_counter_eq_1_sample.csv

628,4-006/0069517.jpg,http://farm2.staticflickr.com/1425/832852758_f7d540ed82.jpg,canada,barcello;canada;montreal,canada,barcellos;canada;montreal


In [7]:
class AverageMeter(object):
    """Computes and stores the average and current value

    Credits:
        @pytorch team - Imagenet example
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


import time
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

filename = 'yfcc100m_images_intersect_didemo_val_nouns_counter_eq_1_sample.csv'
root = '/mnt/ilcompf2d1/data/yfcc100m/image'
fields = ['adobe_cil']
batch_size = 256
num_workers = 8
PIN_MEMORY = True

resize_transform = [transforms.Resize((224, 224))]
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
img_transform = transforms.Compose(
        resize_transform + [transforms.ToTensor(), normalize])

img_loader = DataLoader(
    dataset=ImageFromCSV(filename, root, fields=fields,
                         transform=img_transform,
                         loader=tar_loader),
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=PIN_MEMORY)
in_time = AverageMeter()
end = time.time()
for i, (img, _) in enumerate(img_loader):
    img_d = img.to('cuda:0')
    in_time.update(time.time() - end)
    if i % 10:
        print(f'[{i + 1}/{len(img_loader)}]\t'
              f'Data-in {in_time.val:.3f} ({in_time.avg:.3f})\t')
    if i == 100:
        break

[2/1954]	Data-in 208.819 (208.812)	
[3/1954]	Data-in 216.845 (211.490)	


KeyError: 'Traceback (most recent call last):\n  File "/mnt/ilcompf9d1/user/escorcia/install/bin/miniconda3/envs/scientific/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in _worker_loop\n    samples = collate_fn([dataset[i] for i in batch_indices])\n  File "/mnt/ilcompf9d1/user/escorcia/install/bin/miniconda3/envs/scientific/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in <listcomp>\n    samples = collate_fn([dataset[i] for i in batch_indices])\n  File "<ipython-input-1-48a438122394>", line 84, in __getitem__\n    sample = self.loader(path)\n  File "<ipython-input-1-48a438122394>", line 17, in tar_loader\n    fid = reader.extractfile(f\'{tarname}/{image}\')\n  File "/mnt/ilcompf9d1/user/escorcia/install/bin/miniconda3/envs/scientific/lib/python3.6/tarfile.py", line 2074, in extractfile\n    tarinfo = self.getmember(member)\n  File "/mnt/ilcompf9d1/user/escorcia/install/bin/miniconda3/envs/scientific/lib/python3.6/tarfile.py", line 1750, in getmember\n    raise KeyError("filename %r not found" % name)\nKeyError: "filename \'4-009/0098454.jpg\' not found"\n'