In [1]:
import numpy as np
from coins import generate_combinations
from itertools import product



In [2]:
uniform = np.asarray(list(product(range(10),range(10),range(10),range(10),range(10))))
combinations = np.asarray(generate_combinations(25, range(10), 5))
combinations_set = set(tuple(c) for c in combinations)

marginals = np.zeros(10)
for i in xrange(10):
    marginals[i] = (combinations==i).mean()
print 'Marginals', marginals
print 'Got {} combinations out of 10^5=100000'.format(len(combinations))
print 'Ratio is {}'.format(len(combinations) / 10.**5)
print 'Uniform entropy {:.2f} bits'.format(np.log(10.**5))
marginal_entropy = -5 * np.sum(marginals * np.log(marginals))
print 'Marginal entropy {:.2f} bits'.format(marginal_entropy)
joint_entropy = np.log(len(combinations))
print 'Joint entropy {:.2f} bits'.format(joint_entropy)

print 'Difference is {:.2f} bits'.format(marginal_entropy - joint_entropy)

Marginals [0.06180075 0.07369917 0.08524241 0.09589771 0.1051323  0.11241343
 0.11720831 0.11898419 0.11720831 0.11241343]
Got 5631 combinations out of 10^5=100000
Ratio is 0.05631
Uniform entropy 11.51 bits
Marginal entropy 11.41 bits
Joint entropy 8.64 bits
Difference is 2.78 bits


## Symbolic
The entropy is the lowest negative-log-likelihood one can expect (when the model matches the true distribution).
So, given symbolic-sum-25 distribution, the best nll we can expect is 8.64 bits.

