In [1]:
from __future__ import division
import gym
import torch
import random
import numpy as np
import torch
from PIL import Image
import torch
import torch.nn as nn
from collections import namedtuple
from collections import deque
import torch.nn.functional as F
from tqdm import tqdm_notebook as tqdm
from matplotlib.pyplot import imshow
from PIL import Image
from wrappers import make_atari, wrap_deepmind, wrap_pytorch
import queue
from torch import optim
import matplotlib.pyplot as plt

In [18]:
class SegmentSumTree:
    def __init__(self, capacity):
        self.size = capacity 
        #n leaves + n-1 internal = 2n-1, capacity of a tree
        self.tree = np.zeros((2*capacity-1), dtype = np.float32)
        #leaf nodes having actual values 
        self.transit_data_buffer =  np.array([None] * self.size) 
        self.is_full = False
        self.max_prior_val = 1
        self.num_entries = 0 
        self.idx = 0 
    
    
    def append(self, priority_val, transit_data):
        self.max_prior_val = max(priority_val, self.max_prior_val)
        self.transit_data_buffer[self.idx] = transit_data  
        self.update(self.idx + self.size - 1, priority_val) 
        self.idx = (self.idx + 1) % self.size 
        self.num_entries += 1
        self.is_full = self.is_full or self.idx == 0  
    
    
    def propagate(self, index, priority_val):
        parent_idx = self.getParentIdx(index)
        left_node_idx = 2 * parent_idx + 1
        right_node_idx = 2 * parent_idx + 2
        self.tree[parent_idx] = self.tree[left_node_idx] + self.tree[right_node_idx]
        if parent_idx != 0: 
            self.propagate(parent_idx, priority_val)
            
        
    def update(self, index, priority_val):
        self.tree[index] = priority_val
        self.max_prior_val = max(priority_val, self.max_prior_val)
        self.propagate(index, priority_val)
       
    
    def getParentIdx(self, index):
        return (index - 1)//2
        
            
        
    def search(self, value):
        idx = self.retrieve(0, value)  # Search for index of item from root
        data_index = idx - self.size + 1
        return (self.tree[idx], data_index, idx)
        
    
    def retrieve(self, idx, value):
        left_node_idx = 2 * idx + 1
        right_node_idx = 2 * idx + 2
        
        if left_node_idx >= len(self.tree):
            return idx
        
        elif value <= self.tree[left_node_idx]:
            return self.retrieve(left_node_idx, value)
            
        else:
            return self.retrieve(right_node_idx, value - self.tree[left_node_idx])
        
    
    def getNumEntries(self):
        return self.num_entries
     

    def getTotal(self):
        return self.tree[0]
        
    def getSize(self):
        return self.size
    
    
    def getMaxPriorValue(self):
        return self.max_prior_val
    
    
    def getPriorties(self):
        return self.priorities
    
    
    def getTree(self):
        return self.tree
    
    
    def getDataByIdx(self, idx):
        return self.transit_data_buffer[idx % self.size]
    

# Model

In [None]:
class NoisyNet(nn.module):
    def __init__(self, in_channels, out_channels, init_std = 0.5):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weights_mu = nn.Parameter(torch.empty(out_channels, in_channels))
        self.bias_mu = nn.Parameter(torch.empty(out_channels))
        self.weights_sigma = nn.Parameter(torch.empty(out_channels, in_channels))
        self.bias_sigma = nn.Parameter(torch.empty(out_channels))
        self.init_std = init_std
        self.resetNoise()
        self.resetWeights()
       
   
    def resetNoise(self):
        epsilon_i = torch.randn(self.in_channels)
        epsilon_i = epsilon_i.sign().mul_(epsilon_i.abs().sqrt_())
        epsilon_j = torch.randn(self.out_channels)
        epsilon_j = epsilon_j.sign().mul_(epsilon_j.abs().sqrt_())
        self.weight_epsilon.copy_(epsilon_j.ger(epsilon_i))
        self.bias_epsilon.copy_(epsilon_j)
       
   
    def resetWeights(self):
        mu_range = 1 / math.sqrt(self.in_channels)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.init_std / math.sqrt(self.in_channels))
        self.bias_sigma.data.fill_(self.init_std / math.sqrt(self.out_channels))
       
   
    def forward(self, input):
        if not self.training:
            return F.linear(input, self.weight_mu, self.bias_mu)
        else:
            weights = self.weight_mu + self.weight_sigma * self.weight_epsilon
            biases = self.bias_mu + self.bias_sigma * self.bias_epsilon
            return F.linear(input, weights, biases)

