Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion models_pytorch/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
def get_model(base='', **model_params):
base = base.lower()
model_cls = ClassificationModel ## right now only this one exists, so we can figure that out later
encoder_preprocessing = model_params.pop('preprocessing', True)
model = model_cls(**model_params)

encoder_weights = model_params['encoder_weights']

preprocessing = None
if encoder_weights is not None:
if encoder_weights is not None and encoder_preprocessing:
preprocessing = get_preprocessing_fn(model_params['encoder'], pretrained=encoder_weights)

return model, preprocessing
2 changes: 1 addition & 1 deletion models_pytorch/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def predict(self, x, **args):
if self.training:
self.eval()
with torch.no_grad():
features = self.encoder(x, **args)
features = self.encoder(x)
scalars = {k: args[k] for k in self.required_inputs}
output = self.classifier.predict(features, **scalars)
return output
Expand Down
4 changes: 2 additions & 2 deletions models_pytorch/classification/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .basic import BasicClassifier
from .extra_inputs import ExtraScalarInputsClassifier
from .conv_lstm import ConvLSTMClassifier
from .srn import SpatialRegularizationClassifier

classifier_map = {
'basic': BasicClassifier,
'srn': SpatialRegularizationClassifier,
'extra_inputs': ExtraScalarInputsClassifier,
'conv_lstm': ConvLSTMClassifier,
}


Expand Down
48 changes: 35 additions & 13 deletions models_pytorch/classification/classifiers/_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

from models_pytorch.utils import Flatten, get_activation

pooling_types = {
'avg': nn.AdaptiveAvgPool2d,
'max': nn.AdaptiveMaxPool2d,
}

class BasicClassifier(nn.Module):

pooling_types = {
'avg': nn.AdaptiveAvgPool2d,
'max': nn.AdaptiveMaxPool2d,
}
class BasicClassifier(nn.Module):
pooling_types = pooling_types

def __init__(self, encoder_channels, nclasses, hidden_layers=(), pool_type='avg', channel_index=0,
activation='sigmoid'):
activation='sigmoid', extra_inputs=None):

