In [171]:
import sys
sys.path.append('/Users/ashwindesilva/research/ood-tl')

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold
from datahandlers.cifar import SplitCIFARHandler
import numpy as np
import random

In [21]:
def wif(id):
    """
    Used to fix randomization bug for pytorch dataloader + numpy
    Code from https://github.com/pytorch/pytorch/issues/5059
    """
    process_seed = torch.initial_seed()
    # Back out the base_seed so we can use all the bits.
    base_seed = process_seed - id
    ss = np.random.SeedSequence([id, base_seed])
    # More than 128 bits (4 32-bit words) would be overkill.
    np.random.seed(ss.generate_state(4))

In [181]:
import torch
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold


class Sampler(object):
    """Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class StratifiedSampler(Sampler):
    """Stratified Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, task_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(task_vector.size(0) / batch_size)
        self.task_vector = task_vector
        self.batch_size = batch_size

        tasks = self.task_vector.numpy()
        self.target_indices = np.where(tasks==0)[0].tolist()
        self.ood_indices = np.where(tasks==1)[0].tolist()

    def gen_sample_array(self):
        try:
            from sklearn.model_selection import StratifiedShuffleSplit
        except:
            print('Need scikit-learn for this functionality')
        import numpy as np
        
        # s = StratifiedKFold(n_splits=self.n_splits, shuffle=True)
        # X = torch.randn(self.task_vector.size(0),2).numpy()
        # y = self.task_vector.numpy()
        # s.get_n_splits(X, y)

        # # indices = []
        # # for _, test_index in s.split(X, y):
        # #     indices = np.hstack([indices, test_index])

        # indices = []
        # for i in range(self.n_splits):
        #     _ , test_index = next(s.split(X, y))
        #     indices = np.hstack([indices, test_index])

        indices = []
        for i in range(self.n_splits):
            indices.extend(np.random.choice(self.target_indices, self.batch_size // 2, replace=False))
            indices.extend(np.random.choice(self.ood_indices, self.batch_size // 2, replace=False))
        indices = np.array(indices)
        
        return indices.astype('int')

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return len(self.task_vector)

In [None]:
dataset = SplitCIFARHandler([[0, 1], [8, 9]])

In [188]:
dataset.sample_data(n=100, m=2000, randomly=False)
targets = dataset.comb_trainset.targets
task_vector = torch.tensor([targets[i][0] for i in range(len(targets))], dtype=torch.int32)

In [189]:
batch_size = 16
strat_sampler = StratifiedSampler(task_vector, batch_size)
batch_sampler = torch.utils.data.BatchSampler(strat_sampler, batch_size, False)
data_loader = DataLoader(dataset.comb_trainset, worker_init_fn=wif, pin_memory=True, num_workers=0, batch_sampler=batch_sampler)

In [190]:
for data, targets in data_loader:
    tasks, labels = targets
    print(tasks)
    print("Target Fraction : {:.3f}".format(1-tasks.sum()/len(tasks)))

[23, 18, 64, 72, 21, 75, 8, 52, 1374, 1192, 1297, 1578, 1004, 952, 298, 1546]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.500
[2, 28, 63, 72, 42, 61, 70, 51, 1226, 1523, 1807, 915, 852, 971, 1143, 1583]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.500
[38, 97, 79, 66, 11, 26, 17, 83, 313, 1835, 490, 1492, 219, 544, 447, 102]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.500
[89, 79, 84, 35, 5, 8, 50, 58, 810, 620, 542, 588, 1800, 1957, 121, 168]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.500
[83, 12, 32, 54, 65, 27, 15, 7, 1979, 827, 799, 1745, 825, 1346, 1343, 1845]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.500
[42, 25, 26, 1, 12, 90, 7, 16, 1684, 1352, 1747, 1921, 1587, 1388, 168, 1729]
tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.500
[2, 47, 36, 72, 64, 6, 93, 21, 1651, 235, 1215, 889, 524