In [41]:
%autoreload 9

import torch
import numpy as np
import matplotlib.pyplot as plt
import os

from torch.utils.tensorboard import SummaryWriter

import gmm
from example import plot
from gumbel_regression import *

def create_data():
    np.random.seed(0)
    n_per_cluster = 40
    means = np.array([[0,0], [-5,5], [5,5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    true_y = np.array([0] * n_per_cluster + [1] * n_per_cluster + [2] * n_per_cluster)
    return torch.Tensor(X), true_y

def create_data2():
    np.random.seed(0)
    n_per_cluster = 50
    means = np.array([[0,0], [-5, -5]])
    X = np.vstack([np.random.randn(n_per_cluster, 2) + mu for mu in means])
    true_y = np.array([0] * n_per_cluster + [1] * n_per_cluster)
    return torch.Tensor(X), true_y

In [48]:
def gumbel_stable_loss(z, clip=None):
    if clip is not None:
        z = torch.clamp(z, max=clip)

    max_z = torch.max(z)
    max_z = torch.where(max_z < -1.0, torch.tensor(-1.0, dtype=torch.double, device=max_z.device), max_z)
    max_z = max_z.detach()  # Detach the gradients
    
    # scale by e^max_z
    loss = torch.exp(z - max_z) - z * torch.exp(-max_z) - torch.exp(-max_z)
    return loss

def loss_fn(V, logP):
    # loss = gumbel_stable_loss(logP - V, clip=10) # + logP

    loss = torch.mean(torch.exp(logP-V) - V - 1)
    z = logP - V

    max_z = torch.max(z)
    max_z = torch.where(max_z < -1.0, torch.tensor(-1.0, dtype=torch.double, device=max_z.device), max_z)
    max_z = max_z.detach()  # Detach the gradients
    loss = loss * torch.exp(-max_z)
    return loss


def main():
    data, true_y = create_data2()
    n = len(data)
    d = data.shape[1]
    k = 2
    g_fitter = gmm.GaussianMixture(n_components=k, n_features=d)
    max_iter = 100
    writer = SummaryWriter()
    

    V = torch.zeros(n, requires_grad=True) 
    V_mean = torch.mean(V)
    V_lr = 0.01
    V_optim = torch.optim.Adam([V], lr=V_lr)
    # loss_fn = gumbel_stable_loss

    logP = g_fitter._estimate_log_prob(data)
    loss = loss_fn(V, logP)  # TODO: Do we need to compute mean here?
    writer.add_scalar('loss', loss, 0)
    
    # Randomly initialize means
    g_fitter.mu = torch.nn.Parameter(torch.randn(1, k, d), requires_grad=False)

    for iter_ in range(max_iter):
        pred_y = g_fitter.predict(data)
        if iter_ % 10 == 0:
            plot(data, true_y, pred_y, iter_, g_fitter.mu)

        # Do V step
        logP = g_fitter._estimate_log_prob(data)

        V_optim.zero_grad()
        # loss = torch.mean(loss_fn(data-V))
        loss = loss_fn(V, logP)
        loss = loss.mean()
        writer.add_scalar('loss', loss, iter_+1)
        loss.backward(retain_graph=True)
        V_optim.step()
        print("Loss:", loss, "\n")
        
        # Do theta step
        weighted_log_prob = logP + np.log(1/k) * torch.ones_like(logP)
        log_prob_norm = V.unsqueeze(1).unsqueeze(2)
        log_resp = weighted_log_prob - log_prob_norm
        pi, mu, var = g_fitter._m_step(data, log_resp)
        g_fitter.update_pi(pi)
        g_fitter.update_mu(mu)
        g_fitter.update_var(var)

    writer.close()


main()


Parameter containing:
tensor([[[-1.1261,  0.0650],
         [-0.9060,  1.4437]]])
tensor([-1.1261,  0.0650])
Parameter containing:
tensor([[[-1.1261,  0.0650],
         [-0.9060,  1.4437]]])
tensor([-0.9060,  1.4437])
Loss: tensor(-2.6511, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-2.6165, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-1.7313, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-1.1519, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-0.7860, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-0.5231, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-0.2654, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-0.1349, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-0.0757, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Loss: tensor(-0.0309, dtype=torch.float64, grad_fn=<MeanBackward0>) 

Parameter containing:
tensor([[[0.0032, 0.3322],
         [0.0940, 0.3788]]])
tens