Another perspective: marginal entropy is the best nll one can expect from a independent model (interestingly it's almost the same as the uniform).

N is the number of combinations (5631), i is the combination index (0 to 5630), j is the digit index (0 to 4).

NLL Joint: $$ E_p[-\log q(z)] = \frac 1 N \sum_i -\log q(z) \approx 11.41$$
NLL Marginal: $$ E_p[-\log q(z)] = E_p[-\sum_j \log q_j(z_j)] \approx 8.64$$

## Images
Now let's look into using images. We factor the model as $q(x) = q(z) \prod_i q(x_i|z_i)$. Then we need to marginalize wrt $z$ because it is an unobserved variable.

NLL Joint: $$ E_p[- \log \sum_z q(z) \prod_j q(x_j|z_j)] $$
NLL Marginal: $$ E_p[- \log \sum_z \prod_j q(z_j) q(x_j|z_j)]$$
E_p[- \log \prod_j \sum_{z_0} q(z_0) q(x_0|z_0)] = - 5 * NLL-Single-Model$$.

## Simplifying assumption of good and bad images
Let's make a simplifying assumption that:
- if $y_j=z_j$ then $q(x_j|z_j)=a$
- else if $y_j\neq z_j$ then $q(x_j|z_j)=b$.

Then we have

NLL Joint: $$ E_{y\sim p(y)}[- \log \sum_z q(z) \prod_j (a*1_{z_j=y_j} + b*1_{z_j\neq y_j})] $$

## Perfect conditional image model
Suppose we got the "perfect" conditional image model, which is uniform over 6000 images of given class (mixture of diracs).
Then $a=1/6000$ and $b=0$.

We have: 

NLL Joint: $$ E_{y\sim p(y)}[- \log (q(z) a^5)] = -5 \log a + E_{y\sim p(y)}[-\log q(z)] $$
which is 43 + Symbolic_NLL.

Indeed this result is confirmed by the numerical results with a=1/6000 and b=0.
We get a difference of 2.78 bits in that case, exactly like the symbolic case (since images are either assigned perfect score or none).


## Perfect marginalized image model 

Suppose we got the "perfect" marginal image model, which is uniform over 60000 images of 10 classes (mixture of diracs).
Then $a=b=1/60000$.

NLL Joint: $$ E_{y\sim p(y)}[- \log \sum_z q(z) \prod_j a] = -5 \log a + E_{y\sim p(y)}[- \log \sum_z q(z)] = -5 \log a = 55 bits$$ since q sums up to zero.

Therefore, if we ignore the marginal image model completely, we can get a score of 55 bits, which corresponds to the independent/uniform case.


## Perfect symbolic model

NLL Joint: $$ E_{y\sim p(y)}[- \log \sum_z p(z) \prod_j (a*1_{z_j=y_j} + b*1_{z_j\neq y_j})] 
= \frac 1 N \sum_{k=1}^N [ - \log \frac 1 N \sum_{i=1}^N \prod_j (a*1_{z_j^i=y_j^k} + b*1_{z_j^i\neq y_j^k})] $$


## Todo: compare slightly flawed images respecting constraint with perfect images not respecting constraint

to do this, one could switch a and b for one class, which would correspond to ensuring that 25 is never respected.

In [134]:
-5*np.log(1./60000)

55.010499206021194

In [7]:
# Fill perfect joint distribution
q_join = np.zeros(10**5)
for i, u_i in enumerate(uniform):
    if tuple(u_i) in combinations_set:
        q_join[i] = 1. / len(combinations_set)
print q_join.sum()

1.0


In [8]:
# Fill perfect marginal distribution
q_ind = np.zeros(10**5)
for i, u_i in enumerate(uniform):
    q_ind[i] = np.prod([marginals[u_i_j] for u_i_j in u_i])
print q_ind.sum()

1.0000000000000002


In [13]:
# Fill uniform distribution
q_uni = np.ones(10**5)
q_uni /= q_uni.sum()

In [5]:
# Fast
def get_nll(a, b, q):
    nll = 0.
    for i, y in enumerate(combinations):
        log_arg = 0.

        log_arg = q.dot(np.prod((uniform == y) * a + (uniform != y) * b, axis=1))

        nll += -np.log(log_arg) / float(len(combinations))

        if i % 500 == 0:
            print '{}/{}'.format(i, len(combinations))
            print 'combination', y
            print 'nll', nll
    return nll

# Perfect image model

In [None]:
a = 1./6000.  # ~diracs on training set
b = 0.  # not likelihood to bad digits

In [123]:
nll_join = get_nll(a, b, q_join)
print nll_join

0/5631
combination [0 0 7 9 9]
nll 0.00925832286739577
500/5631
combination [1 5 9 7 3]
nll 4.638419756565316
1000/5631
combination [2 6 7 3 7]
nll 9.267581190263275
1500/5631
combination [3 6 2 7 7]
nll 13.896742623961234
2000/5631
combination [4 4 8 9 0]
nll 18.525904057659194
2500/5631
combination [5 2 8 9 1]
nll 23.155065491357153
3000/5631
combination [5 9 7 4 0]
nll 27.784226925055112
3500/5631
combination [6 7 4 3 5]
nll 32.41338835875307
4000/5631
combination [7 4 9 3 2]
nll 37.04254979245103
4500/5631
combination [8 2 5 1 9]
nll 41.671711226148986
5000/5631
combination [9 0 0 9 7]
nll 46.300872659846945
5500/5631
combination [9 7 0 5 4]
nll 50.930034093544904
52.133616066306374


In [122]:
nll_ind = get_nll(a, b, q_ind)
print nll_ind

0/5631
combination [0 0 7 9 9]
nll 0.009867725421967987
500/5631
combination [1 5 9 7 3]
nll 4.901380586741098
1000/5631
combination [2 6 7 3 7]
nll 9.78055439383192
1500/5631
combination [3 6 2 7 7]
nll 14.65394538315527
2000/5631
combination [4 4 8 9 0]
nll 19.523288202264716
2500/5631
combination [5 2 8 9 1]
nll 24.39140705760259
3000/5631
combination [5 9 7 4 0]
nll 29.255284277965057
3500/5631
combination [6 7 4 3 5]
nll 34.12155291187345
4000/5631
combination [7 4 9 3 2]
nll 38.995544088353604
4500/5631
combination [8 2 5 1 9]
nll 43.87361336336889
5000/5631
combination [9 0 0 9 7]
nll 48.75141814475452
5500/5631
combination [9 7 0 5 4]
nll 53.63896196483282
54.91244098275452


In [126]:
print 'Difference is {:.2f} bits'.format(nll_ind - nll_join)

Difference is 2.78 bits


# Slightly flawed image model

In [9]:
a = 0.9 / 6000.
b = 0.1 / 6000.

In [10]:
nll_join = get_nll(a, b, q_join)
print nll_join

0/5631
combination [0 0 7 9 9]
nll 0.009189741736744677
500/5631
combination [1 5 9 7 3]
nll 4.599596732681784
1000/5631
combination [2 6 7 3 7]
nll 9.187858579046145
1500/5631
combination [3 6 2 7 7]
nll 13.77519999198052
2000/5631
combination [4 4 8 9 0]
nll 18.36192118751578
2500/5631
combination [5 2 8 9 1]
nll 22.948473742820912
3000/5631
combination [5 9 7 4 0]
nll 27.534152581164488
3500/5631
combination [6 7 4 3 5]
nll 32.12022143949179
4000/5631
combination [7 4 9 3 2]
nll 36.707587934856726
4500/5631
combination [8 2 5 1 9]
nll 41.29559541617429
5000/5631
combination [9 0 0 9 7]
nll 45.883654857266116
5500/5631
combination [9 7 0 5 4]
nll 50.47353558124406
51.66728615469974


In [11]:
nll_ind = get_nll(a, b, q_ind)
print nll_ind

0/5631
combination [0 0 7 9 9]
nll 0.009279927782088892
500/5631
combination [1 5 9 7 3]
nll 4.632917927024459
1000/5631
combination [2 6 7 3 7]
nll 9.252323731343532
1500/5631
combination [3 6 2 7 7]
nll 13.869551548164335
2000/5631
combination [4 4 8 9 0]
nll 18.485028172173273
2500/5631
combination [5 2 8 9 1]
nll 23.09972881367803
3000/5631
combination [5 9 7 4 0]
nll 27.712595578207438
3500/5631
combination [6 7 4 3 5]
nll 32.32625722007918
4000/5631
combination [7 4 9 3 2]
nll 36.943087623260006
4500/5631
combination [8 2 5 1 9]
nll 41.56144169261444
5000/5631
combination [9 0 0 9 7]
nll 46.17997230658643
5500/5631
combination [9 7 0 5 4]
nll 50.802807400563715
52.00580738980117


In [17]:
nll_uni = get_nll(a, b, q_uni)
print nll_uni

0/5631
combination [0 0 7 9 9]
nll 0.00976922379790821
500/5631
combination [1 5 9 7 3]
nll 4.89438112275202
1000/5631
combination [2 6 7 3 7]
nll 9.778993021706302
1500/5631
combination [3 6 2 7 7]
nll 14.663604920660585
2000/5631
combination [4 4 8 9 0]
nll 19.548216819614222
2500/5631
combination [5 2 8 9 1]
nll 24.432828718567617
3000/5631
combination [5 9 7 4 0]
nll 29.31744061752101
3500/5631
combination [6 7 4 3 5]
nll 34.202052516475206
4000/5631
combination [7 4 9 3 2]
nll 39.08666441543038
4500/5631
combination [8 2 5 1 9]
nll 43.97127631438555
5000/5631
combination [9 0 0 9 7]
nll 48.85588821334072
5500/5631
combination [9 7 0 5 4]
nll 53.74050011229589
55.010499206024235


In [12]:
print 'Difference is {:.2f} bits'.format(nll_ind - nll_join)

Difference is 0.34 bits


# Generate perfect digits randomly

In [18]:
a = b = 1. / 60000.

In [16]:
nll_join = get_nll(a, b, q_join)
print nll_join

0/5631
combination [0 0 7 9 9]
nll 0.009769223797908221
500/5631
combination [1 5 9 7 3]
nll 4.894381122752021
1000/5631
combination [2 6 7 3 7]
nll 9.778993021706304
1500/5631
combination [3 6 2 7 7]
nll 14.663604920660587
2000/5631
combination [4 4 8 9 0]
nll 19.548216819614225
2500/5631
combination [5 2 8 9 1]
nll 24.43282871856762
3000/5631
combination [5 9 7 4 0]
nll 29.317440617521015
3500/5631
combination [6 7 4 3 5]
nll 34.20205251647521
4000/5631
combination [7 4 9 3 2]
nll 39.086664415430384
4500/5631
combination [8 2 5 1 9]
nll 43.971276314385555
5000/5631
combination [9 0 0 9 7]
nll 48.85588821334073
5500/5631
combination [9 7 0 5 4]
nll 53.7405001122959
55.01049920602424


In [19]:
nll_ind = get_nll(a, b, q_ind)
print nll_ind

0/5631
combination [0 0 7 9 9]
nll 0.009769223797908221
500/5631
combination [1 5 9 7 3]
nll 4.894381122752021
1000/5631
combination [2 6 7 3 7]
nll 9.778993021706304
1500/5631
combination [3 6 2 7 7]
nll 14.663604920660587
2000/5631
combination [4 4 8 9 0]
nll 19.548216819614225
2500/5631
combination [5 2 8 9 1]
nll 24.43282871856762
3000/5631
combination [5 9 7 4 0]
nll 29.317440617521015
3500/5631
combination [6 7 4 3 5]
nll 34.20205251647521
4000/5631
combination [7 4 9 3 2]
nll 39.086664415430384
4500/5631
combination [8 2 5 1 9]
nll 43.971276314385555
5000/5631
combination [9 0 0 9 7]
nll 48.85588821334073
5500/5631
combination [9 7 0 5 4]
nll 53.7405001122959
55.01049920602424
