In [34]:
from __future__ import division
import gym
import torch
import random
import numpy as np
import torch
from PIL import Image
import torch
import torch.nn as nn
from collections import namedtuple
from collections import deque
import torch.nn.functional as F
from tqdm import tqdm_notebook as tqdm
from matplotlib.pyplot import imshow
from PIL import Image
from wrappers import make_atari, wrap_deepmind, wrap_pytorch
import queue
from torch import optim
import matplotlib.pyplot as plt
import math
import pandas as pd 
import os 
import pickle 

# Model

In [2]:
class NoisyNet(nn.Module):
    def __init__(self, in_channels, out_channels, init_std = 0.5):
        super(NoisyNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.weights_mu = nn.Parameter(torch.empty(out_channels, in_channels)) 
        self.bias_mu = nn.Parameter(torch.empty(out_channels))
        self.weights_sigma = nn.Parameter(torch.empty(out_channels, in_channels))
        self.bias_sigma = nn.Parameter(torch.empty(out_channels))
        self.register_buffer('weight_epsilon',torch.empty(out_channels, in_channels))
        self.register_buffer('bias_epsilon',torch.empty(out_channels))
        self.init_std = init_std
        self.resetNoise()
        self.resetWeights()
       
   
    def resetNoise(self):
        epsilon_i = torch.randn(self.in_channels)
        epsilon_i = epsilon_i.sign().mul_(epsilon_i.abs().sqrt_())
        epsilon_j = torch.randn(self.out_channels)
        epsilon_j = epsilon_j.sign().mul_(epsilon_j.abs().sqrt_())
        self.weight_epsilon.copy_(epsilon_j.ger(epsilon_i))
        self.bias_epsilon.copy_(epsilon_j)
       
   
    def resetWeights(self):
        mu_range = 1 / math.sqrt(self.in_channels)
        self.weights_mu.data.uniform_(-mu_range, mu_range)
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.weights_sigma.data.fill_(self.init_std / math.sqrt(self.in_channels))
        self.bias_sigma.data.fill_(self.init_std / math.sqrt(self.out_channels))
       
   
    def forward(self, input):
        if not self.training:
            return F.linear(input, self.weights_mu, self.bias_mu)
        else:
            weights = self.weights_mu + self.weights_sigma * self.weight_epsilon
            biases = self.bias_mu + self.bias_sigma * self.bias_epsilon
            return F.linear(input, weights, biases)

In [3]:
class QNet(torch.nn.Module):
    def __init__(self,obs_shape,act_shape,atoms):
        super(QNet, self).__init__()
        self.atoms = atoms
        self.act_shape = act_shape

        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.relu = nn.ReLU()
#         self.fc1 = nn.Linear(7*7*64,512)
#         self.fc2 = nn.Linear(512,1)
#         self.fc3 = nn.Linear(7*7*64,512)
#         self.fc4 = nn.Linear(512,act_shape)
        self.fc1 = NoisyNet(7*7*64,512)
        self.fc2 = NoisyNet(512,atoms)
        self.fc3 = NoisyNet(7*7*64,512)
        self.fc4 = NoisyNet(512,act_shape*atoms)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, x):
        x=x/255
        
        #Conv
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(x.shape[0],-1)
        
        #Fc
        x1 = x
        x = self.fc1(x)
        x = self.relu(x)
        v = self.fc2(x)
        
        x1 = self.fc3(x1)
        x1 = self.relu(x1)
        adv = self.fc4(x1)
        
        #Reshaping value and advantage functions to add probabilities of each atom for each action
        value = v.view(v.shape[0],1,self.atoms)
        adv = adv.view(adv.shape[0],self.act_shape,self.atoms)
        
        q_s_a = value + adv - adv.mean(1,keepdim=True)
        
        #probability of each atom for all actions
        q_s_a = self.softmax(q_s_a)
        
        return q_s_a
    
    
    def reset_noise(self):
        for name, module in self.named_children():
            if 'fc' in name:
                module.resetNoise()

In [4]:
# def eps_greedy(epsilon,state,net):
#     if(np.random.random()<epsilon):
#         action = np.random.randint(ACT_SHAPE)
#     else:
#         qvalues = net(state)
#         action = torch.argmax(qvalues).item()
#     return action    
def eps_greedy(epsilon,state,net,atoms):
    if(np.random.random()<epsilon):
        action = np.random.randint(ACT_SHAPE)
    else:
        #Finding the expected value of each action (sum(pi*zi))
        qvalues = net(state)
        expected_values = torch.matmul(qvalues,atoms)
        action = torch.argmax(expected_values).item()
    return action    

