In [31]:
import torch
from torch.utils.data.sampler import WeightedRandomSampler
from typing import Iterator, Sequence


class ExhaustiveWeightedRandomSampler(WeightedRandomSampler):
    """ExhaustiveWeightedRandomSampler behaves pretty much the same as WeightedRandomSampler
    except that it receives an extra parameter, exaustive_weight, which is the weight of the
    elements that should be sampled exhaustively over multiple iterations.
    This is useful when the dataset is very big and also very imbalanced, like the negative
    sample is way more than positive samples, we want to over sample positive ones, but also
    iterate over all the negative samples as much as we can.
    Args:
        weights (sequence): a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        exaustive_weight (int): which weight of samples should be sampled exhaustively,
            normally this is the one that should not been over sampled, like the lowest
            weight of samples in the dataset.
        generator (Generator): Generator used in sampling.
    """

    def __init__(
        self,
        weights: Sequence[float],
        num_samples: int,
        exaustive_weight=1,
        generator=None,
    ) -> None:
        super().__init__(weights, num_samples, True, generator)
        self.all_indices = torch.tensor(list(range(num_samples)))
        self.exaustive_weight = exaustive_weight
        self.weights_mapping = torch.tensor(weights) == self.exaustive_weight
        self.remaining_indices = torch.tensor([], dtype=torch.long)

    def get_remaining_indices(self) -> torch.Tensor:
        remaining_indices = self.weights_mapping.nonzero().squeeze()
        return remaining_indices[torch.randperm(len(remaining_indices))]

    def __iter__(self) -> Iterator[int]:
        rand_tensor = torch.multinomial(
            self.weights, self.num_samples, self.replacement, generator=self.generator
        )
        print('rand_tensor:', rand_tensor, sep = '\n')
        exaustive_indices = rand_tensor[
            self.weights_mapping[rand_tensor].nonzero().squeeze()
        ]
        print('exaustive_indices:', exaustive_indices, sep = '\n')
        print('self.remaining_indices before:', self.remaining_indices, sep = '\n')
        while len(exaustive_indices) > len(self.remaining_indices):
            self.remaining_indices = torch.cat(
                [self.remaining_indices, self.get_remaining_indices()]
            )
        print('self.remaining_indices:', self.remaining_indices, sep = '\n')
        yield_indexes, self.remaining_indices = (
            self.remaining_indices[: len(exaustive_indices)],
            self.remaining_indices[len(exaustive_indices) :],
        )
#         print('temp2:', exaustive_indices)
#         print('temp:', (rand_tensor[..., None] == exaustive_indices).any(-1).nonzero().squeeze())
        rand_tensor[
            (rand_tensor[..., None] == exaustive_indices).any(-1).nonzero().squeeze()
        ] = yield_indexes
        yield from iter(rand_tensor.tolist())

In [33]:
sampler = ExhaustiveWeightedRandomSampler([1, 1, 1, 1, 1, 1, 1, 1, 1, 10], num_samples=30)
for i in range(5):
    idxs = list(sampler)
    print('\n\nFINAL RESULT:')
    print(idxs)
    print('\t-->', sorted(list(set(idxs))))

rand_tensor:
tensor([1, 8, 9, 5, 8, 7, 9, 9, 2, 6, 4, 9, 1, 7, 5, 2, 6, 7, 9, 9, 9, 5, 3, 9,
        5, 9, 7, 9, 9, 9])
exaustive_indices:
tensor([1, 8, 5, 8, 7, 2, 6, 4, 1, 7, 5, 2, 6, 7, 5, 3, 5, 7])
self.remaining_indices before:
tensor([], dtype=torch.int64)
self.remaining_indices:
tensor([5, 7, 0, 8, 4, 2, 1, 6, 3, 5, 1, 0, 8, 3, 2, 7, 4, 6])


FINAL RESULT:
[5, 7, 9, 0, 8, 4, 9, 9, 2, 1, 6, 9, 3, 5, 1, 0, 8, 3, 9, 9, 9, 2, 7, 9, 4, 9, 6, 9, 9, 9]
	--> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
rand_tensor:
tensor([7, 1, 9, 9, 5, 9, 9, 9, 9, 9, 0, 5, 9, 9, 9, 2, 9, 0, 9, 9, 9, 9, 4, 0,
        7, 5, 3, 7, 9, 9])
exaustive_indices:
tensor([7, 1, 5, 0, 5, 2, 0, 4, 0, 7, 5, 3, 7])
self.remaining_indices before:
tensor([], dtype=torch.int64)
self.remaining_indices:
tensor([7, 3, 6, 4, 0, 2, 8, 1, 5, 1, 3, 6, 4, 7, 0, 5, 8, 2])


FINAL RESULT:
[7, 3, 9, 9, 6, 9, 9, 9, 9, 9, 4, 0, 9, 9, 9, 2, 9, 8, 9, 9, 9, 9, 1, 5, 1, 3, 6, 4, 9, 9]
	--> [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
rand_tensor:
tensor([0, 9, 6,