In [2]:
%load_ext autoreload
%autoreload 2

In [7]:
import torch
import torch.nn as nn

from torch.nn import functional as F

In [20]:
n_hidden = 128
batch_size = 128
num_epochs = 26
lr = 0.005

tokens = ['<', '>', '#', '%', ')', '(', '+', '-', '/', '.', '1', '0', '3', '2', '5', '4', '7',
          '6', '9', '8', '=', 'A', '@', 'C', 'B', 'F', 'I', 'H', 'O', 'N', 'P', 'S', '[', ']',
          '\\', 'c', 'e', 'i', 'l', 'o', 'n', 'p', 's', 'r', '\n']
tokens = ''.join(tokens) + ' '

In [16]:
def identity(input):
    return input

In [17]:
class OpenChemMLP(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params = params
        self.hidden_size = self.params['hidden_size']
        self.input_size = [self.params['input_size']] + self.hidden_size[:-1]
        self.n_layers = self.params['n_layers']
        self.activation = self.params['activation']
        if type(self.activation) is list:
            assert len(self.activation) == self.n_layers
        else:
            self.activation = [self.activation] * self.n_layers
        if 'dropout' in self.params.keys():
            self.dropout = self.params['dropout']
        else:
            self.dropout = 0
        self.layers = nn.ModuleList([])
        self.bn = nn.ModuleList([])
        self.dropouts = nn.ModuleList([])
        for i in range(self.n_layers - 1):
            self.dropouts.append(nn.Dropout(self.dropout))
            self.bn.append(nn.BatchNorm1d(self.hidden_size[i]))
            self.layers.append(nn.Linear(in_features=self.input_size[i], out_features=self.hidden_size[i]))
        i = self.n_layers - 1
        self.dropouts.append(nn.Dropout(self.dropout))
        self.layers.append(nn.Linear(in_features=self.input_size[i], out_features=self.hidden_size[i]))

    @staticmethod
    def get_required_params():
        return {
            'input_size': int,
            'n_layers': int,
            'hidden_size': list,
            'activation': None,
        }

    @staticmethod
    def get_optional_params():
        return {'dropout': float}

    def forward(self, inp):
        output = inp
        for i in range(self.n_layers - 1):
            output = self.dropouts[i](output)
            output = self.layers[i](output)
            output = self.bn[i](output)
            output = self.activation[i](output)
        output = self.dropouts[-1](output)
        output = self.layers[-1](output)
        output = self.activation[-1](output)
        return output

In [None]:
class OpenChemEncoder(nn.Module):
    """Base class for embedding module"""
    def __init__(self, params, use_cuda=None):
        super(OpenChemEncoder, self).__init__()
        self.params = params
        if use_cuda is None:
            use_cuda = torch.cuda.is_available()
        self.use_cuda = use_cuda
        self.input_size = self.params['input_size']
        self.encoder_dim = self.params['encoder_dim']

    @staticmethod
    def get_required_params():
        return {'input_size': int, 'encoder_dim': int}

    @staticmethod
    def get_optional_params():
        return {}

    def forward(self, inp):
        raise NotImplementedError


class RNNEncoder(OpenChemEncoder):
    def __init__(self, params, use_cuda):
        super(RNNEncoder, self).__init__(params, use_cuda)
        check_params(params, self.get_required_params(), self.get_optional_params())
        self.layer = self.params['layer']
        layers = ['LSTM', 'GRU', 'RNN']
        if self.layer not in ['LSTM', 'GRU', 'RNN']:
            raise ValueError(self.layer + ' is invalid value for argument'
                             ' \'layer\'. Choose one from :' + ', '.join(layers))

        self.input_size = self.params['input_size']
        self.encoder_dim = self.params['encoder_dim']
        self.n_layers = self.params['n_layers']
        if self.n_layers > 1:
            self.dropout = self.params['dropout']
        else:
            UserWarning('dropout can be non zero only when n_layers > 1. ' 'Parameter dropout set to 0.')
            self.dropout = 0
        self.bidirectional = self.params['is_bidirectional']
        if self.bidirectional:
            self.n_directions = 2
        else:
            self.n_directions = 1
        if self.layer == 'LSTM':
            self.rnn = nn.LSTM(self.input_size,
                               self.encoder_dim,
                               self.n_layers,
                               bidirectional=self.bidirectional,
                               dropout=self.dropout,
                               batch_first=True)
        elif self.layer == 'GRU':
            self.rnn = nn.GRU(self.input_size,
                              self.encoder_dim,
                              self.n_layers,
                              bidirectional=self.bidirectional,
                              dropout=self.dropout,
                              batch_first=True)
        else:
            self.layer = nn.RNN(self.input_size,
                                self.encoder_dim,
                                self.n_layers,
                                bidirectional=self.bidirectional,
                                dropout=self.dropout,
                                batch_first=True)

    @staticmethod
    def get_required_params():
        return {
            'input_size': int,
            'encoder_dim': int,
        }

    @staticmethod
    def get_optional_params():
        return {'layer': str, 'n_layers': int, 'dropout': float, 'is_bidirectional': bool}

    def forward(self, inp, previous_hidden=None, pack=True):
        """
        inp: shape batch_size, seq_len, input_size
        previous_hidden: if given shape n_layers * num_directions,
        batch_size, embedding_dim.
               Initialized automatically if None
        return: embedded
        """
        input_tensor = inp[0]
        input_length = inp[1]
        batch_size = input_tensor.size(0)
        # TODO: warning: output shape is changed! (batch_first=True) Check hidden
        if pack:
            input_lengths_sorted, perm_idx = torch.sort(input_length, dim=0, descending=True)
            input_lengths_sorted = input_lengths_sorted.detach().to(device="cpu").tolist()
            input_tensor = torch.index_select(input_tensor, 0, perm_idx)
            rnn_input = pack_padded_sequence(input=input_tensor,
                                             lengths=input_lengths_sorted,
                                             batch_first=True)
        else:
            rnn_input = input_tensor
        if previous_hidden is None:
            previous_hidden = self.init_hidden(batch_size)
            if self.layer == 'LSTM':
                cell = self.init_cell(batch_size)
                previous_hidden = (previous_hidden, cell)
        else:
            if self.layer == 'LSTM':
                hidden = previous_hidden[0]
                cell = previous_hidden[1]
                hidden = torch.index_select(hidden, 1, perm_idx)
                cell = torch.index_select(cell, 1, perm_idx)
                previous_hidden = (hidden, cell)
            else:
                previous_hidden = torch.index_select(previous_hidden, 1, perm_idx)
        rnn_output, next_hidden = self.rnn(rnn_input)  # , previous_hidden)

        if pack:
            rnn_output, _ = pad_packed_sequence(rnn_output, batch_first=True)
            _, unperm_idx = perm_idx.sort(0)
            rnn_output = torch.index_select(rnn_output, 0, unperm_idx)
            if self.layer == 'LSTM':
                hidden = next_hidden[0]
                cell = next_hidden[1]
                hidden = torch.index_select(hidden, 1, unperm_idx)
                cell = torch.index_select(cell, 1, unperm_idx)
                next_hidden = (hidden, cell)
            else:
                next_hidden = torch.index_select(next_hidden, 1, unperm_idx)

        index_t = (input_length - 1).to(dtype=torch.long)
        index_t = index_t.view(-1, 1, 1).expand(-1, 1, rnn_output.size(2))

        embedded = torch.gather(rnn_output, dim=1, index=index_t).squeeze(1)

        return embedded, next_hidden

    def init_hidden(self, batch_size):
        if self.use_cuda:
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return torch.zeros(self.n_layers * self.n_directions, batch_size, self.encoder_dim, device=device)

    def init_cell(self, batch_size):
        if self.use_cuda:
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")
        return torch.zeros(self.n_layers * self.n_directions, batch_size, self.encoder_dim, device=device)

In [21]:
class OpenChemEmbedding(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params = params
        self.num_embeddings = self.params['num_embeddings']
        if 'padding_idx' in params.keys():
            self.padding_idx = self.params['padding_idx']
        else:
            self.padding_idx = None

    def forward(self, inp):
        raise NotImplementedError

    @staticmethod
    def get_required_params():
        return {
            'num_embeddings': int,
        }

    @staticmethod
    def get_optional_params():
        return {'padding_idx': int}


class Embedding(OpenChemEmbedding):
    def __init__(self, params):
        super().__init__(params)
        self.embedding_dim = self.params['embedding_dim']
        self.embedding = nn.Embedding(num_embeddings=self.num_embeddings,
                                      embedding_dim=self.embedding_dim,
                                      padding_idx=self.padding_idx)

    def forward(self, inp):
        embedded = self.embedding(inp)
        return embedded

    @staticmethod
    def get_required_params():
        return {
            'embedding_dim': int,
        }

In [22]:
emb = Embedding({
        'num_embeddings': len(tokens),
        'embedding_dim': n_hidden,
        'padding_idx': tokens.index(' ')
    })

enc = 

mlp = OpenChemMLP({
        'input_size': n_hidden,
        'n_layers': 2,
        'hidden_size': [n_hidden, 1],
        'activation': [F.relu, identity],
        'dropout': 0.0
    })

TypeError: __init__() got an unexpected keyword argument 'input_size'

In [23]:
emb

Embedding(
  (embedding): Embedding(46, 128, padding_idx=45)
)