In [5]:
class ReplayBuffer(object):
    def __init__(self,maxsize):
        self.q = deque(maxlen = maxsize)
        self.maxsize = maxsize
    def add(self,x):
        if(len(self.q)==self.maxsize):
            self.q.popleft()
        self.q.append(x)
        
    def getSize(self):
        return len(self.q)
    def sample(self,size):
        batch = random.sample(list(self.q),size)
        state,action,reward,next_state,done = map(list, zip(*batch))
        return state,action,reward,next_state,done
    def getelem(self,idx):
#         print('index accessed in temporary buffer',idx,idx%self.maxsize)
#         print(len(self.q))
        return self.q[idx%self.maxsize]
    def reset(self):
        self.q.clear()
           

In [6]:
class SegmentSumTree:
    def __init__(self, capacity):
        self.size = capacity 
        #n leaves + n-1 internal = 2n-1, capacity of a tree
        self.tree = np.zeros((2*capacity-1), dtype = np.float32)
        #leaf nodes having actual values 
        self.transit_data_buffer =  np.array([None] * self.size) 
        self.is_full = False
        self.max_prior_val = 1
        self.num_entries = 0 
        self.idx = 0 
    
    
    def append(self, priority_val, transit_data):
        #self.max_prior_val = max(priority_val, self.max_prior_val)
        self.transit_data_buffer[self.idx] = transit_data  
        self.update(self.idx + self.size - 1, priority_val) 
        self.idx = (self.idx + 1) % self.size 
        self.num_entries += 1
        if self.num_entries > self.size:
            self.num_entries = self.size
        self.is_full = self.is_full or self.idx == 0  
    
    
    def propagate(self, index, priority_val):
        parent_idx = self.getParentIdx(index)
        left_node_idx = 2 * parent_idx + 1
        right_node_idx = 2 * parent_idx + 2
        self.tree[parent_idx] = self.tree[left_node_idx] + self.tree[right_node_idx]
        if parent_idx != 0: 
            self.propagate(parent_idx, priority_val)
            
        
    def update(self, index, priority_val):
#         print('priority',priority_val)
        self.tree[index] = priority_val
        self.max_prior_val = max(priority_val, self.max_prior_val)
        self.propagate(index, priority_val)
       
    
    def getParentIdx(self, index):
        return (index - 1)//2
        
            
        
    def search(self, value):
        idx = self.retrieve(0, value)  # Search for index of item from root
        data_index = idx - self.size + 1
        return (self.tree[idx], data_index, idx)
        
    
    def retrieve(self, idx, value):
        left_node_idx = 2 * idx + 1
        right_node_idx = 2 * idx + 2
        
        if left_node_idx >= len(self.tree):
            return idx
        
        elif value <= self.tree[left_node_idx]:
            return self.retrieve(left_node_idx, value)
            
        else:
            return self.retrieve(right_node_idx, value - self.tree[left_node_idx])
        
    
    def getNumEntries(self):
        return self.num_entries
     

    def getTotal(self):
        return self.tree[0]
        
    def getSize(self):
        return self.size
    
    
    def getMaxPriorValue(self):
        return self.max_prior_val
    
    
    def getPriorties(self):
        return self.priorities
    
    
    def getTree(self):
        return self.tree
    
    
    def getDataByIdx(self, idx):
        return self.transit_data_buffer[idx % self.size]
    
    
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state','done'))

class PrioritizedReplayBuffer:
    def __init__(self,capacity = 1000000, batch_size = 32, gamma = 0.99, multi_step = 3, 
                beta = 0.4, alpha = 0.5):
        self.capacity = capacity
        self.discount = gamma
        self.n = multi_step
        self.beta = beta   
        self.alpha = alpha
        self.transit_buffer = SegmentSumTree(capacity) 
               

    def add(self,state,action,reward,next_state,done):
        
        self.transit_buffer.append(self.transit_buffer.max_prior_val, Transition(state, action, reward, next_state, done))  
        
        
    def sample(self,k): 
        batch, idxs, priorities = [], [], []
        root_total_priority = self.transit_buffer.getTotal()
        segment = root_total_priority / k
        
        priority_exponent_weight = (1 - self.beta) / (ITERATIONS - REPLAY_SAMPLE) 
        #priority_exponent_weight = (1 - self.beta) / TMAX - TMIN  
        
        self.beta = np.min([1.0, self.beta + priority_exponent_weight])
