In [None]:
TRAIN_TFRECORD = '../data/cifar10-train.tfrecord'
SCALING = 'linear'
FRACTION = 5
TOTAL_SAMPLES = 30000
LABELED_SAMPLES = '10:20:30:40:100:250:1000:4000'
SEED = 12345
OUTPUT_DIR = '/home/users/daniel/tmp/'
NAME = 'cifar10'

In [None]:
import pandas as pd
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pl

from tqdm import tqdm
from itertools import count
from pathlib import Path

np.random.seed(SEED)

In [None]:
def get_class(serialized_example):
    return tf.parse_single_example(serialized_example, features={'label': tf.FixedLenFeature([], tf.int64)})['label']

dataset = tf.data.TFRecordDataset(TRAIN_TFRECORD).map(get_class)
it = dataset.make_one_shot_iterator().get_next()
class_ids = []
try:
    with tf.Session() as session:
        for n in tqdm(count()):
            result = session.run(it)
            class_ids.append(result)
except tf.errors.OutOfRangeError:
    pass

In [None]:
df = pd.DataFrame({'class_id': class_ids})
df.class_id.hist()

In [None]:
n_classes = df.class_id.max() + 1
class_order = np.arange(n_classes)
np.random.shuffle(class_order)

if SCALING == 'linear':
    num_samples = np.linspace(1, 1 / FRACTION, num=n_classes)
else:
    raise ValueError(f'Dont understand scaling == {linear}') 
    
num_samples *= TOTAL_SAMPLES / np.sum(num_samples)
num_samples = num_samples[class_order].astype(np.int64)
num_samples

print(num_samples)

if any(df.groupby('class_id').size().values < num_samples):
    raise ValueError('Cannot fullfill samples')

In [None]:
def sample(group):
    class_id, = group.class_id.unique()
    n = num_samples[class_id]
    return group.sample(n=num_samples[class_id], replace=False)


In [None]:
dataset = tf.data.TFRecordDataset(TRAIN_TFRECORD)
it = dataset.make_one_shot_iterator().get_next()
data = []
try:
    with tf.Session() as session:
        for n in tqdm(count()):
            result = session.run(it)
            data.append(result)
except tf.errors.OutOfRangeError:
    pass

In [None]:
output_dir = Path(OUTPUT_DIR) / 'SSL2'
output_dir.mkdir(exist_ok=True, parents=True)

def write_selection(selection, path):
    print(f'Writing to {path}')
    indices = selection['index'].values.copy()
    np.random.shuffle(indices)

    with tf.python_io.TFRecordWriter(str(path)) as writer:
        for index in tqdm(indices):
            writer.write(data[index])

In [None]:
def resample(df, size):
    repeats = size // len(df)
    crops = size % len(df)
    new = pd.concat(repeats * [df] + [df.iloc[:crops]])
    assert len(new) == size
    return new

def resample_classes(df):
    max_count = selection.groupby('class_id').size().max()
    return df.groupby('class_id').apply(lambda df: resample(df, max_count)).reset_index(drop=True)

In [None]:
selection = df.reset_index().groupby('class_id').apply(sample).reset_index(drop=True)
selection.class_id.hist()
pl.show()

selection_resampled = resample_classes(selection)
selection_resampled.class_id.hist()

selection_resampled.drop_duplicates(['index'], keep='first').class_id.hist(alpha=.5)

write_selection(selection_resampled, output_dir / f'{NAME}-unlabel.tfrecord')

In [None]:
for labelled_samples in LABELED_SAMPLES.split(':'):
    n = int(labelled_samples) // n_classes
    for seed in range(6):
        rgen = np.random.RandomState(seed)
        selection = df.reset_index().groupby('class_id').apply(
            lambda d: d.sample(n=n, replace=False, random_state=rgen)).reset_index(drop=True)
        write_selection(selection, output_dir / f'{NAME}.{seed}@{labelled_samples}-label.tfrecord')