In [10]:
%autoreload 9

import torch
import numpy as np
import matplotlib.pyplot as plt
import os

from torch.utils.tensorboard import SummaryWriter

import gmm
import gmm_gumbel

from example import plot
from gumbel_regression import gumbel_stable_loss

from example import create_data_1

def create_data():
    np.random.seed(0)
    n_per_cluster = 40
    means = np.array([[0,0], [-5,5], [5,5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    true_y = np.array([0] * n_per_cluster + [1] * n_per_cluster + [2] * n_per_cluster)
    return torch.Tensor(X), true_y

def create_data2():
    np.random.seed(0)
    n_per_cluster = 50
    means = np.array([[0,0], [-5, -5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    true_y = np.array([0] * n_per_cluster + [1] * n_per_cluster)
    return torch.Tensor(X), true_y

In [68]:
def loss_fn(V, logP):
    # print("V:", V.shape)
    # print("logP:", logP.shape)
    z = logP - V
    
    # max_z = torch.zeros_like(z)
    max_z = torch.max(z)
    max_z = torch.where(max_z < -1.0, torch.tensor(-1.0, dtype=torch.double, device=max_z.device), max_z)
    max_z = max_z.detach()  # Detach the gradients
    loss = torch.mean(torch.exp(z - max_z) - z*torch.exp(-max_z) - torch.exp(-max_z), axis=1)
    return loss


def main():
    import shutil

    shutil.rmtree('examples/gumbel')
    os.mkdir('examples/gumbel')
    N = 300
    K = 4
    D = 2
    data, true_y, true_mus = create_data_1(N, K, D)
    true_y = np.array(true_y)
    true_mus = np.array(true_mus)
    g_fitter = gmm_gumbel.GaussianMixtureGumbel(K, D)
    
    mu = g_fitter.get_kmeans_mu(x=data, n_centers=K)
    g_fitter.mu.data = mu
    print(g_fitter.var)
    max_iter = 100
    writer = SummaryWriter()

    V = torch.rand(N, 1, 1, requires_grad=True) 
    V_lr = 0.001
    V_optim = torch.optim.Adam([V], lr=V_lr)
    
    for iter_ in range(max_iter):
        if iter_ % 1 == 0:
            pred_y = g_fitter.predict(data)
            plot(data, true_y, pred_y, iter_, g_fitter.mu[0], K, "gumbel")

        # Do V step
        for v_iter in range(1000):
            if v_iter % 50 == 0:
                print(V[:10,0,0])
            logP = g_fitter.estimate_log_prob(data)

            V_optim.zero_grad()

            loss = loss_fn(V, logP).mean()
            loss.backward(retain_graph=True)
            V_optim.step()
        
        print("Loss at iteration" + str(iter_) + ": ", loss, "\n")
        writer.add_scalar('loss', loss, iter_+1)
        print("Score:", g_fitter.score(data, as_average=True))
        
        # Do theta step
        # logP is logP(X|Z) and should be N x K
        weighted_log_prob = logP + torch.log(g_fitter.pi)
        log_prob_norm = V
        log_resp = weighted_log_prob - log_prob_norm
        
        # This is all the lines in gmm.em()
        pi, mu, var = g_fitter.m_step(data, log_resp)
        # g_fitter.update_pi(pi)
        g_fitter.update_mu(mu)
        # g_fitter.update_var(var)
        print(mu)

    writer.close()


main()


Parameter containing:
tensor([[[[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]],

         [[1., 0.],
          [0., 1.]]]])
tensor([0.1221, 0.6147, 0.4698, 0.5122, 0.9306, 0.4869, 0.4314, 0.7997, 0.3848,
        0.6266], grad_fn=<SelectBackward0>)
tensor([0.0721, 0.5647, 0.4198, 0.4622, 0.8806, 0.4369, 0.3814, 0.7497, 0.3348,
        0.5766], grad_fn=<SelectBackward0>)
tensor([0.0222, 0.5147, 0.3698, 0.4122, 0.8306, 0.3869, 0.3314, 0.6997, 0.2848,
        0.5266], grad_fn=<SelectBackward0>)
tensor([-0.0278,  0.4647,  0.3198,  0.3622,  0.7806,  0.3369,  0.2814,  0.6497,
         0.2348,  0.4766], grad_fn=<SelectBackward0>)
tensor([-0.0777,  0.4147,  0.2698,  0.3122,  0.7306,  0.2869,  0.2314,  0.5997,
         0.1849,  0.4266], grad_fn=<SelectBackward0>)
tensor([-0.1276,  0.3647,  0.2199,  0.2622,  0.6807,  0.2369,  0.1814,  0.5497,
         0.1349,  0.3767], grad_fn=<SelectBackward0>)
tensor([-0.1775,  0.3148,  0.1699, 