#         print('total priority',root_total_priority)
        
        for i in range(k):
            a = segment * i
            b = segment * (i + 1)
            samp = random.uniform(a, b)
#             print('sample : ',samp)
            priority, data_idx, idx = self.transit_buffer.search(samp)
#             print('data idx',data_idx)
            data = self.transit_buffer.getDataByIdx(data_idx)
            priorities.append(priority)
            batch.append(data)
            idxs.append(idx)
        #print('*'*100)
        sum_priority = np.power(self.transit_buffer.tree[:self.transit_buffer.getNumEntries()],self.alpha)
        sum_priority = sum_priority.sum()
        sampling_probabilities = np.power(priorities, self.alpha) / sum_priority #root_total_priority
        #
        
        #print(sampling_probabilities)
        
        #compute importance sampling 
        weights = np.power(self.transit_buffer.getNumEntries() * sampling_probabilities, -self.beta)
        weights /= weights.max()
        
#         try:
        states, actions, rewards, next_states, dones = zip(*batch)
#         except:
#             print('error in ',batch)
        
        states,actions, rewards, next_states, dones = list(states), list(actions), list(rewards), \
                                                    list(next_states), list(dones) 
        states = torch.cat([x for x in states]).cuda()
        next_states = torch.cat([x for x in next_states]).cuda()
        actions = torch.Tensor(actions).long().cuda()
        dones = np.array(dones).astype(int)
        dones = torch.Tensor(dones).cuda()
        rewards = torch.Tensor(rewards).cuda()
        #priorities = torch.Tensor(priorities).cuda()
        weights = torch.Tensor(weights).cuda()
        
        return weights, idxs, states, actions, rewards, next_states, dones       
        
        
    def updatePriorities(self, indexes, priorities):
        #priorities = np.power(priorities, self.alpha)
        for idx, priority_val in zip(indexes, priorities):
            self.transit_buffer.update(idx, priority_val)
        
        
    def getSize(self):
        return self.transit_buffer.getNumEntries()
        

# Loss function

