# From IN22k, get IN6k and create miniIN6k

Get all IN22k classes

In [1]:
from glob import glob
# total classes in IN22k folder: 21841
# classes not in IN1k: 20842
# These 20842 classes have a total of 12.906.958 images
IMAGENET22K_DIR = '/datasets01_101/imagenet-22k/062717/'
all_ = glob(IMAGENET22K_DIR+'*')
classes_dir_22k = glob(IMAGENET22K_DIR+'n*.tar')
other_file = [x for x in all_ if x not in classes_dir_22k]

classes_22k = [x.split('/')[-1].split('.')[0] for x in classes_dir_22k]
len(classes_22k)

21841

Get clean images and classes, cleaned automatically to remove duplicates

In [2]:
CLEAN_IMAGE_LIST = 'data/IN6k/clean_images.txt' # 11795291 images filtered by Matthijs 
with open(CLEAN_IMAGE_LIST) as f:
    lines = f.readlines()

allclasses = [line.split('_')[0] for line in lines]
clean_classes = list(set(allclasses)) # from 21841 keep only 21783
len(clean_classes)

21783

Get IN1k classes, in order to avoid them

Images to use later among clean classes, and discard ones not in this dict
- Avoid classes from IN1k

In [7]:
import torch
IN1k_classes = torch.load('data/IN6k/IN1k_classes.pth') # load IN1k list of classes names
classes_not_in_IN1k = [x for x in clean_classes if x not in IN1k_classes]
len(classes_not_in_IN1k)

21394

In [8]:
from tqdm import tqdm
clean_class_to_im_dict = dict([(c,[]) for c in classes_not_in_IN1k])
for (c, line) in tqdm(zip(allclasses, lines)):
# for line in lines:
#     c = line.split('_')[0]
    if c in IN1k_classes:
        continue
    clean_class_to_im_dict[c].append(line.replace('\n',''))
# takes time

11795291it [01:08, 171232.20it/s]


Get tar dataset of each class not in IN1k

In [10]:
import os
import shutil
for c in classes_not_in_IN1k:
    f = f'/checkpoint/matthijs/imagenet-22k/tarindex/{c}.tarlog'
    shutil.copyfile(f, f'data/IN6k/tars/{c}.tarlog')
    print(os.path.exists(f))
    break

True


In [5]:
import sys
sys.path.insert(0,'/private/home/sbaio/spicy-lorikeet/')

from imagenet.imagenet22k import TarDataset
IMAGENET22K_DIR = '/datasets01_101/imagenet-22k/062717/'
i22ktarlogs = '/checkpoint/matthijs/imagenet-22k/tarindex/'

tar_dsets = {}
for c in classes_not_in_IN1k:
    tar_dataset = TarDataset(IMAGENET22K_DIR + c + '.tar',
                            i22ktarlogs + c + '.tarlog',
                            preload=True)
    tar_dsets[c] = tar_dataset

In [6]:
import matplotlib.pyplot as plt
l = []
for c in classes_not_in_IN1k:
    l.append(len(clean_class_to_im_dict[c]))
l = sorted(l)
plt.plot(l)

[<matplotlib.lines.Line2D at 0x7fe51d53be10>]

In [9]:
# sorted_classes_not_in_IN1k = sorted(list(tar_dsets.items()), key=lambda x:len(x[1]))
sorted_classes_not_in_IN1k = sorted(list(clean_class_to_im_dict.items()), key=lambda x:len(x[1]))
len(sorted_classes_not_in_IN1k[-1][1])

2248

In [10]:
# classes_with_more_than_1k_images = sorted(list(filter(lambda x:len(x[1])>1000, list(tar_dsets.items()))), key=lambda x:len(x[1]), reverse=True)
classes_with_more_than_1k_images = sorted(list(filter(lambda x:len(x[1])>=900, list(clean_class_to_im_dict.items()))), key=lambda x:len(x[1]), reverse=True)
len(classes_with_more_than_1k_images)
# allow taking less than 1000 images ... to have 6k classes

