Skip to content

Commit

Permalink
Merge pull request #1154 from zalandoresearch/sampler-refactor
Browse files Browse the repository at this point in the history
Sampler refactor
  • Loading branch information
Alan Akbik committed Sep 25, 2019
2 parents 899827b + a0233bd commit 316332c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 34 deletions.
58 changes: 27 additions & 31 deletions flair/samplers.py
@@ -1,4 +1,5 @@
import logging
from abc import abstractmethod
from collections import defaultdict

from torch.utils.data.sampler import Sampler
Expand All @@ -9,19 +10,33 @@
log = logging.getLogger("flair")


class ImbalancedClassificationDatasetSampler(Sampler):
class FlairSampler(Sampler):
def set_dataset(self, data_source):
"""Initialize by passing a block_size and a plus_window parameter.
:param data_source: dataset to sample from
"""
self.data_source = data_source
self.num_samples = len(self.data_source)

def __len__(self):
return self.num_samples


class ImbalancedClassificationDatasetSampler(FlairSampler):
"""Use this to upsample rare classes and downsample common classes in your unbalanced classification dataset.
"""

def __init__(self, data_source: FlairDataset):
def __init__(self):
super(ImbalancedClassificationDatasetSampler, self).__init__(None)

def set_dataset(self, data_source: FlairDataset):
"""
Initialize by passing a classification dataset with labels, i.e. either TextClassificationDataSet or
:param data_source:
"""
super().__init__(data_source)

self.data_source = data_source
self.num_samples = len(self.data_source)
self.indices = list(range(len(data_source)))
self.num_samples = len(data_source)

# first determine the distribution of classes in the dataset
label_count = defaultdict(int)
Expand All @@ -44,27 +59,17 @@ def __iter__(self):
for i in torch.multinomial(self.weights, self.num_samples, replacement=True)
)

def __len__(self):
return self.num_samples


class ChunkSampler(Sampler):
class ChunkSampler(FlairSampler):
"""Splits data into blocks and randomizes them before sampling. This causes some order of the data to be preserved,
while still shuffling the data.
"""

def __init__(self, data_source, block_size=5, plus_window=5):
"""Initialize by passing a block_size and a plus_window parameter.
:param data_source: dataset to sample from
:param block_size: minimum size of each block
:param plus_window: randomly adds between 0 and this value to block size at each epoch
"""
super().__init__(data_source)
self.data_source = data_source
self.num_samples = len(self.data_source)

def __init__(self, block_size=5, plus_window=5):
super(ChunkSampler, self).__init__(None)
self.block_size = block_size
self.plus_window = plus_window
self.data_source = None

def __iter__(self):
data = list(range(len(self.data_source)))
Expand All @@ -83,23 +88,17 @@ def __iter__(self):
data[:] = [b for bs in blocks for b in bs]
return iter(data)

def __len__(self):
return self.num_samples


class ExpandingChunkSampler(Sampler):
class ExpandingChunkSampler(FlairSampler):
"""Splits data into blocks and randomizes them before sampling. Block size grows with each epoch.
This causes some order of the data to be preserved, while still shuffling the data.
"""

def __init__(self, data_source, step=3):
def __init__(self, step=3):
"""Initialize by passing a block_size and a plus_window parameter.
:param data_source: dataset to sample from
"""
super().__init__(data_source)
self.data_source = data_source
self.num_samples = len(self.data_source)

super(ExpandingChunkSampler, self).__init__(None)
self.block_size = 1
self.epoch_count = 0
self.step = step
Expand All @@ -124,6 +123,3 @@ def __iter__(self):
self.block_size += 1

return iter(data)

def __len__(self):
return self.num_samples
13 changes: 10 additions & 3 deletions flair/trainers/trainer.py
Expand Up @@ -2,15 +2,17 @@
from pathlib import Path
from typing import List, Union
import time
import sys

import datetime
import sys
import inspect

import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.sgd import SGD
from torch.utils.data.dataset import ConcatDataset

from flair.samplers import FlairSampler

try:
from apex import amp
except ImportError:
Expand Down Expand Up @@ -203,8 +205,13 @@ def train(
if train_with_dev:
train_data = ConcatDataset([self.corpus.train, self.corpus.dev])

# initialize sampler if provided
if sampler is not None:
sampler = sampler(train_data)
# init with default values if only class is provided
if inspect.isclass(sampler):
sampler = sampler()
# set dataset to sample from
sampler.set_dataset(train_data)
shuffle = False

dev_score_history = []
Expand Down

0 comments on commit 316332c

Please sign in to comment.