In [7]:
def compute_loss(size,atoms):
    weights,idx,current_state,action,reward,next_state,done = buffer.sample(size)
    
    #qvalues -> (batch,actions,atoms)
    # z(xt)
    dist_current = net(current_state)
    #p(xt,at)
    current_distribution = dist_current.gather(1,action.view(-1,1).unsqueeze(2).repeat(1,1,N_ATOMS))
    current_distribution = current_distribution.squeeze(1)
    
    with torch.no_grad():
        # z(xt+1)
        dist_next = net(next_state)
        
        #Selecting optimal action a*

        zvalues_next = torch.matmul(dist_next,atoms)
        optimal_action = zvalues_next.max(1)[1]

        # z'(xt+1)
        target_net.reset_noise()
        dist_target = target_net(next_state)
        
        
        #Distribution of target with optimal action z(xt+1,a*)
        dist_target_optimal = dist_target.gather(1,optimal_action.view(-1,1).unsqueeze(2).repeat(1,1,N_ATOMS))
        dist_target_optimal = dist_target_optimal.squeeze(1)

        #Finding target distribution values (Tzj = r + gamma*Z(x,a*)) (not aligned)
        done = done.squeeze(0)
        done = done.unsqueeze(1)
        Tz = reward.unsqueeze(1).repeat(1,N_ATOMS) + (1-done).repeat(1,N_ATOMS)*(GAMMA**N)*atoms.unsqueeze(0).repeat(size,1)

        #Clipping the values
        Tz = torch.clamp(Tz,min=VMIN,max=VMAX)

        #Aligning the values

        deltaz = (VMAX-VMIN)/(N_ATOMS-1)
        indices = (Tz - VMIN)/deltaz
        lower = indices.floor().long()
        upper = indices.ceil().long()
        lower[(upper > 0) * (lower == upper)] -= 1
        upper[(lower < (N_ATOMS - 1)) * (lower == upper)] += 1
    #     dm_l = (upper.float() + (lower == upper).float() - indices)*dist_target_optimal
    #     dm_u = (indices - lower.float())*dist_target_optimal

        #Finding target probabilities
        target_distribution = torch.zeros(size,N_ATOMS).cuda()

    #     for i in range(target_distribution.size(0)):
    #         target_distribution[i].index_add_(0,lower[i].long(),dm_l[i])
    #         target_distribution[i].index_add_(0,upper[i].long(),dm_u[i])

        offset = torch.linspace(0, ((size - 1) * N_ATOMS), size).unsqueeze(1).expand(size, N_ATOMS).to(action).cuda()

        target_distribution = target_distribution.view(-1).index_add_(0, (lower + offset).view(-1), (dist_target_optimal * (upper.float()-indices)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        target_distribution = target_distribution.view(-1).index_add_(0, (upper + offset).view(-1), (dist_target_optimal * (indices - lower.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l) 

        target_distribution = torch.reshape(target_distribution,(size,N_ATOMS))
        
    L = -(target_distribution*torch.log(current_distribution)).sum(dim=1)
    
    new_priorities =  torch.abs(L) + edge_epsilon
    
    L = L*weights
    L = L.mean()
    #print('Loss: ',L.item())
    optimizer.zero_grad()
    L.backward()
    optimizer.step()

    new_priorities = new_priorities.cpu().detach().numpy()
    buffer.updatePriorities(idx,new_priorities)
    
    return L

In [8]:
def update_target():
    target_net.load_state_dict(net.state_dict())

In [9]:
def epsilon_decay(ep):
    e = .01 + 0.99 * np.exp(-ep/30000)
    return e

In [10]:
def addreward(id,item,filename):
    f=open(filename,'a+')
    f.write(str(id)+' '+str(item)+' '+'\n')
    f.close()

In [11]:
def addloss(id,loss,filename):
    f=open(filename,'a+')
    f.write(str(id)+' '+str(loss.item())+' '+'\n')
    f.close()

# Training

In [12]:
# env = gym.make('PongNoFrameskip-v4')
env    = make_atari('PongNoFrameskip-v4')
# env    = make_atari('BankHeist-v0')
# env    = make_atari('Pong-v0')
env    = wrap_deepmind(env)
env    = wrap_pytorch(env)

print(env.observation_space)
print(env.action_space)


Box(4, 84, 84)
Discrete(6)


In [41]:
ITERATIONS = 10000000#5000000#1000000
epsilon = 0.0#0.99
OBS_SHAPE = env.observation_space.shape
ACT_SHAPE = env.action_space.n
REPLAY_SAMPLE = 20000
BATCH_SIZE = 32
GAMMA = 0.99
T_upd = 10000
TMAX = 50e6
TMIN = 20e3
edge_epsilon = 1e-5
N = 3
BLANK_FRAME = torch.zeros(1,4,84,84).cuda()
save_model_interval = 1e5
loss_log_interval = 1e3
episode_reward_interval = 10
N_ATOMS = 51
VMAX = 10
VMIN = -10
atoms = torch.linspace(VMIN,VMAX,N_ATOMS).cuda()

In [14]:
# seed = 123
# np.random.seed(seed)
# torch.manual_seed(np.random.randint(1, 10000)) 
# if torch.cuda.is_available():
#     torch.manual_seed(np.random.randint(1, 10000))
#net = QNet(env.observation_space.shape,env.action_space.n,N_ATOMS)
#net = net.cuda()
#target_net = QNet(env.observation_space.shape,env.action_space.n,N_ATOMS)
#target_net = target_net.cuda()
#update_target()
#optimizer = optim.Adam(net.parameters(), lr=0.0000625,eps=1.5e-4)

In [42]:
lossfile = './rainbow-logs/losses2.csv'
rewardsfile = './rainbow-logs/rewards2.csv'
# if os.path.exists(lossfile):
#     os.remove(lossfile)
# if os.path.exists(rewardsfile):
#     os.remove(rewardsfile)

In [None]:
# nsteps = ReplayBuffer(N)
#buffer = PrioritizedReplayBuffer(30000)#100000) 
# episode_reward = 0    
# state = env.reset()
# state = torch.Tensor(state).cuda()
# state = state.unsqueeze(0)
#count = 0
# losses = []
# rewards = []
# loss_dict = {}
# rewards_dict = {}
# t = 0
# T = np.inf
#loss_count = 0 
for i in tqdm(range(5789088,ITERATIONS+1)): 
    #If current state has not reached terminal
#     print('t:',t,end=' ')
#     print('T:',T,end=' ')
    if(t<T):
        #epsilon = epsilon_decay(i)
        action = eps_greedy(epsilon,state,net,atoms)
        next_state, reward, done, info = env.step(action)
        next_state = torch.Tensor(next_state).unsqueeze(0).cuda()
        nsteps.add((state,action,reward,next_state,done))
        state = next_state
        episode_reward+=reward
        if(done):
            T = t + 1
            count += 1
            rewards_dict[i] = {} 
            rewards_dict[i]['episode'] = count
            rewards_dict[i]['rewards'] = episode_reward
            if count%100 == 0:
                print(i,count,episode_reward) 
            if count%episode_reward_interval == 0: 
                rewards_df = pd.DataFrame.from_dict(data = rewards_dict, orient = 'index').reset_index()
                if not os.path.exists(rewardsfile):
                    rewards_df.to_csv(rewardsfile,index=None, header='column_names')
                else: # else it exists so append without writing the header
                    rewards_df.to_csv(rewardsfile, mode='a',index=None, header=False)
                rewards_dict = {}
            episode_reward = 0
    tau = t - N + 1
#     print('t-n',tau)
    # tau --> index of state for which transition is stored (nth previous state)
    if(tau>=0):
        steps = tau
        s_tau,a_tau,_,_,_ = nsteps.getelem(tau)
        G=0
        while(True):
            s,action,reward,ns,done = nsteps.getelem(steps)
            G+= np.power(GAMMA,steps-tau)*reward
            steps+=1
            if(steps==min(tau+N,T)): 
                if(steps==tau+N): #If nth state is reached
                    buffer.add(s_tau,a_tau,G,ns,done)
                elif(steps==T): #If terminal state is reached
                    buffer.add(s_tau,a_tau,G,BLANK_FRAME,done)
                break
            
    t+=1
    if(tau==T-1): #If nth previous state has reached termination, move environment to next episode
        t=0
        T=np.inf
        state = env.reset()
        state = torch.Tensor(state).cuda()
        state = state.unsqueeze(0)
        nsteps.reset()
            
            
    if(buffer.getSize()>REPLAY_SAMPLE and i%4==0):
        loss = compute_loss(BATCH_SIZE,atoms)
        net.reset_noise()
#         print('Loss',loss.item())
        if i % loss_log_interval == 0:
            loss_dict[i] = {}
            loss_dict[i]['loss'] = loss.item()
            loss_count += 1
            if loss_count == 10: 
                loss_df = pd.DataFrame.from_dict(data = loss_dict, orient = 'index').reset_index()
                if not os.path.exists(lossfile):
                    loss_df.to_csv(lossfile,index=None, header='column_names')
                else: # else it exists so append without writing the header
                    loss_df.to_csv(lossfile, mode='a',index=None, header=False)
                loss_dict = {}
                loss_count = 0
    
    if(i%T_upd==0):
        update_target()
        
    if(i%save_model_interval == 0):
        torch.save(net.state_dict(),'./rainbow-logs/rb-model' + str(i) + '.pth')
        torch.save(target_net.state_dict(),'./rainbow-logs/rb-model-target' + str(i) + '.pth')


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app


HBox(children=(IntProgress(value=0, max=4210913), HTML(value='')))

6952131 5300 -17.0
7621567 5400 -14.0
8364135 5500 15.0
8728933 5600 20.0
9152306 5700 4.0
9492049 5800 20.0
9828350 5900 19.0


# Testing

In [17]:
# metrics = {'steps': [], 'rewards': [], 'Qs': [], 'best_avg_reward': -float('inf')}

In [18]:
# evaluation_episodes = 10

In [19]:
# def test(args, T, net, val_mem, metrics, results_dir, evaluate=False):
#     env = Env(args)
#     env.eval()
#     metrics['steps'].append(T)
#     T_rewards, T_Qs = [], []

#     # Test performance over several episodes
#     done = True
#     for _ in range(evaluation_episodes):
#     while True:
#         if done:
#             state, reward_sum, done = env.reset(), 0, False

#         action = eps_greedy(epsilon,state,net)  # Choose an action ε-greedily
#         state, reward, done = env.step(action)  # Step
#         reward_sum += reward
# #         if args.render:
# #             env.render()

#         if done:
#             T_rewards.append(reward_sum)
#             break
#     env.close()

#     # Test Q-values over validation memory
#     for state in val_mem:  # Iterate over valid states
#         T_Qs.append(dqn.evaluate_q(state))

#     avg_reward, avg_Q = sum(T_rewards) / len(T_rewards), sum(T_Qs) / len(T_Qs)

In [32]:
# pickle.dump(buffer,open('buffer.pkl','wb'))
# buffer2 = pickle.load(open('buffer.pkl','rb'))
# pickle.dump(nsteps,open('nsteps.pkl','wb'))
# nsteps2 = pickle.load(open('nsteps.pkl','rb'))

In [36]:
#buffer.transit_buffer.transit_data_buffer[0]

In [44]:
i

5789088

In [45]:
loss_count

7

In [46]:
count

5209