6056

In [11]:
len(classes_with_more_than_1k_images[:6000][-1][1])

907

In [12]:
largest_6k_classes = classes_with_more_than_1k_images[:6000]
print(len(largest_6k_classes[0][1]), len(largest_6k_classes[-1][1]))

2248 907


Double check intersection with IN1k classes

In [13]:
set([x[0] for x in largest_6k_classes]).intersection(IN1k_classes)

set()

In [14]:
# Expected size of IN6k excluding IN1k
sum([len(x[1]) for x in largest_6k_classes])

# compared to 7791368 before removing duplicates

7135116

Get a dict of class name to imlist, avoid using tar_dset imlist, because it contains non clean images

In [15]:
largest_6k_classes_dict = dict(largest_6k_classes)

In [16]:
import sys
sys.path.insert(0, '/private/home/sbaio/spicy-lorikeet/')
from imagenet.imagenet22k import TarDataset
IMAGENET22K_DIR = '/datasets01_101/imagenet-22k/062717/'
i22ktarlogs = '/checkpoint/matthijs/imagenet-22k/tarindex/'
import os
import cv2
import time
from glob import glob
import numpy as np
try:
    from StringIO import StringIO as DataIO
except ImportError:
    from io import BytesIO as DataIO
from PIL import Image
from torchvision import transforms as T

IN_resize = T.Resize((256))
mini_resize = T.Resize((84,84))

image_size = 84
IN6k_dst = '/checkpoint/sbaio/IN6k2/'
miniIN6k_dst = '/checkpoint/sbaio/miniIN6k2/'

def copy_class(c, largest_6k_classes_dict):
    def process_im(tar_dset, c, imname, dst_dir, minidst_dir):
        import os
        from PIL import Image
        try:
            from StringIO import StringIO as DataIO
        except ImportError:
            from io import BytesIO as DataIO
        from torchvision import transforms as T

        IN_resize = T.Resize((256))
        mini_resize = T.Resize((84,84))
        dst_file = os.path.join(dst_dir, c, imname.split('.')[0]+'.jpg')
        minidst_file = os.path.join(minidst_dir, c, imname.split('.')[0]+'.jpg')
        if os.path.exists(dst_file) and os.path.exists(minidst_file):
            return
        data = tar_dset.get_name(imname)
        try:
            im = Image.open(DataIO(data))
        except Exception as e:
            print("Error im %s, %s" % (imname, e))
            im = Image.new('RGB', (256, 256))
        im = im.convert('RGB')
        im = IN_resize(im)
        im.save(dst_file)

        im = mini_resize(im)
        im.save(minidst_file)
    import time
    import sys
    import os
    sys.path.insert(0,'/private/home/sbaio/spicy-lorikeet/')

    from imagenet.imagenet22k import TarDataset
    IMAGENET22K_DIR = '/datasets01_101/imagenet-22k/062717/'
    i22ktarlogs = '/checkpoint/matthijs/imagenet-22k/tarindex/'
    
    IN6k_dst = '/checkpoint/sbaio/IN6k2/'
    miniIN6k_dst = '/checkpoint/sbaio/miniIN6k2/'
    
    start = time.time()
    tar_dataset = TarDataset(IMAGENET22K_DIR + c + '.tar',
                        i22ktarlogs + c + '.tarlog',
                        preload=True)
    for dst_dir in [IN6k_dst, miniIN6k_dst]:
        cdir = dst_dir+'{}'.format(c)
        if not os.path.exists(cdir):
            os.mkdir(cdir)

    for i,imname in enumerate(largest_6k_classes_dict[c]):#tar_dataset.names):
        process_im(tar_dataset, c, imname, dst_dir=IN6k_dst, minidst_dir=miniIN6k_dst)
    msg = 'Took {:0.2f}, Copied {} images of class {}'.format(time.time()-start, i+1, c)
    return msg

In [18]:
classes_to_copy = [x[0] for x in largest_6k_classes]
len(classes_to_copy)

