In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tarfile
import os
import pickle as pkl
import numpy as np
import skimage
import skimage.io
import skimage.transform
import tensorflow as tf
from sklearn.model_selection import train_test_split

In [2]:
BST_PATH = 'BSR_bsds500.tgz'

rand = np.random.RandomState(42)

In [3]:
f = tarfile.open(BST_PATH)
train_files = []
for name in f.getnames():
    if name.startswith('BSR/BSDS500/data/images/train/'):
        train_files.append(name)

In [4]:
train_files

['BSR/BSDS500/data/images/train/97017.jpg',
 'BSR/BSDS500/data/images/train/124084.jpg',
 'BSR/BSDS500/data/images/train/170054.jpg',
 'BSR/BSDS500/data/images/train/249061.jpg',
 'BSR/BSDS500/data/images/train/216066.jpg',
 'BSR/BSDS500/data/images/train/22093.jpg',
 'BSR/BSDS500/data/images/train/43083.jpg',
 'BSR/BSDS500/data/images/train/254033.jpg',
 'BSR/BSDS500/data/images/train/94079.jpg',
 'BSR/BSDS500/data/images/train/159091.jpg',
 'BSR/BSDS500/data/images/train/166081.jpg',
 'BSR/BSDS500/data/images/train/130034.jpg',
 'BSR/BSDS500/data/images/train/55075.jpg',
 'BSR/BSDS500/data/images/train/106020.jpg',
 'BSR/BSDS500/data/images/train/189011.jpg',
 'BSR/BSDS500/data/images/train/61060.jpg',
 'BSR/BSDS500/data/images/train/41004.jpg',
 'BSR/BSDS500/data/images/train/95006.jpg',
 'BSR/BSDS500/data/images/train/198054.jpg',
 'BSR/BSDS500/data/images/train/67079.jpg',
 'BSR/BSDS500/data/images/train/246053.jpg',
 'BSR/BSDS500/data/images/train/370036.jpg',
 'BSR/BSDS500/data/

In [5]:
print('Loading BSR training images')
background_data = []
for name in train_files:
    try:
        fp = f.extractfile(name)
        bg_img = skimage.io.imread(fp)
        background_data.append(bg_img)
    except:
        continue

Loading BSR training images


In [8]:
background_data

[array([[[ 41,  46,  39],
         [132, 137, 141],
         [215, 220, 240],
         ...,
         [159, 174, 197],
         [159, 174, 197],
         [157, 172, 193]],
 
        [[ 45,  52,  45],
         [162, 167, 173],
         [208, 213, 233],
         ...,
         [160, 175, 198],
         [159, 174, 197],
         [158, 173, 194]],
 
        [[ 57,  63,  59],
         [193, 200, 208],
         [204, 210, 232],
         ...,
         [161, 176, 199],
         [160, 175, 198],
         [157, 172, 193]],
 
        ...,
 
        [[  6,   9,  14],
         [  5,  14,   0],
         [ 36,  53,   9],
         ...,
         [ 64,  87,  45],
         [ 66,  89,  47],
         [ 60,  84,  36]],
 
        [[ 10,  10,  18],
         [  3,   9,   0],
         [  3,  18,   0],
         ...,
         [ 53,  78,  36],
         [ 44,  69,  27],
         [ 58,  82,  34]],
 
        [[ 10,  12,   7],
         [  7,  12,   8],
         [  3,  11,   0],
         ...,
         [ 41,  72,  28],
  

In [9]:
def compose_image(digit, background):
    """Difference-blend a digit and a random patch from a background image."""
    w, h, _ = background.shape
    dw, dh, _ = digit.shape
    x = np.random.randint(0, w - dw)
    y = np.random.randint(0, h - dh)
    
    bg = background[x:x+dw, y:y+dh]
    return np.abs(bg - digit).astype(np.uint8)


def mnist_to_img(x):
    """Binarize MNIST digit and convert to RGB."""
    x = (x > 0).astype(np.float32)
    d = x.reshape([28, 28, 1]) * 255
    return np.concatenate([d, d, d], 2)


def create_mnistm(X):
    """
    Give an array of MNIST digits, blend random background patches to
    build the MNIST-M dataset as described in
    http://jmlr.org/papers/volume17/15-239/15-239.pdf
    """
    X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
    for i in range(X.shape[0]):

        if i % 1000 == 0:
            print('Processing example', i)

        bg_img = rand.choice(background_data)

        d = mnist_to_img(X[i])
        d = compose_image(d, bg_img)
        X_[i] = d

    return X_

In [11]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

print('Building train set...')
train = create_mnistm(x_train)
print('Building test set...')
test = create_mnistm(x_test)

# Save dataset as pickle
with open('mnistm.pkl', 'wb') as f:
    pkl.dump({ 'train': train, 'test': test}, f, pkl.HIGHEST_PROTOCOL)


Building train set...
Processing example 0
Processing example 1000
Processing example 2000
Processing example 3000
Processing example 4000
Processing example 5000
Processing example 6000
Processing example 7000
Processing example 8000
Processing example 9000
Processing example 10000
Processing example 11000
Processing example 12000
Processing example 13000
Processing example 14000
Processing example 15000
Processing example 16000
Processing example 17000
Processing example 18000
Processing example 19000
Processing example 20000
Processing example 21000
Processing example 22000
Processing example 23000
Processing example 24000
Processing example 25000
Processing example 26000
Processing example 27000
Processing example 28000
Processing example 29000
Processing example 30000
Processing example 31000
Processing example 32000
Processing example 33000
Processing example 34000
Processing example 35000
Processing example 36000
Processing example 37000
Processing example 38000
Processing examp

In [12]:
x_train.shape

(60000, 28, 28)

In [13]:
x_test.shape

(10000, 28, 28)