<a href="https://colab.research.google.com/github/hafezgh/nested-cpt-actor-critic/blob/main/lottery.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from cpt import *
import numpy as np
import matplotlib.pyplot as plt
import copy

In [2]:
S = 5
A = 3
N = 4
seed = 2045
np.random.seed(seed)
P = np.random.rand(S,int(A**N),S) + 0.01
for s in range(S):
    for a in range(int(A**N)):
        P[s,a] /= np.sum(P[s,a,:])
state_trans = P
w = np.random.uniform(size=(N,N))

In [3]:
class Agent():
    def __init__(self, id, S, A, num_neigh, state_trans, n_max=50, discount=0.99, alpha_ = 0.65, beta_ = 0.65,\
                 lambda_ = 2.6, gamma_pos = 0.69, gamma_neg = 0.69):
        self.id = id
        self.S = S
        self.A = A
        self.num_visible_agents = num_neigh + 1
        self.discount = discount
        self.n_max = n_max
        self.v = np.zeros((S,))
        self.theta = np.random.rand(S,A)
        self.pi = (1/A)*np.ones((S,A))
        self.state_trans = state_trans
        self.alpha_ = alpha_
        self.beta_ = beta_
        self.lambda_ = lambda_
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
    
    def policy(self, s):
        self.pi[s] = np.exp(self.theta[s])/np.sum(np.exp(self.theta[s]))
            
    def sel_action(self, s):
        self.a = np.random.choice(A, p=self.pi[s].ravel())
        return self.a
    
    def sel_action_next(self, sn):
        self.next_a = np.random.choice(A, p=self.pi[sn].ravel())
        return self.next_a

    def calculate_joint_a(self, acts):
        joint_a = 0
        for i in range(self.num_visible_agents):
            joint_a += acts[i]*self.A**(self.num_visible_agents-i-1)
        return joint_a

    def lr_critic(self, it):
        return 0.1/(it**0.75)
    
    def lr_actor(self,it):
        return 0.1/(it**0.5)

    def cpt_delta(self, s, acts, Rp, Rsigma):
        delta = 0
        acts_n_ = copy.deepcopy(acts)
        samples = np.zeros((self.n_max, 1))
        for i in range(self.n_max):
            a = self.sel_action(s)
            acts_n_[self.id] = a
            sigma = (np.sum(acts_n_+1) - (a+1))/(self.num_visible_agents-1)
            joint_a = self.calculate_joint_a(acts_n_)
            r = Rp[self.id,s] + Rsigma[s] * sigma
            sn = get_next_state(s, joint_a, self.state_trans)
            samples[i] = r + self.v[sn]
        est_v_s = cpt_estimate_from_samples(samples, self.alpha_, self.beta_, self.lambda_, self.gamma_pos, self.gamma_neg)
        delta = float(est_v_s - self.v[s])
        return delta

    def mu_cpt(self, acts, Rp, Rsigma):
        weights = np.zeros((self.S,self.S))
        acts_ = copy.deepcopy(acts)
        for sn in range(self.S):
            for s in range(self.S):
                samples = np.zeros((self.n_max, 1))
                for i in range(self.n_max):
                    a = self.sel_action(s)
                    acts_[self.id] = a
                    sigma = (np.sum(acts_+1) - (a+1))/(self.num_visible_agents-1)
                    r = Rp[self.id,s] + Rsigma[s] * sigma
                    if self.v[s] >= 0:
                        samples[i] = self.pi[s,a]*util_plus_derivative(r,self.alpha_)
                    else:
                        samples[i] = self.pi[s,a]*util_minus_derivative(r,self.beta_,self.lambda_)
                samples = np.sort(samples,0)
                for i in range(1, self.n_max+1):
                    weights[sn,s] += samples[i-1]*(weight((self.n_max+1-i)/self.n_max, self.gamma_pos)-weight((self.n_max-i)/self.n_max, self.gamma_pos))
        eps = 1e-16
        h = (eps)*np.ones((self.S,1))
        h[0] = 1
        h = h/np.sum(h)
        I = np.eye(self.S)
        mu = np.linalg.inv((I-weights.T)+eps)@h
        mu /= np.sum(mu)
        return mu

    def cpt_grad(self, s, acts, Rp, Rsigma):
        mu = self.mu_cpt(acts, Rp, Rsigma)
        grad = 0
        acts_ = copy.deepcopy(acts)
        for s in range(self.S):
            grad_s = 0
            for sn in range(self.S):
                for a in range(self.A):
                    acts_[self.id] = a
                    sigma = (np.sum(acts_+1) - (a+1))/(self.num_visible_agents-1)
                    r = Rp[self.id,s] + Rsigma[s] * sigma
                    ret = r+self.discount*self.v[sn]
                    if ret >= 0:
                        u = util_plus(ret,self.alpha_)
                        phi_der = util_plus_derivative(u,self.alpha_)
                    else:
                        u = -util_minus_abs(ret,self.beta_,self.lambda_)
                        phi_der = util_minus_derivative(u,self.beta_,self.lambda_)
                    pi_grad = self.pi[s,a]*(1-self.pi[s,a])
                    grad_s = grad_s + phi_der*self.state_trans[s, self.calculate_joint_a(acts_), sn]*pi_grad*u
            grad = grad + grad_s*float(mu[s])
        return grad
            
    def learn(self, s, acts, Rp, Rsigma, it):
        delta = self.cpt_delta(s, acts, Rp, Rsigma)
        grad = self.cpt_grad(s, acts, Rp, Rsigma)
        self.v[s] = self.v[s] + delta*self.lr_critic(it)
        self.theta[s,acts[self.id]] = self.theta[s,acts[self.id]] + self.lr_actor(it)*grad
        self.policy(s)


