In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from torch.autograd import Variable
from random import shuffle

%matplotlib inline

In [None]:
N = 200  # size of toy data

def build_linear_dataset(N, noise_std=0.5):
    X = np.linspace(-2, 2, num=N)
    Y = 3 * X + 1 + np.random.normal(0, noise_std, size=N)
    X, Y = X.reshape((N, 1)), Y.reshape((N, 1))
    return X, Y

def batch_generator(dataset, batch_size=5):
    shuffle(dataset)
    N_full_batches = len(dataset) // batch_size
    for i in range(N_full_batches):
        idx_from = batch_size * i
        idx_to = batch_size * (i + 1)
        xs, ys = zip(*[(x, y) for x, y in dataset[idx_from:idx_to]])
        yield xs, ys

X, Y = build_linear_dataset(N)

X_var = Variable(torch.Tensor(X))
Y_var = Variable(torch.Tensor(Y))

plt.plot(X, Y, '.');
plt.xlabel('X');
plt.ylabel('Y');

In [None]:
class MDN(nn.Module):
    def __init__(self, ndim_input=1, ndim_output=1, n_hidden=5, n_components=1):
        super(MDN, self).__init__()
        self.fc_in = nn.Linear(1, n_hidden)
        self.tanh = nn.Tanh()
        self.alpha_out = torch.nn.Sequential(
              nn.Linear(n_hidden, n_components),
              nn.Softmax()
            )
        self.logsigma_out = nn.Linear(n_hidden, n_components)
        self.mu_out = nn.Linear(n_hidden, n_components)  

    def forward(self, x):
        out = self.fc_in(x)
        act = self.tanh(out)
        out_alpha = self.alpha_out(act)
        out_sigma = torch.exp(self.logsigma_out(act))
        out_mu = self.mu_out(act)
        return (out_alpha, out_sigma, out_mu)

In [None]:
def gauss_pdf(y, mu, sigma, log=False):
    result = -0.5*torch.log(2*np.pi*sigma**2) - 1/(2*sigma**2) * (y.expand_as(mu) - mu)**2
    if log:
        return result
    else: 
        return torch.exp(result)
    
def mog_pdf(y, mus, sigmas, alphas, log=False): 
    
    result = Variable(torch.zeros(y.size()))
    
    for idx, a in enumerate(torch.transpose(alphas, 0, 1)):  
        
        sigma = sigmas[:, idx].unsqueeze(1)
        
        mu = mus[:, idx].unsqueeze(1)
        
        single_pdf = gauss_pdf(y, mu, sigma, log=False)
        
        result = torch.add(result, a * single_pdf)
    
    if log: 
        return torch.log(result)
    else: 
        return result
        
        
def mdn_loss_function(out_alpha, out_sigma, out_mu, y):
    #result = (gauss_pdf(y, out_mu, Variable(torch.ones((1)))) * out_alpha).squeeze()
    result = mog_pdf(y, mus=out_mu, sigmas=out_sigma, alphas=out_alpha, log=False)
    result = torch.log(result)
    result = torch.mean(result)  # mean over batch
    return -result

In [None]:
model = MDN(n_components=1)
optim = torch.optim.Adam(model.parameters(), lr=0.01)

def train(X, Y, n_epochs=500, n_minibatch=50):
    dataset_train = [(x, y) for x, y in zip(X, Y)]

    for epoch in range(n_epochs): 
        bgen = batch_generator(dataset_train, n_minibatch)

        for j, (x_batch, y_batch) in enumerate(bgen):
            x_var = Variable(torch.Tensor(x_batch))
            y_var = Variable(torch.Tensor(y_batch))
            
            (out_alpha, out_sigma, out_mu) = model(x_var)
            loss = mdn_loss_function(out_alpha, out_sigma, out_mu, y_var)
            
            optim.zero_grad()
            loss.backward()
            optim.step()

        if (epoch + 1) % 50 == 0:
            print("[epoch %04d] loss: %.4f" % (epoch + 1, loss.data[0]))
        
train(X, Y)

In [None]:
X.shape

In [None]:
X_var = Variable(torch.Tensor(X))
Y_var = Variable(torch.Tensor(Y))

(out_alpha, out_sigma, out_mu) = model(X_var)

mus = torch.sum(out_alpha * out_mu, 1)
sigmas = torch.sum(out_alpha * out_sigma, 1)

plt.plot(X, Y, '.')
plt.plot(X, mus.data.numpy(), '-r');
plt.plot(X, mus.data.numpy() + sigmas.data.numpy(), '-k');
plt.plot(X, mus.data.numpy() - sigmas.data.numpy(), '-k');