In [171]:
import sys
sys.path.append('/Users/ashwindesilva/research/ood-tl')

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold
from datahandlers.cifar import SplitCIFARHandler
import numpy as np
import random

In [21]:
def wif(id):
    """
    Used to fix randomization bug for pytorch dataloader + numpy
    Code from https://github.com/pytorch/pytorch/issues/5059
    """
    process_seed = torch.initial_seed()
    # Back out the base_seed so we can use all the bits.
    base_seed = process_seed - id
    ss = np.random.SeedSequence([id, base_seed])
    # More than 128 bits (4 32-bit words) would be overkill.
    np.random.seed(ss.generate_state(4))

In [181]:
import torch
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold


class Sampler(object):
    """Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class StratifiedSampler(Sampler):
    """Stratified Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, task_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(task_vector.size(0) / batch_size)
        self.task_vector = task_vector
        self.batch_size = batch_size

        tasks = self.task_vector.numpy()
        self.target_indices = np.where(tasks==0)[0].tolist()
        self.ood_indices = np.where(tasks==1)[0].tolist()

    def gen_sample_array(self):
        try:
            from sklearn.model_selection import StratifiedShuffleSplit
        except:
            print('Need scikit-learn for this functionality')
        import numpy as np
        
        # s = StratifiedKFold(n_splits=self.n_splits, shuffle=True)
        # X = torch.randn(self.task_vector.size(0),2).numpy()
        # y = self.task_vector.numpy()
        # s.get_n_splits(X, y)

        # # indices = []
        # # for _, test_index in s.split(X, y):
        # #     indices = np.hstack([indices, test_index])

        # indices = []
        # for i in range(self.n_splits):
        #     _ , test_index = next(s.split(X, y))
        #     indices = np.hstack([indices, test_index])

        indices = []
        for i in range(self.n_splits):
            indices.extend(np.random.choice(self.target_indices, self.batch_size // 2, replace=False))
            indices.extend(np.random.choice(self.ood_indices, self.batch_size // 2, replace=False))
        indices = np.array(indices)
        
        return indices.astype('int')

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return len(self.task_vector)

In [None]:
dataset = SplitCIFARHandler([[0, 1], [8, 9]])

In [185]:
dataset.sample_data(n=20, m=100, randomly=False)
targets = dataset.comb_trainset.targets
task_vector = torch.tensor([targets[i][0] for i in range(len(targets))], dtype=torch.int32)

In [186]:
batch_size = 8
strat_sampler = StratifiedSampler(task_vector, batch_size)
batch_sampler = torch.utils.data.BatchSampler(strat_sampler, batch_size, False)
data_loader = DataLoader(dataset.comb_trainset, worker_init_fn=wif, pin_memory=True, num_workers=0, batch_sampler=batch_sampler)

In [187]:
for data, targets in data_loader:
    tasks, labels = targets
    print(tasks)
    print("Target Fraction : {:.3f}".format(1-tasks.sum()/len(tasks)))

[12, 1, 10, 19, 37, 92, 65, 80]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[0, 10, 7, 15, 99, 100, 87, 34]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[18, 1, 7, 5, 62, 65, 107, 76]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[8, 17, 18, 2, 64, 95, 36, 35]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[2, 7, 10, 15, 36, 74, 39, 108]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[9, 3, 1, 4, 104, 65, 86, 83]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[18, 8, 14, 1, 104, 97, 82, 56]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[17, 15, 7, 3, 64, 97, 67, 79]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[16, 0, 13, 19, 113, 64, 47, 111]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[11, 6, 5, 10, 51, 86, 108, 114]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[13, 11, 5, 1, 44, 75, 24, 90]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[9, 4, 1, 15, 94, 33, 11

In [167]:
task_vector

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32)

In [168]:
tasks = task_vector.numpy()

In [170]:
np.where(tasks==1)[0]

array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
       37, 38, 39])