Skip to content

Commit

Permalink
UNet tuple stride support
Browse files Browse the repository at this point in the history
  • Loading branch information
civodlu committed Mar 29, 2021
1 parent 0a6af07 commit e700a23
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/trw/layers/blocks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import warnings
from numbers import Number

import torch
from trw.basic_typing import TorchTensorNCX, Padding, KernelSize, Stride
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(
assert stride is not None

ops = []
if stride != 1:
if (isinstance(stride, Number) and stride != 1) or (max(stride) != 1 or min(stride) != 1):
# if stride is 1, don't upsample!
ops.append(config.ops.upsample_fn(scale_factor=stride))

Expand Down
12 changes: 11 additions & 1 deletion src/trw/layers/unet_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from numbers import Number
from typing import Sequence, Optional, Union, Any, List

from trw.layers.convs import ModuleWithIntermediate
Expand All @@ -9,6 +10,7 @@
from trw.utils import upsample
from trw.layers.blocks import BlockConvNormActivation, BlockUpDeconvSkipConv, ConvBlockType
from trw.layers.layer_config import LayerConfig, default_layer_config
import numpy as np


class DownType(Protocol):
Expand Down Expand Up @@ -248,13 +250,21 @@ def _build(self, config, init_block_fn, down_block_fn, up_block_fn, middle_block
out_channels = skip_channels

stride = strides[len(strides) - i - 1]
if isinstance(stride, Number):
stride_minus_one = stride - 1
else:
assert len(stride) == config.ops.dim, f'expected dim={config.ops.dim}' \
f'for `stride` but got={len(stride)}'
stride_minus_one = tuple(np.asarray(stride))
stride = tuple(stride)

self.ups.append(up_block_fn(
config,
len(self.channels) - i - 1,
skip_channels=skip_channels,
input_channels=input_channels,
output_channels=out_channels,
output_padding=stride - 1,
output_padding=stride_minus_one,
stride=stride))
input_channels = skip_channels

Expand Down
2 changes: 1 addition & 1 deletion src/trw/train/outputs_trw.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def evaluate_batch(self, batch, is_training):
assert truth is not None, 'truth is `None` use `maybe_optional` to True'
assert len(truth) == len(self.output), f'expected len output ({len(self.output)}) == len truth ({len(truth)})!'
assert isinstance(truth, torch.Tensor), 'feature must be a torch.Tensor!'
assert truth.dtype == torch.long, 'the truth vector must be a `long` type feature'
assert truth.dtype == torch.long, f'the truth vector must be a `long` type feature, got={truth.dtype}'

# make sure the class is not out of bound. This is a very common mistake!
# max_index = int(torch.max(truth).cpu().numpy())
Expand Down

0 comments on commit e700a23

Please sign in to comment.