Skip to content

Commit

Permalink
added tests & deprecation in dense layers
Browse files Browse the repository at this point in the history
  • Loading branch information
civodlu committed Jan 31, 2020
1 parent 34fa1dc commit d129c6d
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 8 deletions.
32 changes: 26 additions & 6 deletions src/trw/layers/denses.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
import torch.nn as nn
from trw.layers.flatten import Flatten
import warnings


def denses(sizes, dropout_probability=None, with_batchnorm=False, batchnorm_momentum=0.1, activation=nn.ReLU, last_layer_is_output=False, with_flatten=True):
def denses(
sizes,
dropout_probability=None,
with_batchnorm=None,
batchnorm_momentum=0.1,
activation=nn.ReLU,
last_layer_is_output=False,
with_flatten=True,
batch_norm_kwargs=None):
"""
Args:
sizes:
dropout_probability:
with_batchnorm:
batchnorm_momentum:
activation:
with_batchnorm: deprecated. Use `batch_norm_kwargs`
batchnorm_momentum: deprecated Use `batch_norm_kwargs`
activation: the activation to be used
last_layer_is_output: This must be set to `True` if the last layer of dense is actually an output. If the last layer is an output,
we should not add batch norm, dropout or activation of the last `nn.Linear`
with_flatten: if True, the input will be flattened
batch_norm_kwargs: specify the arguments to be used by the batch normalization layer
Returns:
a nn.Module
"""
ops = []

if with_batchnorm is not None:
warnings.warn('trw.layers.denses `with_batchnorm` and `batchnorm_momentum` arguments '
'are deprecated. Use `batch_norm_kwargs` instead!')
assert batch_norm_kwargs is None


if with_flatten:
ops.append(Flatten())
Expand All @@ -34,8 +50,12 @@ def denses(sizes, dropout_probability=None, with_batchnorm=False, batchnorm_mome
else:
ops.append(activation())

if with_batchnorm:
if with_batchnorm is not None and with_batchnorm is True:
# deprecated. TODO remove in next releases!
ops.append(nn.BatchNorm1d(next, momentum=batchnorm_momentum))

if batch_norm_kwargs is not None:
ops.append(nn.BatchNorm1d(next, **batch_norm_kwargs))

if dropout_probability is not None:
ops.append(nn.Dropout(p=dropout_probability))
Expand Down
10 changes: 8 additions & 2 deletions src/trw/train/callback_export_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ class CallbackExportClassificationReport(callback.Callback):
"""
max_class_names = 40

def __init__(self, with_confusion_matrix=True, with_ROC=True, with_history=True, with_report=True):
def __init__(self, with_confusion_matrix=True, with_ROC=True, with_report=True):
"""
Args:
with_confusion_matrix: if True, the confusion matrix will be exported
with_ROC: if True, the ROC curve will be exported
with_report: if True, the sklearn report will be exported
"""
self.with_confusion_matrix = with_confusion_matrix
self.with_ROC = with_ROC
self.with_history = with_history
self.with_report = with_report

def __call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/trw/train/callback_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@ class CallbackClearTensorboardLog(CallbackTensorboardBased):
def __call__(self, options, history, model, losses, outputs, datasets, datasets_infos, callbacks_per_batch, **kwargs):
CallbackTensorboardBased.remove_tensorboard_logger()
logger.debug('CallbackTensorboardBased.remove_tensorboard_logger called!')

58 changes: 58 additions & 0 deletions tests/test_callback_export_classification_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn as nn
import torch
import os
import numpy as np


class Criterion:
Expand Down Expand Up @@ -109,5 +110,62 @@ def test_basic(self):
assert 'x_truth_str=str_' in lines[-1]


def test_classification_report(self):
output_mappings = {
'good': {
'mappinginv': {
0: 'str_0',
1: 'str_1',
}
}
}

datasets_infos = {
'dataset1': {
'split1': {
'output_mappings': output_mappings
}
}
}

output_raw = np.random.randn(10, 2)
truth = np.random.randint(0, 2, 10)
outputs = {
'dataset1': {
'split1': {
'output1': {
'output_ref': trw.train.OutputClassification(None, 'good'),
'output_raw': output_raw,
'output': np.argmax(output_raw, axis=1),
'output_truth': truth
}
}
}
}

callback = trw.train.CallbackExportClassificationReport()
options = trw.train.create_default_options(device=torch.device('cpu'))
options['workflow_options']['current_logging_directory'] = os.path.join(
options['workflow_options']['logging_directory'],
'test_classification_report')
root_output = options['workflow_options']['current_logging_directory']
trw.train.create_or_recreate_folder(options['workflow_options']['current_logging_directory'])
callback(options, None, None, None, outputs, None, datasets_infos, None)

path_report = os.path.join(root_output, 'output1-dataset1-split1-report.txt')
path_roc = os.path.join(root_output, 'output1-dataset1-split1-ROC.png')
path_cm = os.path.join(root_output, 'output1-dataset1-split1-cm.png')
assert os.path.exists(path_report)
assert os.path.exists(path_roc)
assert os.path.exists(path_cm)

with open(path_report, 'r') as f:
lines = ''.join(f.readlines())

# make sure the class mapping was correct
assert 'str_0' in lines
assert 'str_1' in lines


if __name__ == '__main__':
unittest.main()
53 changes: 53 additions & 0 deletions tests/test_callback_tensorboard_record_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
import trw
import torch.nn as nn
import torch
import os


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.convs = trw.layers.ConvsBase(
cnn_dim=2,
input_channels=3,
channels=[4, 8],
strides=[2, 2],
with_flatten=True)
self.classifier = nn.Linear(32, 2)

def forward(self, batch):
x = batch['image']
x = self.convs(x)
x = self.classifier(x)

return {
'fake_symbols_2d': trw.train.OutputClassification(x, classes_name='classification')
}


def create_dataset():
return trw.datasets.create_fake_symbols_2d_datasset(
nb_samples=20,
global_scale_factor=0.5,
image_shape=[32, 32],
nb_classes_at_once=1,
max_classes=2)


class TestCallbackTensorboardRecordModel(unittest.TestCase):
def test_basic(self):
options = trw.train.create_default_options(device=torch.device('cpu'))
callback = trw.train.CallbackTensorboardRecordModel(onnx_folder='onnx_export')

onnx_root = os.path.join(options['workflow_options']['current_logging_directory'], 'onnx_export')
trw.train.create_or_recreate_folder(onnx_root)

model = Net()
datasets = create_dataset()
callback(options, None, model, None, None, datasets, None, None)
assert os.path.exists(os.path.join(onnx_root, 'model.onnx'))


if __name__ == '__main__':
unittest.main()

0 comments on commit d129c6d

Please sign in to comment.