In [131]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [135]:
class TreeNNLayer(nn.Module):
    def __init__(self, layer_class, out_features, *in_layers, **kwargs):
        super(__class__, self).__init__()
        
        self.in_layers = in_layers
        if len(in_layers):
            self.has_in_layers = True
            self.in_features = sum(map(lambda l: l.out_features, in_layers))
            self.in_key = None
        else:
            self.has_in_layers = False
            self.in_features = kwargs['in_features']
            self.in_key = kwargs['in_key']
            
        self.out_features = out_features
        
        layer_args = {}
        if 'layer_args' in kwargs:
            layer_args = kwargs['layer_args']

        # FIXME find a better way to know whether to pass in/out features
        if getattr(nn, layer_class).__module__.endswith('activation'):
            self.layer = getattr(nn, layer_class)(**layer_args)
        else:
            self.layer = getattr(nn, layer_class)(self.in_features, self.out_features, **layer_args)
        
    def forward(self, X):
        if not self.has_in_layers:
            assert isinstance(X, dict)
            return self.layer(X[self.in_key])
        
        inputs = []
        for layer in self.in_layers:
            inputs.append(layer(X))
        
        return self.layer(torch.cat(inputs, 1))


class TreeNN(nn.Module):
    def __init__(self, loss_class, optimizer_class, lr, out_layer, loss_args = {}, optimizer_args = {}):
        super(__class__, self).__init__()
        
        self.out_layer = out_layer
        
        self.loss = getattr(nn, loss_class)(**loss_args)
        
        optimizer_args['lr'] = lr
        self.optimizer = getattr(optim, optimizer_class)(self.parameters(), **optimizer_args)
        
    def forward(self, X):
        return self.out_layer(X)
        
    def fit(self, X, Y):
        Y_hat = self(X)
        
        loss = self.loss(Y_hat, Y)
        loss.backward()
        
        self.optimizer.step()
        self.zero_grad()
        
        return loss.item()
        

In [136]:
treenn = TreeNN(
    'MSELoss', 'Adagrad', 0.1, TreeNNLayer('Linear', 1,
        TreeNNLayer('ReLU', 8,
            TreeNNLayer('Linear', 8, in_features = 4, in_key = 'foo')
        ),
        TreeNNLayer('ReLU', 8,
            TreeNNLayer('Linear', 8, in_features = 4, in_key = 'bar')
        )
    )
)
treenn

TreeNN(
  (out_layer): TreeNNLayer(
    (layer): Linear(in_features=16, out_features=1, bias=True)
  )
  (loss): MSELoss()
)

In [142]:
treenn.fit({
    'foo': torch.tensor([1, 2, 3, 4], dtype=torch.float).view(1, -1),
    'bar': torch.tensor([10, 20, 30, 40], dtype=torch.float).view(1, -1)
}, torch.tensor([5], dtype=torch.float).view(1, -1))

0.0003173103032168001