# What are Samplers in Pytorch?

Samplers determine how each individual image (or subjecet) will be fed to the network during a single epoch. There are two main samplers built-in to  `torchvision`/`pytorch`. They are called `SequentialSampler` and `RandomSampler` (produced below). They do exactly what you think they do - sample your data in order, or shuffle each sample before returning them.

```python
class SequentialSampler(Sampler):
    def __init__(self, nb_samples):
        self.num_samples = nb_samples

    def __iter__(self):
        return iter(range(self.num_samples))

    def __len__(self):
        return self.num_samples


class RandomSampler(Sampler):
    def __init__(self, nb_samples):
        self.num_samples = nb_samples

    def __iter__(self):
        return iter(torch.randperm(self.num_samples).long())

    def __len__(self):
        return self.num_samples
```

You'll see how they operated on integers. That's how sampling works in pytorch - each individual image/subject/whatever is accessed using an integer index.

Let's demonstrate their usage. Starting with the default (i.e. doing nothing):


In [22]:
import torch
import numpy as np
x = torch.ones(3,1,10,10) # three samples of (1,10,10) size
y = torch.from_numpy(np.arange(3)) # class labels = (1,2,3)

from torchsample import TensorDataset
data = TensorDataset(x,y)
for i,j in data:
    print(j[0])

0
1
2


Ok, obvious - by default the torchsample datasets will take samples in the order they are given. We can change this by passing `shuffle=True` into the class:

In [23]:
data = TensorDataset(x,y, shuffle=True)
for i, j in data:
    print(j[0])

2
0
1


Cool, our data has now been shuffled. However, we can use samplers explicitly to provide more nuanced sampling in our data. Let's start by recreating the above two examples using the actual samplers:

In [29]:
from torchsample.samplers import RandomSampler, SequentialSampler

rs = RandomSampler(nb_samples=3)
ss = SequentialSampler(nb_samples=3)

sequential_data = TensorDataset(x,y, sampler=ss)
print('Using Sequential Sampler:')
for i, j in sequential_data:
    print(j[0])
    
random_data = TensorDataset(x,y, sampler=rs)
print('Using Random Sampler:')
for i, j in random_data:
    print(j[0])
    

Using Sequential Sampler:
0
1
2
Using Random Sampler:
2
1
0


That was straight-forward. Now, let's look at the first (of many to come) enhanced samplers. This next sampler is called `MultiSampler`, because it lets you take more samples in an epoch than there actually are samples in the data. This is incredibly useful if you're training on 2D slices from 3D images or even crops which are much smaller than the total image size.


In [30]:
from torchsample.samplers import MultiSampler

ms = MultiSampler(nb_samples=3, desired_samples=10) # we have 3 samples, but we want 10 samples per epoch
multi_data = TensorDataset(x, y, sampler=ms)
for i, j in multi_data:
    print(j[0])

0
1
2
0
1
2
0
1
2
1


Above, we essentially got 10 samples out of our data even though there are only three actual samples. You can see the routine here -- loop over the data in order for whatever `floor(desired_samples / nb_samples)` is, then for the difference just randomly take more samples from the entire pool.

You can also use this sampler to take <i>less</i> samples than exist in your data:

In [31]:
ms = MultiSampler(nb_samples=3, desired_samples=2) # we have 3 samples, but we want 10 samples per epoch
multi_data = TensorDataset(x, y, sampler=ms)
for i, j in multi_data:
    print(j[0])

2
0


Here, we just took two samples at random.

Samplers allow you to implement any type of sampling procedure you like, while abstracting away any data handling. This is a nice thing. It would be nice to have the stratified sampling as in scikit-learn. Maybe that's next on the sampler to-do list.

Now, one last thing to note is that you can actually sample indefinitely from datasets in `torchsample` by calling the `next()` method on the data. This returns an iterator just like the for loops above, but will automatically reset the iterator after you've reached the end of the data. This can be useful for someone I'm sure:

In [34]:
data = TensorDataset(x, y, shuffle=False)
count = 0
while True:
    i, j = data.next()
    print(j[0])
    count+=1
    if count > 10:
        break # break so we dont loop forever

0
1
2
0
1
2
0
1
2
0
1
