In [1]:
import sys
sys.path.append('/Users/ashwindesilva/research/ood-tl')

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold
from datahandlers.cifar import SplitCIFARHandler
import numpy as np
import random

In [2]:
def wif(id):
    """
    Used to fix randomization bug for pytorch dataloader + numpy
    Code from https://github.com/pytorch/pytorch/issues/5059
    """
    process_seed = torch.initial_seed()
    # Back out the base_seed so we can use all the bits.
    base_seed = process_seed - id
    ss = np.random.SeedSequence([id, base_seed])
    # More than 128 bits (4 32-bit words) would be overkill.
    np.random.seed(ss.generate_state(4))

In [10]:
import torch
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold


class Sampler(object):
    """Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

class StratifiedSampler(Sampler):
    """Stratified Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, task_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(task_vector.size(0) / batch_size)
        self.task_vector = task_vector
        self.batch_size = batch_size

        tasks = self.task_vector.numpy()
        self.target_indices = np.where(tasks==0)[0].tolist()
        self.ood_indices = np.where(tasks==1)[0].tolist()

        self.beta = len(self.target_indices)/(len(tasks))

    def gen_sample_array(self):
        try:
            from sklearn.model_selection import StratifiedShuffleSplit
        except:
            print('Need scikit-learn for this functionality')
        import numpy as np
        
        # s = StratifiedKFold(n_splits=self.n_splits, shuffle=True)
        # X = torch.randn(self.task_vector.size(0),2).numpy()
        # y = self.task_vector.numpy()
        # s.get_n_splits(X, y)

        # # indices = []
        # # for _, test_index in s.split(X, y):
        # #     indices = np.hstack([indices, test_index])

        # indices = []
        # for i in range(self.n_splits):
        #     _ , test_index = next(s.split(X, y))
        #     indices = np.hstack([indices, test_index])

        indices = []
        for i in range(self.n_splits):
            indices.extend(np.random.choice(self.target_indices, round(self.batch_size * self.beta), replace=False))
            indices.extend(np.random.choice(self.ood_indices, round(self.batch_size * (1-self.beta)), replace=False))
        indices = np.array(indices)
        
        return indices.astype('int')

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return len(self.task_vector)

In [4]:
dataset = SplitCIFARHandler([[0, 1], [8, 9]])

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz


14.4%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

46.9%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

80.6%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Files already downloaded and verified


In [11]:
dataset.sample_data(n=100, m=2000, randomly=False)
targets = dataset.comb_trainset.targets
task_vector = torch.tensor([targets[i][0] for i in range(len(targets))], dtype=torch.int32)

In [14]:
batch_size = 128
strat_sampler = StratifiedSampler(task_vector, batch_size)
batch_sampler = torch.utils.data.BatchSampler(strat_sampler, batch_size, False)
data_loader = DataLoader(dataset.comb_trainset, worker_init_fn=wif, pin_memory=True, num_workers=0, batch_sampler=batch_sampler)

In [15]:
for data, targets in data_loader:
    tasks, labels = targets
    print(tasks)
    print("Target Fraction : {:.3f}".format(1-tasks.sum()/len(tasks)))

tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.047
tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
Target Fraction : 0.047
tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
      

___

In [1]:
import sys
sys.path.append('/Users/ashwindesilva/research/ood-tl')

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import StratifiedKFold
from datahandlers.cifar import RotatedCIFAR10Handler
import numpy as np
import random

In [2]:
dataset = RotatedCIFAR10Handler([2,3], 45)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
dataset.sample_data(10, 10)

In [16]:
test_loader = dataset.get_task_data_loader(0, 100, train=False)

In [17]:
import matplotlib.pyplot as plt

In [1]:
import datetime


In [2]:
x = datetime.datetime.now()

In [3]:
print(x)

2022-07-20 23:06:53.121349


In [6]:
str(x)

'2022-07-20 23:06:53.121349'

In [7]:
import os

In [None]:
dir = "../experiments/results/cifar10_rotated_tasks"

In [8]:
if not os.path.exists(dir):
        os.makedirs(dir)
        os.makedirs(os.path.join(dir, str(datetime.datetime.now())))

In [1]:
import logging

In [2]:
logging.basicConfig(filename="exp_log.log", level=logging.DEBUG)

In [3]:
for i in range(10):
    a = {
        "risk" : i
    }
    logging.info(str(a))