In [2]:
# %load A10_Reinforce.py
import torch
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(1)

pGT = torch.Tensor([1./12, 2./12, 3./12, 3./12, 2./12, 1./12])
y = torch.from_numpy(np.random.choice(list(range(6)), size=1000, p=pGT.numpy())).type(torch.int64).view(-1, 1)
delta = torch.zeros(y.numel(),6).scatter(1,y,torch.ones_like(y).float())

#maximum likelihood given dataset y encoded in delta
def MaxLik(delta):
    alpha = 1
    theta = torch.randn(6)
    for iter in range(100):
        p_theta = torch.nn.Softmax(dim=0)(theta)
        g = torch.mean(p_theta-delta,0)
        theta = theta - alpha*g
        print("Diff: %f" % torch.norm(p_theta - pGT))
    
    return theta

theta = MaxLik(delta)

#reinforce with reward R
def Reinforce(R, theta=None):
    alpha = 1
    if theta is None:
        theta = torch.randn(6)
    for iter in range(10000):
        #current distribution
        p_theta = torch.nn.Softmax(dim=0)(theta) #1x6

        #sample from current distribution and compute reward
        ##############################
        ## Sample from p_theta, find the assignment delta and compute the reward
        ## for each sample
        ## Dimensions: cPT (6); y (1000x1 -> 1000x1); delta (1000x6); curReward (1000x1)
        ############################## 
        y = torch.from_numpy(np.random.choice(list(range(6)), size=1000, p=R.numpy())).type(torch.int64).view(-1, 1)
        delta = torch.log(torch.exp(theta) / torch.sum(torch.exp(theta)))
        curReward = torch.sum(delta, dim=0) * R

        #compute gradient and update
        g = torch.mean(curReward*(delta - p_theta),0)
        theta = theta + alpha*g
        print("Diff: %f" % torch.norm(p_theta - pGT))
        print(p_theta)

R = pGT
Reinforce(R, theta)
    


Diff: 0.584986
Diff: 0.451345
Diff: 0.361978
Diff: 0.300051
Diff: 0.253475
Diff: 0.216126
Diff: 0.184878
Diff: 0.158056
Diff: 0.134754
Diff: 0.114482
Diff: 0.096952
Diff: 0.081945
Diff: 0.069245
Diff: 0.058618
Diff: 0.049818
Diff: 0.042596
Diff: 0.036714
Diff: 0.031955
Diff: 0.028128
Diff: 0.025068
Diff: 0.022635
Diff: 0.020711
Diff: 0.019200
Diff: 0.018019
Diff: 0.017103
Diff: 0.016396
Diff: 0.015856
Diff: 0.015445
Diff: 0.015136
Diff: 0.014906
Diff: 0.014737
Diff: 0.014616
Diff: 0.014530
Diff: 0.014473
Diff: 0.014436
Diff: 0.014416
Diff: 0.014407
Diff: 0.014407
Diff: 0.014414
Diff: 0.014425
Diff: 0.014440
Diff: 0.014457
Diff: 0.014475
Diff: 0.014495
Diff: 0.014514
Diff: 0.014533
Diff: 0.014552
Diff: 0.014571
Diff: 0.014588
Diff: 0.014605
Diff: 0.014622
Diff: 0.014637
Diff: 0.014652
Diff: 0.014665
Diff: 0.014678
Diff: 0.014690
Diff: 0.014702
Diff: 0.014712
Diff: 0.014722
Diff: 0.014732
Diff: 0.014740
Diff: 0.014748
Diff: 0.014756
Diff: 0.014763
Diff: 0.014769
Diff: 0.014775
Diff: 0.01

tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, na

tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, na

Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([nan, nan, nan, nan, nan, nan])
Diff: nan
tensor([na

KeyboardInterrupt: 