In [20]:
# def eps_greedy(epsilon,state,net):
#     if(np.random.random()<epsilon):
#         action = np.random.randint(ACT_SHAPE)
#     else:
#         qvalues = net(state)
#         action = torch.argmax(qvalues).item()
#     return action    
def eps_greedy(epsilon,state,net,atoms):
    if(np.random.random()<epsilon):
        action = np.random.randint(ACT_SHAPE)
    else:
    #Finding the expected value of each action (sum(pi*zi))
        qvalues = net(state)
        expected_values = torch.matmul(qvalues,atoms)
        action = torch.argmax(expected_values).item()
    return action    

In [21]:
class ReplayBuffer(object):
    def __init__(self,maxsize):
        self.q = deque(maxlen = maxsize)
        self.maxsize = maxsize
    def add(self,x):
        self.q.append(x)
        if(len(self.q)==self.maxsize):
            self.q.popleft()
    def getSize(self):
        return len(self.q)
    def sample(self,size):
        batch = random.sample(list(self.q),size)
        state,action,reward,next_state,done = map(list, zip(*batch))
        return state,action,reward,next_state,done
    def getelem(self,index):
        return self.q[index%(self.maxsize+1)]
           

In [2]:
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state','done'))
# blank_trans = Transition(0, torch.zeros(84, 84, dtype=torch.uint8), None, 0, torch.zeros(84, 84, dtype=torch.uint8),False)

class PrioritizedReplayBuffer:
    def __init__(self,capacity = 1000000, batch_size = 32, gamma = 0.99, multi_step = 3, 
                beta = 0.4, alpha = 0.5):
        self.capacity = capacity
        self.discount = gamma
        self.n = multi_step
        self.beta = beta   
        self.alpha = alpha
        self.transit_buffer = SegmentSumTree(capacity)  
        self.prev_step_length = 4 #history_length
        self.timestep = 0  #t
        

    def add(self,state,action,reward,next_state,done):
#         state = state[-1].mul(255)  
#         next_state = next_state[-1].mul(255)
#         #print('coming here')
#         print('Before : ',self.transit_buffer.max_prior_val)
        self.transit_buffer.append(self.transit_buffer.max_prior_val, Transition( state, action, reward, next_state, done))  
#         print('After : ',self.transit_buffer.max_prior_val)
        
#         print('Tree : ',self.transit_buffer.tree)
#         print('Transitions : ',self.transit_buffer.transit_data_buffer)
#         print('-'*100)
#         self.timestep = 0 if done else self.timestep + 1
        
        
    def sample(self,k): 
        batch, idxs, priorities = [], [], []
        root_total_priority = self.transit_buffer.getTotal()
        segment = root_total_priority / root_total_priority
#         print('Root total priority',root_total_priority)
        
        priority_exponent_weight = (1 - self.beta) / (TMAX - TMIN)
        self.beta = np.min([1.0, self.beta + priority_exponent_weight])

        for i in range(k):
            a = segment * i
            b = segment * (i + 1)
            samp = random.uniform(a, b)
            priority, data_idx, idx = self.transit_buffer.search(samp)
            data = self.transit_buffer.getDataByIdx(data_idx)
            
            priorities.append(priority)
            batch.append(data)
            idxs.append(idx)
        sampling_probabilities = priorities / sum(priorities)
