In this notebook, we generate data in a two step process: 

1) We generate cluster centers from a Gaussian distribution with a fixed mean and standard deviation

2) For each cluster, we generate observed values 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch

## Parameters go here

In [37]:
prior_mn = 20.0
prior_std = 5

cluster_stds = [1.0]*500 # Number of entries here determine number of clusters
n_obs_per_cluster = 10

## Generate the data

In [38]:
n_clusters = len(cluster_stds)
ctrs = np.random.randn(n_clusters)*prior_std + prior_mn

In [39]:
data = [torch.tensor(np.random.randn(n_obs_per_cluster)*std_c + mn_c) for std_c, mn_c in zip(cluster_stds, ctrs)]

## Fit prior describing the data

In [78]:
fit_ctr = torch.tensor(30.0, requires_grad=True)
fit_std = torch.tensor(1.0, requires_grad=True)

In [79]:
optimizer = torch.optim.Adam(params=[fit_ctr, fit_std], lr=.1)

In [None]:
n_smps = 1000
for i in range(1000):
    
    ll = 0
    optimizer.zero_grad()
    
    for c_i in range(n_clusters):
        c_l = 0 
        data_i = data[c_i]
        
        for s_i in range(n_smps):
            ctr = torch.randn(1)*torch.abs(fit_std) + fit_ctr
            var_c = cluster_stds[c_i]**2


            diffs = torch.exp((-.5*(data_i - ctr)**2)/var_c)
        
            c_l+= torch.prod(diffs) # Ignoreing part that doeesn't depend on the ctr
        
        ll = -1*torch.log(c_l/n_smps)
        
    if i % 1 == 0:  
        print(str(i) + ': ' + str(ll.detach().numpy()))
        print('fit_ctr: ' + str(fit_ctr.detach().numpy()))
        print('fit_std: ' + str(abs(fit_std.detach().numpy())))
              
    ll.backward()
    optimizer.step()
        
    

0: 7.859594054091942
fit_ctr: 23.801092
fit_std: 5.512872
1: 7.870756718831328
fit_ctr: 23.796787
fit_std: 5.514128
2: 7.457318609142766
fit_ctr: 23.792242
fit_std: 5.515487
3: 7.559388767504556
fit_ctr: 23.787703
fit_std: 5.516854
4: 7.520243699507743
fit_ctr: 23.784636
fit_std: 5.517546
5: 7.6699290200973085
fit_ctr: 23.781336
fit_std: 5.518342
6: 7.799901724155332
fit_ctr: 23.777906
fit_std: 5.519206
7: 7.368849417752132
fit_ctr: 23.774815
fit_std: 5.519907
8: 7.360859436224478
fit_ctr: 23.771322
fit_std: 5.5207973
9: 7.775481355057464
fit_ctr: 23.768095
fit_std: 5.521578
10: 7.828404925463204
fit_ctr: 23.76528
fit_std: 5.5221853
11: 7.263677118875921
fit_ctr: 23.76271
fit_std: 5.5226564
12: 7.773770951790556
fit_ctr: 23.759829
fit_std: 5.5232816
13: 7.478767682146854
fit_ctr: 23.757387
fit_std: 5.5237164
14: 7.608935729777869
fit_ctr: 23.754047
fit_std: 5.5245776
15: 7.451043021856939
fit_ctr: 23.74945
fit_std: 5.526016
16: 7.698338722428656
fit_ctr: 23.745111
fit_std: 5.527341
17:

In [None]:
fit_ctr

In [None]:
fit_std

In [None]:
torch.prod(torch.tensor([1, 2, 4]))

In [59]:
np.mean(ctrs)

20.009747790112886