In [None]:
%matplotlib inline

# import scipy as sp
# import numpy as np
# np.set_printoptions(precision=4)

import matplotlib
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context("talk")

In [None]:
"""
1.
"""
import itertools
import torch
from torch.autograd import Variable
torch.set_printoptions(precision=4)
    
class P1:
    def __init__(self):
        self.dtype = torch.FloatTensor
        self.torch_of = lambda y: Variable(torch.Tensor(y).type(self.dtype), requires_grad=False)

        # initial states
        self.names = ["SR", "YD", "MG", "ZH", "HS", "RS", "NZ", "YK"]
        self.indices = {name: i for i, name in enumerate(self.names)}
        self.n = len(self.names)

        # useful helper functions
        self.name_of = lambda i: self.names[i]
        self.index_of = lambda n: self.indices[n]
        self.all_ys = lambda: map(self.torch_of, itertools.product(range(2), repeat=self.n))

        # parameters of model
        self.edges = [(self.index_of(s), self.index_of(t)) for (s, t) in
                 [("MG", "RS"), ("YD", "RS"), ("RS", "ZH"), 
                 ("ZH", "YK"), ("ZH", "HS"), ("ZH", "NZ")]
                ]
        self.theta_st = 2
        self.theta_s = Variable(
            torch.Tensor([2, -2, -2, -8, -2, 3, -2, 1]).type(self.dtype), 
            requires_grad=True
        )
        self.log_partition_cache = None

    def log_potential(self, y):
        edge_scores = (y[s] * self.theta_st * y[t] for s, t in self.edges)
        node_scores = self.theta_s.dot(y)
        return sum(edge_scores) + node_scores

    def log_partition(self):
        if self.log_partition_cache is None:
            scores = [self.log_potential(y).exp() for y in self.all_ys()]
            self.log_partition_cache = sum(scores).log()
        return self.log_partition_cache

    def marginals(self):
        ps = lambda i: sum(((self.log_potential(y) - self.log_partition()).exp() 
              for y in self.all_ys() if y[i].data[0] == 1)).data[0]
        return self.torch_of([ps(i) for i in range(self.n)])

    def marginals_optim(self):
        if self.theta_s.grad is None:
            lpf = self.log_partition()
            lpf.backward()
        return self.theta_s.grad
    
    def get_test_y(self):
        return self.torch_of([
                int(i == self.index_of("RS") or i == self.index_of("SR")) 
                for i in range(self.n)
            ])
    
p1 = P1()
y = p1.get_test_y()

print(p1.log_potential(y))
print(p1.log_partition())
print(p1.marginals(), p1.marginals_optim())