6000

In [None]:
# from concurrent import futures
# import warnings
# warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
# N = 0
# print('Copying {} classes'.format(len(classes_to_copy)))
# # We can use a with statement to ensure threads are cleaned up promptly
# with futures.ThreadPoolExecutor(max_workers=100) as executor:
#     # Start the load operations and mark each future with its arg
#     future_to_url = {executor.submit(copy_class, i): i for i in classes_to_copy}
#     for future in futures.as_completed(future_to_url):
#         url = future_to_url[future]
#         try:
#             ret = future.result()
#         except Exception as exc:
#             print('%r generated an exception: %s' % (url, exc))
#         else:
#             N += 1
#             print('{} : {}'.format(N, ret))

In [None]:
import time
import submitit
executor = submitit.AutoExecutor(folder='/checkpoint/sbaio/jobs/copy_miniIN/%j')
executor.update_parameters(timeout_min=100, partition='learnfair', constraint="volta",
                           tasks_per_node=1,gpus_per_node=1, mem_gb=100, 
                           cpus_per_task=10, nodes=1, signal_delay_s=120)
# jobs = []
max_running_jobs = 200
i = 0
for c in classes_to_copy[::-1]:
    running = [job for job in jobs if job.state=='RUNNING']
    failed = [job for job in jobs if job.state=='FAILED']
    completed = [job for job in jobs if job.state=='COMPLETED']
    pending = [job for job in jobs if job.state=='PENDING']
    unknown = [job for job in jobs if job.state=='UNKNOWN']

    if len(failed) > 0:
        print('Some jobs failed')
        break
    while len(running)+len(pending)+len(unknown)>=max_running_jobs:
        time.sleep(0.1)
        i+=1
        if i % 10 == 0:
            print('Waiting for some jobs to finish, completed {}'.format(len(completed)))

    job = executor.submit(copy_class, c, largest_6k_classes_dict)
    jobs.append(job)
    print('Launched job: {}'.format(job))

In [22]:
for job in jobs:
    job.cancel()

In [42]:
print('Done')

Done


### Compute mean and std of the resulting dataset

In [None]:
from datasets import get_dataset
from torchvision import transforms as T
dset = get_dataset('miniIN6k','train', no_transform=True)
transform = T.Compose([T.ToTensor()])
dset.transform = transform
# dset = get_dataset('IN1k','train', transform=T.Compose([T.Resize((256,256)),T.ToTensor()]))
#
import time
from torch.utils.data import DataLoader
loader = DataLoader(
    dset,
    batch_size=2000,
    num_workers=80,
    shuffle=False#, pin_memory=True
)

mean = 0.
std = 0.
nb_samples = 0.
print(len(loader))
start = time.time()
for i,(batch,_) in enumerate(loader):
    batch_samples = batch.size(0)
    batch = batch.view(batch_samples, batch.size(1), -1)
    mean += batch.mean(2).sum(0)
    std += batch.std(2).sum(0)
    nb_samples += batch_samples
    if i%10==0:
        print(i, '{:0.2f}'.format(time.time()-start))
        print(mean/nb_samples, std/nb_samples)
        start = time.time()

mean /= nb_samples
std /= nb_samples
print(mean)
print(std)

In [2]:
print(mean)
print(std)

tensor([-0.2158, -0.2202, -0.1855])
tensor([1.1501, 1.1323, 1.0899])


### Cache it

In [43]:
# load dataset and cache it
from torchvision.datasets import ImageFolder
miniIN6k = ImageFolder('/checkpoint/sbaio/miniIN6k_clean/')

In [44]:
print(len(miniIN6k), len(miniIN6k.classes))

7135116 6000


In [45]:
dset = miniIN6k
tosave = {
    'classes':dset.classes,
    'class_to_idx':dset.class_to_idx,
    'samples':dset.samples
}
# torch.save(tosave, '/private/home/sbaio/.cache/miniIN6k_clean.bin')
# print('Saved cache.')

Saved cache.
