Skip to content

Commit

Permalink
Merge 784ae13 into 3ea2346
Browse files Browse the repository at this point in the history
  • Loading branch information
ncapobianco committed Nov 12, 2020
2 parents 3ea2346 + 784ae13 commit f9abe63
Show file tree
Hide file tree
Showing 6 changed files with 442 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/trw/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .trainer import Trainer, create_losses_fn, epoch_train_eval, eval_loop, train_loop, \
run_trainer_repeat, default_post_training_callbacks, default_per_epoch_callbacks, default_pre_training_callbacks, \
default_sum_all_losses
from .trainer_across_datasets import epoch_train_eval_across_datasets
from .optimizers import create_sgd_optimizers_fn, create_sgd_optimizers_scheduler_step_lr_fn, \
create_scheduler_step_lr, create_adam_optimizers_fn, \
create_adam_optimizers_scheduler_step_lr_fn, create_optimizers_fn
Expand Down Expand Up @@ -64,6 +65,7 @@
from .sequence_collate import SequenceCollate
from .sequence_rebatch import SequenceReBatch
from .sequence_sub_batch import SequenceSubBatch
from .sequence_array_fixed_samples_per_epoch import SequenceArrayFixedSamplesPerEpoch

from .metrics import Metric, MetricClassificationError, MetricClassificationBinarySensitivitySpecificity, MetricLoss, \
MetricClassificationBinaryAUC, MetricClassificationF1
Expand All @@ -75,4 +77,4 @@
from .meaningful_perturbation import MeaningfulPerturbation, default_information_removal_smoothing
from .data_parallel_extented import DataParallelExtended

from .compatibility import grid_sample
from .compatibility import grid_sample, affine_grid
6 changes: 5 additions & 1 deletion src/trw/train/analysis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,12 @@ def classification_report(

cm = sklearn.metrics.confusion_matrix(y_pred=predictions, y_true=trues)
labels = list_classes_from_mapping(class_mapping)
if labels is not None:
labels_range = np.arange(0, len(labels))
else:
labels_range = None
try:
report_str = sklearn.metrics.classification_report(y_true=trues, y_pred=predictions, target_names=labels)
report_str = sklearn.metrics.classification_report(y_true=trues, y_pred=predictions, target_names=labels, labels=labels_range, zero_division=0)
except ValueError as e:
report_str = 'Report failed. Exception={}'.format(e)

Expand Down
19 changes: 19 additions & 0 deletions src/trw/train/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,22 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corner
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners)

def affine_grid(theta, size, align_corners=None):
"""
Compatibility layer for argument change between pytorch <= 1.2 and pytorch > 1.2
See :func:`torch.nn.functional.affine_grid
Note: default behavior is align_corners=True for pytorch <= 1.2 and align_corners=False for pytorch > 1.2
"""

version = torch.__version__[:3]
if version == '1.0' or version == '1.1' or version == '1.2':
return torch.nn.functional.affine_grid(
theta=theta,
size=size)
else:
return torch.nn.functional.affine_grid(
theta=theta,
size=size,
align_corners=align_corners)
94 changes: 94 additions & 0 deletions src/trw/train/sequence_array_fixed_samples_per_epoch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import warnings

import trw
import trw.utils
from trw.train import sequence
from trw.train import sampler as sampler_trw
import numpy as np
import collections
import copy
from trw.utils import get_batch_n
from trw.train import SequenceArray


# this the name used for the sample UID
sample_uid_name = 'sample_uid'


class SequenceArrayFixedSamplesPerEpoch(SequenceArray):
"""
Create a sequence of batches from numpy arrays, lists and :class:`torch.Tensor`
If number_of_samples_per_epoch is specified, only iterate trough chunks of samples at each iterator call.
If number_of_samples_per_epoch is None, behave as SequenceArray.
Note: the sampler iterator is not copied; it is necessary to create multiple instances of
SequenceArrayFixedSamplesPerEpoch and its Sampler to have multiple independent iterators.
"""
def __init__(self, split, sampler=None, transforms=None, use_advanced_indexing=True, sample_uid_name=sample_uid_name, number_of_samples_per_epoch=None):
"""
Args:
split: a dictionary of tensors. Tensors may be `numpy.ndarray`, `torch.Tensor`, numeric
sampler: the sampler to be used to iterate through the sequence
transforms: a transform or list of transforms to be applied on each batch of data
use_advanced_indexing:
sample_uid_name: if not `None`, create a unique UID per sample so that it is easy to track
particular samples (e.g., during data augmentation)
"""
if sampler is not None:
super().__init__(split, sampler=sampler, transforms=transforms, use_advanced_indexing=use_advanced_indexing, sample_uid_name=sample_uid_name)
else:
super().__init__(split, sampler=sampler_trw.SamplerRandom(), transforms=transforms, use_advanced_indexing=use_advanced_indexing, sample_uid_name=sample_uid_name)
self.number_of_samples_per_epoch = number_of_samples_per_epoch
self.sampler.initializer(self.split)
self.sampler_iterator = iter(self.sampler)

def subsample(self, nb_samples):
raise NotImplementedError()

def subsample_uids(self, uids, uids_name, new_sampler=None):
raise NotImplementedError()

def __iter__(self):
return SequenceArrayFixedSamplesPerEpochIterator(self)


class SequenceArrayFixedSamplesPerEpochIterator(sequence.SequenceIterator):
"""
Iterate the elements of an :class:`trw.train.SequenceArray` sequence
Assumptions:
- underlying `base_sequence` doesn't change sizes while iterating
"""
def __init__(self, base_sequence):
super().__init__()
self.base_sequence = base_sequence
self.nb_samples = trw.utils.len_batch(self.base_sequence.split)
self.number_samples_generated = 0

def __next__(self):
if self.base_sequence.number_of_samples_per_epoch is not None and \
self.number_samples_generated >= self.base_sequence.number_of_samples_per_epoch:
# we have reached the maximum number of samples, stop the sequence
raise StopIteration()

try:
indices = self.base_sequence.sampler_iterator.__next__()
except StopIteration:
self.base_sequence.sampler.initializer(self.base_sequence.split)
self.base_sequence.sampler_iterator = iter(self.base_sequence.sampler)
if self.base_sequence.number_of_samples_per_epoch is not None:
indices = self.base_sequence.sampler_iterator.__next__()
else:
raise StopIteration()

if not isinstance(indices, (np.ndarray, collections.Sequence)):
indices = [indices]

self.number_samples_generated += len(indices)

return get_batch_n(
self.base_sequence.split,
self.nb_samples,
indices,
self.base_sequence.transforms,
self.base_sequence.use_advanced_indexing)
Loading

0 comments on commit f9abe63

Please sign in to comment.