In [2]:
import numpy as np
import torch
import torch.nn as nn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.model_selection import train_test_split
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

In [11]:
def activation(act='relu', **kwargs):
    if act == 'relu':
        return nn.ReLU(**kwargs)
    elif act == 'softplus':
        return nn.Softplus(**kwargs)
    elif act == 'sigmoid':
        return nn.Sigmoid(**kwargs)
    elif act == 'tanh':
        return nn.Tanh(**kwargs)

    raise NotImplementedError(f'Activation "{act}" not supported.')

def layer(input_dim, output_dim, bias=True, act='relu', batch_norm=False, dropout=0.):
    yield nn.Linear(input_dim, output_dim, bias=bias)

    if batch_norm:
        yield nn.BatchNorm1d(output_dim)
    
    if act is not None:
        yield activation(act)

    if dropout > 0.:
        yield nn.Dropout(dropout)

def mlp(layers, bias=True, act='relu', final_act=None, batch_norm=False, final_norm=False, dropout=0., final_drop=0.):
    n_layers = len(layers)

    for i in range(1, n_layers):
        if i < n_layers - 1:
            yield from layer(layers[i - 1], layers[i], bias, act, batch_norm, dropout)
        else:
            yield from layer(layers[i - 1], layers[i], bias, final_act, final_norm, final_drop)

class MLP(nn.Module):
    def __init__(self, layers, bias=True, act='relu', final_act=None, batch_norm=False, final_norm=False, dropout=0., final_drop=0.):
        super().__init__()

        self.net = nn.Sequential(*list(mlp(layers, bias, act, final_act, batch_norm, dropout, final_drop)))

    def forward(self, x):
        y = self.net(x)

        return y

class Encoder(nn.Module):
    def __init__(self, layers, bias=True, act='relu', final_act='relu', batch_norm=False, final_norm=False, dropout=0., final_drop=0.):
        super().__init__()

        self.e_net = MLP(layers[:-1], bias, act, final_act, batch_norm, final_norm, dropout, final_drop)
        self.m_net = MLP(layers[-2:], bias, act, None, batch_norm, dropout=dropout)
        self.v_net = MLP(layers[-2:], bias, act, None, batch_norm, dropou=dropout)

    def forward(self, x):
        e = self.e_net(x)
        m = self.m_net(e)
        v = self.v_net(e).exp()

        return m, v

class Decoder(nn.Module):
    def __init__(self, layers, bias=True, act='relu', final_act=None, batch_norm=False, final_norm=False, dropout=0., final_drop=0.):
        super().__init__()

        self.d_net = MLP(layers, bias, act, final_act, batch_norm, final_norm, dropout, final_drop)

    def forward(self, z):
        d = self.d_net(z)

        return d
    
class NTM(nn.Module, BaseEstimator, TransformerMixin):
    def __init__(self, n_topics=2, doc_size=50, vocab_size=50, layers=(25,), bias=True, act='relu', encoder_act='relu', decoder_act=None, batch_norm=True, final_norm=False, dropout=.2, final_drop=0., divergence=.1):
        super().__init__()

        self.n_topics = n_topics
        self.doc_size = doc_size
        self.vocab_size = vocab_size
        self.layers = layers
        self.bias = bias
        self.act = act
        self.encoder_act = encoder_act
        self.decoder_act = decoder_act
        self.batch_norm = batch_norm
        self.final_norm = final_norm
        self.dropout = dropout
        self.final_drop = final_drop
        self.divergence = divergence

        self.encoder = Encoder(((vocab_size,) + layers + (n_topics,)), bias, act, encoder_act, batch_norm, batch_norm, dropout, final_drop)
        self.decoder = Decoder((n_topics, vocab_size), bias, act, decoder_act, batch_norm, final_norm, dropout, final_drop)
        self.optimizer = None
        self.loss_log = []

    def build(self, X, learning_rate=1e-2, batch_size=None):
        self.optimizer = Adam(self.parameters(), learning_rate)
        data_loader = DataLoader(X, X.shape[0] if batch_size is None else batch_size)
        

class VAE(nn.Module, BaseEstimator, TransformerMixin):
    def __init__(self, layers=(100, 10), bias=True, act='relu', encoder_act='relu', decoder_act=None, batch_norm=True, final_norm=False, dropout=.2, final_drop=0., divergence=.1):
        super().__init__()

        self.layers = layers
        self.bias = bias
        self.act = act
        self.encoder_act = encoder_act
        self.decoder_act = decoder_act
        self.batch_norm = batch_norm
        self.final_norm = final_norm
        self.dropout = dropout
        self.final_drop = final_drop
        self.divergence = divergence

        self.encoder = None
        self.decoder = None
        self.train_log = []
        self.test_log = []

    def build(self, X, learning_rate=1e-2, batch_size=None, test_size=.2):
        if self.layers[0] != X.shape[-1]:
            self.layers = (X.shape[-1], *self.layers)

        self.encoder = Encoder(self.layers, self.bias, self.act, self.encoder_act, self.batch_norm, self.batch_norm, self.dropout, self.final_drop)
        self.decoder = Decoder(self.layers[::-1], self.bias, self.act, self.decoder_act, self.batch_norm, self.final_norm, self.dropout, self.final_drop)
        self.optimizer = Adam(self.parameters(), learning_rate)

        if batch_size is None:
            batch_size = int(X.shape[0]*(1. - test_size))//16

        if test_size > 0.:
            X_train, X_test = train_test_split(X, test_size=test_size)
            train_loader, test_loader = DataLoader(X_train, batch_size), DataLoader(X_test, batch_size)
        else:
            train_loader, test_loader = DataLoader(X, batch_size), None

        return train_loader, test_loader

    def forward(self, x, return_divergence=False):
        m, v = self.encoder(x)
        z = m + v*torch.randn_like(v)

        if return_divergence:
            divergence = self.divergence*(m**2 + v**2 - v.log() - .5).sum()

            return z, divergence
        return z

    def backward(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def step(self, loader, grad=True):
        step_loss = 0.

        for x in loader:
            z, divergence = self(x, return_divergence=True)
            y = self.decoder(z)
            loss = (y - x).square().sum() + divergence
            step_loss += self.backward(loss) if grad else loss.item()

        step_loss /= loader.dataset.shape[0]

        return step_loss

    def fit(self, X, n_steps=150, learning_rate=1e-2, batch_size=None, test_size=.2, test_rate=1, desc='VAE', verbosity=1):
        train_loader, test_loader = self.build(X, learning_rate, batch_size, test_size)

        for i in tqdm(range(n_steps), desc) if verbosity == 1 else range(n_steps):
            self.train_log.append(self.step(train_loader, grad=True))

            if test_loader is not None and i%test_rate == 0:
                self.test_log.append(self.step(test_loader, grad=False))

        return self

    def transform(self, X):
        Z = self(X)

        return Z.detach()