In [35]:
import sys
sys.path.append('/Users/ashwindesilva/research/ood-tl')

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold
from datahandlers.cifar import SplitCIFARHandler
import numpy as np

In [21]:
def wif(id):
    """
    Used to fix randomization bug for pytorch dataloader + numpy
    Code from https://github.com/pytorch/pytorch/issues/5059
    """
    process_seed = torch.initial_seed()
    # Back out the base_seed so we can use all the bits.
    base_seed = process_seed - id
    ss = np.random.SeedSequence([id, base_seed])
    # More than 128 bits (4 32-bit words) would be overkill.
    np.random.seed(ss.generate_state(4))

In [108]:
class Sampler(object):
    """Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class StratifiedSampler(Sampler):
    """Stratified Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, task_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(task_vector.size(0) / batch_size)
        self.task_vector = task_vector

    def gen_sample_array(self):
        try:
            from sklearn.model_selection import StratifiedShuffleSplit
        except:
            print('Need scikit-learn for this functionality')
        import numpy as np
        
        # s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.5)
        s = StratifiedKFold(n_splits=self.n_splits, shuffle=True)
        X = torch.randn(self.task_vector.size(0),2).numpy()
        y = self.task_vector.numpy()
        s.get_n_splits(X, y)

        indices = []
        for _, test_index in s.split(X, y):
            indices = np.hstack([indices, test_index])
        
        # print(task_vector[indices])
        print(indices.astype('int'))
        indices.sort()
        print(indices)

        # train_index, test_index = next(s.split(X, y))
        # print(task_vector[train_index])
        # indices = np.hstack([train_index, test_index])
        return indices.astype('int')

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return len(self.task_vector)

In [None]:
dataset = SplitCIFARHandler([[0, 1], [8, 9]])

In [105]:
dataset.sample_data(n=20, m=40, randomly=False)
targets = dataset.comb_trainset.targets
task_vector = torch.tensor([targets[i][0] for i in range(len(targets))], dtype=torch.int32)

In [109]:
batch_size = 8
strat_sampler = StratifiedSampler(task_vector, batch_size)
batch_sampler = torch.utils.data.BatchSampler(strat_sampler, batch_size, True)
data_loader = DataLoader(dataset.comb_trainset, worker_init_fn=wif, pin_memory=True, num_workers=0, batch_sampler=batch_sampler)

In [110]:
for data, targets in data_loader:
    tasks, labels = targets
    print(tasks)
    print("Target Fraction : {:.3f}".format(1-tasks.sum()/len(tasks)))

[ 5  9 13 20 24 36 38 43 48  3 11 18 25 40 42 47 54 59  8 10 17 23 29 31
 44 45 56  0  4 12 21 22 34 50 52 55  1 15 16 28 32 37 39 49  2  7 19 27
 33 41 51 53  6 14 26 30 35 46 57 58]
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.
 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53.
 54. 55. 56. 57. 58. 59.]
[0, 1, 2, 3, 4, 5, 6, 7]
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Target Fraction : 1.000
[8, 9, 10, 11, 12, 13, 14, 15]
tensor([0, 0, 0, 0, 0, 0, 0, 0])
Target Fraction : 1.000
[16, 17, 18, 19, 20, 21, 22, 23]
tensor([0, 0, 0, 0, 1, 1, 1, 1])
Target Fraction : 0.500
[24, 25, 26, 27, 28, 29, 30, 31]
tensor([1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.000
[32, 33, 34, 35, 36, 37, 38, 39]
tensor([1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.000
[40, 41, 42, 43, 44, 45, 46, 47]
tensor([1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.000
[48, 49, 50, 51, 52, 53, 54, 55]
tensor([