#         print("Probabilities : ",sampling_probabilities.sum())
        
        #compute importance sampling 
        weights = np.power(self.transit_buffer.getNumEntries() * sampling_probabilities, -self.beta)
        weights /= weights.max()
    
        states, actions, rewards, next_states, dones = zip(*batch)
        
#         print('-'*30)
#         print('length of states',len(states))
#         print('length of next states',len(next_states))
#         print('shape of states',states[0].shape)
#         print('shape of next states',next_states[0].shape)
        
        
        states, actions, rewards, next_states, dones = list(states), list(actions), list(rewards), \
                                                    list(next_states), list(dones) 
        states = torch.cat([x for x in states]).cuda()
        next_states = torch.cat([x for x in next_states]).cuda()
        actions = torch.Tensor(actions).long().cuda()
        dones = np.array(dones).astype(int)
        dones = torch.Tensor(dones).cuda()
        rewards = torch.Tensor(rewards).cuda()
        priorities = torch.Tensor(priorities).cuda()
        
#         print(states.shape)
#         print(next_states.shape)
        
        return priorities, idxs, states, actions, rewards, next_states, dones       
        
    def updatePriorities(self, indexes, priorities):
        priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(indexes, priorities):
            self.transit_buffer.update(idx, priority_val)
        
        
    def getSize(self):
        return self.transit_buffer.getNumEntries()
        

# Loss function

In [38]:
def compute_loss(size):
    weights,idx,current_state,action,reward,next_state,done = buffer.sample(size)
    