In [5]:
seed = 2045

np.random.seed(seed)
pols_avg = np.zeros((N,S,A))

Rp = np.random.rand(N,S)
Rsigma = (np.random.rand(S)-0.5)*5
pols_avg = np.zeros((N,S,A))
max_it = int(1e04)


discount = 0.99
alpha_ = [0.65, 0.65, 0.65,0.65]
beta_ = [0.65, 0.65, 0.65,0.65]
lambda_ = [2.6, 2.6, 2.6,2.6]
gamma_pos = [0.69, 0.69, 0.69,0.69]
gamma_neg = [0.69, 0.69, 0.69,0.69]
s = 0
ep = 8
all_v_it = np.zeros((ep,max_it,N,S))
all_pols_it = np.zeros((ep,max_it,N,S,A))


for e in range(ep):
    np.random.seed(seed+ep)
    s = 0
    r = np.zeros((N,))
    sigma = np.zeros((N,))
    next_sigma = np.zeros((N,))
    agents = []
    counter = 0
    for n in range(N):
        agents.append(Agent(n, S, A, N-1, state_trans, n_max=50, discount=0.99, alpha_ = alpha_[n], beta_ = beta_[n],\
                 lambda_ = lambda_[n], gamma_pos = gamma_pos[n], gamma_neg = gamma_neg[n]))   
    act = np.zeros((N,),dtype=int)
    act[n] = agents[n].sel_action(s)
    joint_a = 0
    for it in range(1, max_it+1):
        joint_a = 0
        for n in range(N):
            joint_a += int(A**n)*act[n]
        sn = get_next_state(s, joint_a, P)
        for n in range(N):
            agents[n].learn(s, act, Rp, Rsigma, it)
            all_v_it[e,it-1,n,:] = agents[n].v
            all_pols_it[e,it-1,n,:] = agents[n].pi
        for n in range(N):
            act[n] = agents[n].sel_action_next(sn)
        s = copy.deepcopy(sn)

            
    pols = (1/A)*np.ones((N,S,A))
    for n in range(N):
        pols[n] = agents[n].pi
    pols_avg += pols
pols = pols_avg/ep
