In [1]:
%autoreload 9

import torch
import numpy as np
import matplotlib.pyplot as plt
import gmm
from example import plot
from gumbel_regression import *

def create_data():
    np.random.seed(0)
    n_per_cluster = 40
    means = np.array([[0,0], [-5,5], [5,5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    for p_idx, point in enumerate(X):
            plt.scatter(X[p_idx, 0], X[p_idx, 1], color='blue', s = 10)
    return torch.Tensor(X)



In [15]:
def main():
    X = create_data()
    g_fitter = gmm.GaussianMixture(n_components=3, n_features=2)
    print(g_fitter.fit(torch.Tensor(X)))
    n = 120
    K = 3

    V = torch.zeros(n, requires_grad=True) 
    V_mean = torch.mean(V)

    def gumbel_stable_loss(z, clip=None):
        if clip is not None:
            z = torch.clamp(z, max=clip)

        max_z = torch.max(z)
        max_z = torch.where(max_z < -1.0, torch.tensor(-1.0, dtype=torch.double, device=max_z.device), max_z)
        max_z = max_z.detach()  # Detach the gradients
        loss = torch.exp(z - max_z) - z*torch.exp(-max_z) - torch.exp(-max_z)    # scale by e^max_z
        return loss

    def loss_fn(V, logP):
        # loss = gumbel_stable_loss(logP - V, clip=10) # + logP
        loss = torch.mean(torch.exp(logP-V) - V - 1)
        return loss

    V_lr = 0.001
    # loss_fn = gumbel_stable_loss
    # print(data)

    max_iter = 100
    V_optim = torch.optim.Adam([V], lr=V_lr)
    data = X
    logP = g_fitter._estimate_log_prob(data)
    loss = loss_fn(V, logP)
    print("Initial loss:", loss.mean())
    print("initial mu", g_fitter.mu)
    g_fitter.mu = torch.nn.Parameter(torch.randn(1, 3, 2), requires_grad=False)
    print("next mu", g_fitter.mu)
    # print(g_fitter._estimate_log_prob(data))
    print("First datapoint:", data[0])

    for i in range(max_iter):
        y = g_fitter.predict(data)
        if i % 10 == 0:
            print("I", i)
            plot(data, y, i)

        # Do V step

        data = X
        logP = g_fitter._estimate_log_prob(data)

        V_optim.zero_grad()
        # loss = torch.mean(loss_fn(data-V))
        loss = loss_fn(V, logP)
        loss = loss.mean()
        loss.backward(retain_graph=True)
        V_optim.step()
        print("V0", V[0])
        print("Loss:", loss)
        print()


        # Do theta step
    #    print(logP.shape)
        weighted_log_prob = logP + np.log(1/K) * torch.ones_like(logP)
        log_prob_norm = V.unsqueeze(1).unsqueeze(2)
        log_resp = weighted_log_prob - log_prob_norm
       # print(V.shape, log_resp.shape)
        pi, mu, var = g_fitter._m_step(data, log_resp)

        g_fitter.update_pi(pi)
        g_fitter.update_mu(mu)
        g_fitter.update_var(var)

main()


-inf
tensor(-3.8913)
None
Initial loss: tensor(-0.9725, grad_fn=<MeanBackward0>)
initial mu Parameter containing:
tensor([[[-4.9636,  5.4392],
         [ 4.9706,  4.9009],
         [-0.0089, -0.0522]]])
next mu Parameter containing:
tensor([[[-0.1483,  0.1326],
         [-2.0981, -0.8472],
         [ 0.6098, -1.6213]]])
First datapoint: tensor([1.7641, 0.4002])
I 0
V0 tensor(0.0010, grad_fn=<SelectBackward0>)
Loss: tensor(-0.9845, grad_fn=<MeanBackward0>)

V0 tensor(0.0020, grad_fn=<SelectBackward0>)
Loss: tensor(-0.9665, grad_fn=<MeanBackward0>)

V0 tensor(0.0030, grad_fn=<SelectBackward0>)
Loss: tensor(-0.9538, grad_fn=<MeanBackward0>)

V0 tensor(0.0040, grad_fn=<SelectBackward0>)
Loss: tensor(-0.9428, grad_fn=<MeanBackward0>)

V0 tensor(0.0050, grad_fn=<SelectBackward0>)
Loss: tensor(-0.9316, grad_fn=<MeanBackward0>)

V0 tensor(0.0060, grad_fn=<SelectBackward0>)
Loss: tensor(-0.9150, grad_fn=<MeanBackward0>)

V0 tensor(0.0070, grad_fn=<SelectBackward0>)
Loss: tensor(-0.8935, grad_fn