#     print(current_state.shape)
#     print(action.shape)
#     print(reward.shape)
#     print(next_state.shape)
#     print(done.shape)
    
    dist_current = net(current_state)
    
    # z(xt+1)
    dist_next = net(next_state)
    
    # z'(xt+1)
    target_net.eval()
    target_net.reset_noise()
    dist_target = target_net(next_state)
    
    #Selecting optimal action a*
    
    zvalues_next = torch.matmul(dist_next,atoms)
    optimal_action = zvalues_next.max(1)[1]
    
    #Distribution of target with optimal action z(xt+1,a*)
    dist_target_optimal = dist_target.gather(1,optimal_action.view(-1,1).unsqueeze(2).repeat(1,1,N_ATOMS))
    dist_target_optimal = dist_target_optimal.squeeze(1)
    
    #Finding target distribution values (Tzj = r + gamma*Z(x,a*)) (not aligned)
    done = done.squeeze(0)
    done = done.unsqueeze(1)
    Tz = reward.unsqueeze(1).repeat(1,N_ATOMS) + (1-done).repeat(1,N_ATOMS)*GAMMA*atoms.unsqueeze(0).repeat(size,1)
    
    #Clipping the values
    Tz = torch.clamp(Tz,min=VMIN,max=VMAX)
    
    #Aligning the values
    
    deltaz = (VMAX-VMIN)/(N_ATOMS-1)
    indices = (Tz - VMIN)/deltaz
    lower = indices.floor().long()
    upper = indices.ceil().long()
    lower[(upper > 0) * (lower == upper)] -= 1
    upper[(lower < (N_ATOMS - 1)) * (lower == upper)] += 1
    
    #Finding target probabilities
    target_distribution = torch.zeros(size,N_ATOMS).cuda()
    
    offset = torch.linspace(0, ((size - 1) * N_ATOMS), size).unsqueeze(1).expand(size, N_ATOMS).to(action).cuda()
    
    target_distribution = target_distribution.view(-1).index_add_(0, (lower + offset).view(-1), (dist_target_optimal * (upper.float() - indices)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
    target_distribution = target_distribution.view(-1).index_add_(0, (upper + offset).view(-1), (dist_target_optimal* (indices - lower.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l) 
    
    #p(xt,at)
    current_distribution = dist_current.gather(1,action.view(-1,1).unsqueeze(2).repeat(1,1,N_ATOMS))
    current_distribution = current_distribution.squeeze(1)
    
    
    target_distribution = torch.reshape(target_distribution,(size,N_ATOMS))
#     L = (-target_distribution*torch.log(current_distribution)).sum(dim=1).mean()
    L = (-target_distribution*torch.log(current_distribution)).sum(dim=1)*weights
    new_priorities = L + edge_epsilon
    
    
    optimizer.zero_grad()
    
    L = L.mean()
    L.backward()
#     print('Loss : ',L.item())
    optimizer.step()
    new_priorities = new_priorities.cpu().detach().numpy()
    buffer.updatePriorities(idx,new_priorities)
    
    return L
    

In [29]:
def update_target():
    target_net.load_state_dict(net.state_dict())

In [30]:
def epsilon_decay(ep):
    e = .01 + .99*np.exp(-ep/30000)
    return e

In [31]:
def addreward(id,item,filename):
    f=open(filename,'a+')
    f.write(str(id)+' '+str(item)+' '+'\n')
    f.close()

In [32]:
def addloss(id,loss,filename):
    f=open(filename,'a+')
    f.write(str(id)+' '+str(loss.item())+' '+'\n')
    f.close()

# Training

In [33]:
# env = gym.make('PongNoFrameskip-v4')
# env    = make_atari('Pong-v0')
env    = make_atari('Pong-v0')
env    = wrap_deepmind(env)
env    = wrap_pytorch(env)

print(env.observation_space)
print(env.action_space)


Box(1, 84, 84)
Discrete(6)


In [1]:
ITERATIONS = 1000000
epsilon = .99
OBS_SHAPE = env.observation_space.shape
ACT_SHAPE = env.action_space.n
REPLAY_SAMPLE = 5000
BATCH_SIZE = 32
GAMMA = .99
T_upd = 1000
TMAX = 50e6
TMIN = 20e3
edge_epsilon = 1e-5
VMAX = 10
VMIN = -10
N_ATOMS = 51
N = 4
BLANK_FRAME = torch.zeros(1,4,84,84)
atoms = torch.linspace(VMIN,VMAX,N_ATOMS).cuda()
torch.manual_seed(1) 

NameError: name 'env' is not defined

In [41]:
net = QNet(env.observation_space.shape,env.action_space.n,N_ATOMS)
net = net.cuda()
target_net = QNet(env.observation_space.shape,env.action_space.n,N_ATOMS)
target_net = target_net.cuda()
update_target()
optimizer = optim.Adam(net.parameters(), lr=0.00001)

In [42]:
# net.load_state_dict(torch.load('dqn-model.pth'))
# target_net.load_state_dict(torch.load('dqn-model-target.pth'))
# optimizer = optim.Adam(net.parameters(), lr=0.00001)
# update_target()

In [None]:
nsteps = ReplayBuffer(N)
buffer = PrioritizedReplayBuffer(10000)
episode_reward = 0
state = env.reset()
state = torch.Tensor(state).cuda()
state = state.unsqueeze(0)
count = 0
# lossfile = './dueling-logs/losses.txt'
# rewardsfile = './dueling-logs/rewards.txt'
losses = []
rewards = []
t = 0
T = np.inf
count = 0
for i in tqdm(range(ITERATIONS)):
    #If current state has not reached terminal
    print('t:',t,end=' ')
    print('T:',T)
    if(t<T):
        
        epsilon = epsilon_decay(i)
        action = eps_greedy(epsilon,state,net,atoms)
        next_state, reward, done, info = env.step(action)
        next_state = torch.Tensor(next_state).unsqueeze(0).cuda()
        nsteps.add((state,action,reward,next_state,done))
        state = next_state
        episode_reward+=reward
        if(done):
            T = t+1
            count+=1
            print(count,episode_reward)
            episode_reward=0
    tau = t - N + 1
    # tau --> index of state for which transition is stored (nth previous state)
    if(tau>=0):
        steps = tau
        s_tau,a_tau,_,_,_ = nsteps.getelem(tau)
        G=0
        while(True):
            s,action,reward,ns,done = nsteps.getelem(steps)
            G+= np.power(GAMMA,steps-tau)*reward
            steps+=1
            if(steps==min(tau+N,T)): 
                if(steps==tau+N): #If nth state is reached
                    buffer.add(s_tau,a_tau,G,ns,done)
                elif(steps==T): #If terminal state is reached
                    buffer.add(s_tau,a_tau,G,BLANK_FRAME,done)
                break
            
    t+=1
    if(tau==T-1): #If nth previous state has reached termination move environment to next episode
        t=0
        T=np.inf
        state = env.reset()
        state = torch.Tensor(state).cuda()
        state = state.unsqueeze(0)
        nsteps.reset()
            
            
    if(buffer.getSize()<R)
            
    if(i%T_upd==0):
#         torch.save(net.state_dict(),'./dueling-logs/dqn-model.pth')
#         torch.save(target_net.state_dict(),'./dueling-logs/dqn-model-target.pth')
    update_target()


HBox(children=(IntProgress(value=0, max=1000000), HTML(value='')))

-14.0
-20.0
-18.0
-19.0
Loss :  24.54898452758789
Loss :  148.11819458007812
Loss :  83.00587463378906
Loss :  122.226806640625
Loss :  74.64559936523438
Loss :  72.04064178466797
Loss :  71.79872131347656
Loss :  68.47967529296875
Loss :  44.44623947143555
Loss :  63.78770065307617
Loss :  78.23565673828125
-19.0
Loss :  71.246337890625
Loss :  65.11727905273438
Loss :  93.02323150634766
Loss :  82.27151489257812
Loss :  46.90441131591797
Loss :  192.79493713378906
Loss :  169.25582885742188
Loss :  153.83792114257812
Loss :  90.00130462646484
Loss :  124.48818969726562
Loss :  100.0698013305664
Loss :  72.75244140625
Loss :  146.4843292236328
-20.0
Loss :  65.2945556640625
Loss :  39.594635009765625
Loss :  34.30528259277344
Loss :  26.623069763183594
Loss :  23.840866088867188
Loss :  276.481689453125
Loss :  35.17384719848633
Loss :  26.509403228759766
Loss :  59.310943603515625
Loss :  111.70928955078125
-21.0
Loss :  35.72512435913086
Loss :  46.280967712402344
Loss :  61.4233360

Loss :  7.138603687286377
Loss :  5.735165596008301
-20.0
Loss :  5.571308612823486
Loss :  10.091358184814453
Loss :  41.9439582824707
Loss :  140.57217407226562
Loss :  480.5096435546875
Loss :  661.5758056640625
Loss :  328.46160888671875
Loss :  156.6196746826172
Loss :  130.67547607421875
Loss :  6.224289417266846
Loss :  3.284229040145874
Loss :  15.952858924865723
-20.0
Loss :  38.093299865722656
Loss :  187.7059783935547
Loss :  91.81228637695312
Loss :  24.616989135742188
Loss :  46.03972625732422
Loss :  36.30195236206055
Loss :  29.88376235961914
Loss :  17.537235260009766
Loss :  19.23085594177246
Loss :  48.166751861572266
-21.0
Loss :  30.411100387573242
Loss :  13.362829208374023
Loss :  12.497520446777344
Loss :  44.75130081176758
Loss :  21.683937072753906
Loss :  21.5523681640625
Loss :  17.912731170654297
Loss :  75.01091003417969
Loss :  61.891815185546875
Loss :  66.40391540527344
-21.0
Loss :  247.36148071289062
Loss :  573.9458618164062
Loss :  34.214820861816406