In [4]:
%autoreload 9

import torch
import numpy as np
import matplotlib.pyplot as plt
import os

from torch.utils.tensorboard import SummaryWriter

import gmm
import gmm_gumbel

from example import plot
from gumbel_regression import gumbel_stable_loss

from example import create_data_1

def create_data():
    np.random.seed(0)
    n_per_cluster = 40
    means = np.array([[0,0], [-5,5], [5,5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    true_y = np.array([0] * n_per_cluster + [1] * n_per_cluster + [2] * n_per_cluster)
    return torch.Tensor(X), true_y

def create_data2():
    np.random.seed(0)
    n_per_cluster = 50
    means = np.array([[0,0], [-5, -5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    true_y = np.array([0] * n_per_cluster + [1] * n_per_cluster)
    return torch.Tensor(X), true_y

In [6]:
def loss_fn(V, logP):
    # loss = gumbel_stable_loss(logP - V, clip=10) # + logP

    loss = torch.mean(torch.exp(logP-V) - V - 1)
    z = logP - V

    max_z = torch.max(z)
    max_z = torch.where(max_z < -1.0, torch.tensor(-1.0, dtype=torch.double, device=max_z.device), max_z)
    max_z = max_z.detach()  # Detach the gradients
    loss = loss * torch.exp(-max_z)
    return loss


def main():
    N = 300
    K = 4
    D = 2
    data, true_y = create_data_1(N=N, K=K, D=D)
    N = len(data)
    g_fitter = gmm_gumbel.GaussianMixtureGumbel(n_components=K, n_features=D)
    max_iter = 100
    writer = SummaryWriter()
    

    V = torch.zeros(n, requires_grad=True) 
    V_mean = torch.mean(V)
    V_lr = 0.01
    V_optim = torch.optim.Adam([V], lr=V_lr)
    # loss_fn = gumbel_stable_loss

    logP = g_fitter._estimate_log_prob(data)
    loss = loss_fn(V, logP)  # TODO: Do we need to compute mean here?
    writer.add_scalar('loss', loss, 0)
    
    # Randomly initialize means
    g_fitter.mu = torch.nn.Parameter(torch.Tensor([[[-5, -5], [0,0]]]), requires_grad=False)

    for iter_ in range(max_iter):
        pred_y = g_fitter.predict(data)
        if iter_ % 1 == 0:
            plot(data, true_y, pred_y, iter_, g_fitter.mu)

        # Do V step
        logP = g_fitter._estimate_log_prob(data)

        V_optim.zero_grad()
        # loss = torch.mean(loss_fn(data-V))
        loss = loss_fn(V, logP)
        loss = loss.mean()
        writer.add_scalar('loss', loss, iter_+1)
        loss.backward(retain_graph=True)
        V_optim.step()
        print("Loss:", loss, "\n")
        
        # Do theta step
        weighted_log_prob = logP + np.log(1/k) * torch.ones_like(logP)
        log_prob_norm = V.unsqueeze(1).unsqueeze(2)
        log_resp = weighted_log_prob - log_prob_norm
        pi, mu, var = g_fitter._m_step(data, log_resp)
        g_fitter.update_pi(pi)
        g_fitter.update_mu(mu)
        g_fitter.update_var(var)

    writer.close()


main()


ValueError: too many values to unpack (expected 2)