In [4]:
import json
import os
import sys
import cv2
import copy
import pandas as pd
import numpy as np

import torch.utils.data.sampler

In [5]:
class BalancedSubsetSampler(torch.utils.data.Sampler):
    def __init__(self, classifications, batch_size, shuffle, idxs):
        
        self.classifications = classifications
        self.num_examples = len(classifications)
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.target_idxs = idxs
        
        self._setup()
        
    def _setup(self):
        scenting_idxs = np.where(self.classifications=='scenting')[0]
        non_scenting_idxs = np.where(self.classifications=='non_scenting')[0]
        
        scenting_idxs = scenting_idxs[np.in1d(scenting_idxs, self.target_idxs)]
        non_scenting_idxs = non_scenting_idxs[np.in1d(non_scenting_idxs, self.target_idxs)]
        
        self.lookup_src = {
            "scenting"     : scenting_idxs,
            "non_scenting" : non_scenting_idxs
        }
        
    def init_lookup(self):
        lookup = copy.deepcopy(self.lookup_src)
        
        if self.shuffle:
            for key, val in lookup.items():
                np.random.shuffle(val)
                
        return lookup
    
    def sample(self, lookup, key):
        key_idxs = lookup[key]
        return np.random.choice(key_idxs)
    
    def __len__(self):
        return self.num_examples
    
    def __iter__(self):
        lookup = self.init_lookup()
        for i in range(self.num_examples):
            if i % 2 == 0:
                sample = self.sample(lookup, 'scenting')
            else:
                sample = self.sample(lookup, 'non_scenting')
            yield sample


In [7]:
class SubsetIdentitySampler(torch.utils.data.Sampler):
    def __init__(self, indices):
        self.indices = indices
        
    def __iter__(self):
        return (idx for idx in self.indices)
    
    def __len__(self):
        return len(self.indices)
        