self.nclasses = nclasses
self.activation_type = str(activation)
Expand All @@ -22,14 +23,26 @@ def __init__(self, encoder_channels, nclasses, hidden_layers=(), pool_type='avg'
self.channel_index = channel_index
input_shape = encoder_channels[self.channel_index]

super().__init__()

self.required_inputs = []
self.extra_input_fcs = None
if extra_inputs is not None:
self.required_inputs += list(extra_inputs.keys())
self.extra_input_fcs = nn.ModuleDict()
for key, nh in extra_inputs.items():
self.extra_input_fcs[key] = nn.Linear(1, nh)
input_shape += nh

input_shapes = [input_shape]
input_shapes.extend(hidden_layers)

output_shapes = list(hidden_layers)
output_shapes.append(nclasses)

super().__init__()
modules = [self.pooling_types[pool_type]((1, 1)), Flatten()]
self.pool = self.pooling_types[pool_type]((1, 1))
self.flatten = Flatten()
modules = [nn.Sequential(), nn.Sequential()] # this is for backwards compatz
for ih, oh in zip(input_shapes, output_shapes):
modules.extend([nn.Linear(ih, oh),
nn.ReLU()])
Expand All @@ -38,14 +51,23 @@ def __init__(self, encoder_channels, nclasses, hidden_layers=(), pool_type='avg'
self.classifier = nn.Sequential(*modules)
self.activation = get_activation(activation)

def forward(self, features):
return self.classifier(features[self.channel_index])

def predict(self, features):
def forward(self, features, **inputs):
encoder_features = self.pool(features[self.channel_index])
outputs = [encoder_features]
if self.extra_input_fcs is not None:
inputs = {k: v for k, v in inputs.items() if v is not None}
assert all(k in inputs.keys() for k in self.extra_input_fcs.keys()), 'incorrect keys input into network'
for k, fc in self.extra_input_fcs.items():
h = self.flatten(fc(inputs[k]))
h = torch.unsqueeze(torch.unsqueeze(h, -1), -1)
outputs.append(h)
return self.classifier(self.flatten((torch.cat(outputs, dim=1))))

def predict(self, features, **inputs):
if self.training:
self.eval()
with torch.no_grad():
return self.activation(self.forward(features))
return self.activation(self.forward(features, **inputs))


classifier_types = {'basic': BasicClassifier}
Expand Down
6 changes: 4 additions & 2 deletions models_pytorch/classification/classifiers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def __init__(self, encoder_channels, tasks='cls', classifier_params=None, activa
d['encoder_channels'] = encoder_channels
d.update(**kwargs)
self.classifiers = nn.ModuleDict({name: get_classifier(**args) for name, args in task_params.items()})
self.required_inputs = list(set(i for c in self.classifiers.values() for i in c.required_inputs))

self.is_multi_task = len(self.classifiers) > 1

def output_info(self):
Expand All @@ -36,7 +38,7 @@ def tasks(self):
return list(self.output_info().keys())

def forward(self, features, **kwargs):
output = [(name, classifier(features)) for name, classifier in self.classifiers.items()]
output = [(name, classifier(features, **kwargs)) for name, classifier in self.classifiers.items()]
if not self.is_multi_task:
output = output[0][1]
else:
Expand All @@ -45,7 +47,7 @@ def forward(self, features, **kwargs):
return output

def predict(self, features, **kwargs):
output = [(name, classifier.predict(features)) for name, classifier in self.classifiers.items()]
output = [(name, classifier.predict(features, **kwargs)) for name, classifier in self.classifiers.items()]

if not self.is_multi_task:
output = output[0][1]
Expand Down
228 changes: 228 additions & 0 deletions models_pytorch/classification/classifiers/conv_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from models_pytorch.utils import get_activation


# Reference: Modified from: (1) https://github.com/ndrplz/ConvLSTM_pytorch/blob/master/convlstm.py
# (2) https://github.com/automan000/Convolution_LSTM_PyTorch/blob/master/convolution_lstm.py

# ConvLSTM Cell
class ConvLSTMCell(nn.Module):
"""
Conv-LSTM Cell based on "Convolutional LSTM Network: A Machine Learning Approach
for Precipitation Nowcasting", arXiv: https://arxiv.org/pdf/1506.04214.pdf
"""

def __init__(self, feature_size, in_planes, hidden_planes, kernel_size, bias=True):
"""
feature_size: (int, int)
(height, width) of input feature (tensor)
in_planes: int
Number of channels in input feature (tensor)
hidden_planes: int
Number of channels in hidden state (tensor)
kernel_size: int
Size of the convolutional kernel
bias: bool
Whether to add the bias or not
"""
super(ConvLSTMCell, self).__init__()
self.height, self.width = feature_size
self.in_planes = in_planes
self.hidden_planes = hidden_planes
self.kernel = kernel_size
self.padding = kernel_size // 2
self.bias = bias

# Equation: 3 from paper
# Wxi * Xt
self.WiXconv = nn.Conv2d(in_channels=self.in_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Whi * Ht-1
self.WiHconv = nn.Conv2d(in_channels=self.hidden_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Wxf * Xt
self.WfXconv = nn.Conv2d(in_channels=self.in_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Whf * Ht-1
self.WfHconv = nn.Conv2d(in_channels=self.hidden_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Wxc * Xt
self.WcXconv = nn.Conv2d(in_channels=self.in_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Whc * Ht-1
self.WcHconv = nn.Conv2d(in_channels=self.hidden_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Wxo * Xt
self.WoXconv = nn.Conv2d(in_channels=self.in_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

# Who * Ht-1
self.WoHconv = nn.Conv2d(in_channels=self.hidden_planes, out_channels=self.hidden_planes,
kernel_size=self.kernel, stride=1, padding=self.padding, bias=self.bias)

self.Wci = nn.Parameter(torch.zeros(1, self.hidden_planes, self.height, self.width))
self.Wcf = nn.Parameter(torch.zeros(1, self.hidden_planes, self.height, self.width))
self.Wco = nn.Parameter(torch.zeros(1, self.hidden_planes, self.height, self.width))

def forward(self, x, h_prev, c_prev):
"""
x: Tensor from CNN
h_prev: Output from previous step
c_prev: Output from previous step
"""
# Ignoring bias in the equations below
it = torch.sigmoid(self.WiXconv(x) + self.WiHconv(h_prev) + c_prev * self.Wci)
ft = torch.sigmoid(self.WfXconv(x) + self.WfHconv(h_prev) + c_prev * self.Wcf)
ct = ft * c_prev + it * torch.tanh(self.WcXconv(x) + self.WcHconv(h_prev))
ot = torch.sigmoid(self.WoXconv(x) + self.WoHconv(h_prev) + ct * self.Wco)
ht = ot * torch.tanh(ct)
return ht, ct

def init_hidden(self, batch_size):
return (Variable(torch.zeros(batch_size, self.hidden_planes, self.height, self.width)).cuda(),
Variable(torch.zeros(batch_size, self.hidden_planes, self.height, self.width)).cuda())


# ConvLSTM
class ConvLSTM(nn.Module):
"""
ConvLSTM based on "Convolutional LSTM Network: A Machine Learning Approach
for Precipitation Nowcasting", arXiv: https://arxiv.org/pdf/1506.04214.pdf
"""

def __init__(self, feature_size, in_planes, hidden_planes, kernel_size, bias=True):
"""
feature_size: (int, int)
(height, width) of input feature (tensor)
in_planes: int
Number of channels in input feature (tensor)
hidden_planes: list of int
List of number of channels in hidden state (tensor)
kernel_size: int
Size of the convolutional kernel
bias: bool
Whether to add the bias or not
"""
super(ConvLSTM, self).__init__()
self.height, self.width = feature_size
self.in_planes = [in_planes] + hidden_planes
self.hidden_planes = hidden_planes
self.kernel_size = kernel_size
self.bias = bias
self.num_cells = len(hidden_planes)
cell_list = []
for i in range(self.num_cells):
cell = ConvLSTMCell(feature_size=(self.height, self.width),
in_planes=self.in_planes[i],
hidden_planes=self.hidden_planes[i],
kernel_size=self.kernel_size, bias=self.bias)
cell_list.append(cell)
self.cells = nn.ModuleList(cell_list)

def _init_hidden(self, batch_size):
init_states = []
for i in range(self.num_cells):
init_states.append(self.cells[i].init_hidden(batch_size))
return init_states

def forward(self, feature, hidden_state=None):
"""
feature: 5-D Tensor of shape (BS, T, C, H, W)
"""
if hidden_state is None:
hidden_state = self._init_hidden(batch_size=feature.size(0))

# Sequence length
T = feature.size(1)
curr_input = feature

cells_output = []
last_state_list = []
for cell_idx in range(self.num_cells):
h, c = hidden_state[cell_idx]
# Loop through sequence
seq_output = []
for t in range(T):
h, c = self.cells[cell_idx](x=curr_input[:, t, :, :, :],
h_prev=h, c_prev=c)
seq_output.append(h)

cell_output = torch.stack(seq_output, dim=1)
curr_input = cell_output

cells_output.append(cell_output)
last_state_list.append([h, c])

return cells_output[-1:], last_state_list[-1:]


class FullyConnected(nn.Module):
def __init__(self, in_features=64, nclasses=6, use_relu6=True):
super(FullyConnected, self).__init__()
self.in_features = in_features
self.nclasses = nclasses
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc1 = nn.Linear(self.in_features, self.in_features * 2)
if use_relu6:
self.relu = nn.ReLU6(inplace=True)
else:
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(self.in_features * 2, self.nclasses)

def forward(self, x):
x = x.squeeze(0)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x


class ConvLSTMClassifier(nn.Module):

def __init__(self, encoder_channels, activation='sigmoid',
nclasses=6,
feature_shape=(7, 7),
hidden_planes=(128, 64, 64), kernel_size=5, bias=True,
encoder_index=0,
fc_relu6=True,
):
self.encoder_index = encoder_index
self.activation_type = str(activation)
self.hidden_planes = list(hidden_planes)
super(ConvLSTMClassifier, self).__init__()
self.activation = get_activation(activation)

self.is_multi_task = False

self.nclasses = nclasses
self.feature_shape = tuple(feature_shape)

self.conv_lstm = ConvLSTM(feature_shape, encoder_channels[self.encoder_index], self.hidden_planes,
kernel_size, bias=bias)

self.fc = FullyConnected(in_features=self.hidden_planes[-1], nclasses=self.nclasses,
use_relu6=fc_relu6)

def output_info(self):
return {'final': {'nclasses': self.nclasses, 'activation': self.activation_type}}

def forward(self, features, **kwargs):
features = features[self.encoder_index]
features = F.interpolate(features, size=self.feature_shape).unsqueeze(0)
cells_output, states_output = self.conv_lstm(features)
return self.fc(cells_output[0])

def predict(self, features, **kwargs):
return self.activation(self.forward(features, **kwargs))
Loading