In [2]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.models as models
from torch import nn
from torch import optim

from tqdm import tqdm
import os
import numpy as np
from PIL import Image
import pickle as pkl
import matplotlib.pyplot as plt

import util
from DuckDataset import DuckDataset

%load_ext autoreload
%autoreload 2

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-9vk1ybjf because the default path (/home/jovyan/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
# provide imagenet data here
imagenettrain_dir = '/mnt/qb/datasets/ImageNet2012/train/'
imagenetval_dir = '/mnt/qb/datasets/ImageNet2012/val/'
duckdata_dir = 'data'

In [4]:
# generate background images from a subset of imagenet classes

folders = []
for folder in os.listdir(imagenettrain_dir):
    folders.append(folder)
folders.sort()
np.random.seed(0)
folders = np.random.choice(folders, size=10, replace=False)

# train
images = []
for folder in folders:
    for file in os.listdir(os.path.join(imagenettrain_dir, folder)):
        images.append(os.path.join(imagenettrain_dir, folder, file))  
pil_images = []
for image in tqdm(images):
    im = Image.open(image)
    im.load()
    pil_images.append(im)
pil_images = [img.convert('RGB') for img in pil_images]
pkl.dump(pil_images, open(f'{duckdata_dir}/imagenet10_train.pkl' , 'wb+'))

# val
images = []
for folder in folders:
    for file in os.listdir(os.path.join(imagenetval_dir, folder)):
        images.append(os.path.join(imagenetval_dir, folder, file))  
pil_images = []
for image in tqdm(images):
    im = Image.open(image)
    im.load()
    pil_images.append(im)
pil_images = [img.convert('RGB') for img in pil_images]
pkl.dump(pil_images, open(f'{duckdata_dir}/imagenet10_val.pkl' , 'wb+'))

"\n\n# subset of classes\nfolders = []\nfor folder in os.listdir('/mnt/qb/datasets/ImageNet2012/train/'):\n    folders.append(folder)\nfolders.sort()\nnp.random.seed(0)\nfolders = np.random.choice(folders, size=10, replace=False)\n\n# train\nimages = []\nfor folder in folders:\n    for file in os.listdir(os.path.join('/mnt/qb/datasets/ImageNet2012/train/', folder)):\n        images.append(os.path.join('/mnt/qb/datasets/ImageNet2012/train/', folder, file))  \npil_images = []\nfor image in tqdm(images):\n    im = Image.open(image)\n    im.load()\n    pil_images.append(im)\npil_images = [img.convert('RGB') for img in pil_images]\npkl.dump(pil_images, open('/mnt/qb/luxburg/frieder/imagenet10_train.pkl' , 'wb+'))\n\n# val\nimages = []\nfor folder in folders:\n    for file in os.listdir(os.path.join('/mnt/qb/datasets/ImageNet2012/val/', folder)):\n        images.append(os.path.join('/mnt/qb/datasets/ImageNet2012/val/', folder, file))  \npil_images = []\nfor image in tqdm(images):\n    im = I

In [5]:
# dataset class for the background images
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 images, 
                 transform = None,
            ):
            super(ImageDataset, self).__init__()
            self.images = images
            self.transform = transform    

    def __getitem__(self, index):
        assert index < len(self.images), 'Invalid index!'
        # get the image
        img = self.images[index]
        # apply transform
        if self.transform is not None:
            img = self.transform(img)
        return img, 0
            
    def __len__(self):
        return len(self.images)

In [7]:
# load background images
train_images = pkl.load(open(f'{duckdata_dir}/imagenet10_train.pkl' , 'rb'))
val_images = pkl.load(open(f'{duckdata_dir}/imagenet10_val.pkl', 'rb'))

background_train = ImageDataset(train_images, transform=transforms.Compose([transforms.RandomResizedCrop(224), 
                                                                            transforms.RandomHorizontalFlip()]))
background_val = ImageDataset(val_images, transform=transforms.Compose([transforms.RandomResizedCrop(224),
                                                                        transforms.RandomHorizontalFlip()])) 


In [8]:
# generate training and validation set containing ducks

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
random_size = (50, 100)

trainset = DuckDataset(background_train,
                       random_size=random_size,
                       transform=transforms.Compose([transforms.ToTensor(), normalize]),
                       uniform_yellow = False
                       )
valset = DuckDataset(background_val,
                     random_size=random_size,
                     transform=transforms.Compose([transforms.ToTensor(), normalize]),
                     uniform_yellow = False
                     )

pkl.dump(trainset, open(f'{duckdata_dir}/duck_train.pkl' , 'wb+'))
pkl.dump(valset, open(f'{duckdata_dir}/duck_val.pkl' , 'wb+'))


In [12]:
# generate a (reproducible) visualisation data set 
torch.manual_seed(1)
np.random.seed(1)

vis_set = DuckDataset(background_val,
                      random_size=random_size,
                      transform=transforms.ToTensor(),
                      uniform_yellow = False
                      )

pkl.dump(vis_set,open(f'{duckdata_dir}/duck_vis.pkl', 'wb'))

In [13]:
# extract images containing a duck

imgs, labels, duck_positions, masks = [], [], [], []
for _ in range(200):
    img, label, duck_position, bgr = vis_set.__draw_random__()
    if label == 0:
        continue
        
    mask = img - bgr
    mask[mask != 0] = 1
    mask = mask.sum(axis = 0)
    mask[mask != 0] = 1
    
    imgs.append(img)
    labels.append(label)
    duck_positions.append(duck_position)
    masks.append(mask)
    
    examples = (imgs, labels, duck_positions, masks)
pkl.dump(examples,open(f'{duckdata_dir}/duck_vis_examples.pkl', 'wb'))