# Protein Folding Prioritized DQN

### Imports

In [None]:
import argparse
import math
import random
from copy import deepcopy

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn

import matplotlib.pyplot as plt
%matplotlib inline

### Use Cuda

In [None]:
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
    print("Using GPU: GPU requested and available.")
    dtype = torch.cuda.FloatTensor
    dtypelong = torch.cuda.LongTensor
else:
    print("NOT Using GPU: GPU not requested or not available.")
    dtype = torch.FloatTensor
    dtypelong = torch.LongTensor

### Agent

In [None]:
class Agent:
    def __init__(self, env, q_network, target_q_network):
        self.env = env
        self.q_network = q_network
        self.target_q_network = target_q_network
        self.num_actions = env.action_space.n

    def act(self, state, epsilon):
        """DQN action - max q-value w/ epsilon greedy exploration."""
        if random.random() > epsilon:
            state = torch.tensor(np.float32(state)).type(dtype).unsqueeze(0)
            q_value = self.q_network.forward(state)
            return q_value.max(1)[1].data[0]
        return torch.tensor(random.randrange(self.env.action_space.n))

### Prioritized Replay Buffer

Prioritized Experience Replay: https://arxiv.org/abs/1511.05952

In [None]:
class NaivePrioritizedBuffer(object):
    def __init__(self, capacity, prob_alpha=0.6):
        self.prob_alpha = prob_alpha
        self.capacity   = capacity
        self.buffer     = []
        self.pos        = 0
        self.priorities = np.zeros((capacity,), dtype=np.float32)
    
    def push(self, state, action, reward, next_state, done):
        assert state.ndim == next_state.ndim
        state      = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
        
        max_prio = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.pos] = (state, action, reward, next_state, done)
        
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity
    
    def sample(self, batch_size, beta=0.4):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        
        probs  = prios ** self.prob_alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        
        total    = len(self.buffer)
        weights  = (total * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights  = np.array(weights, dtype=np.float32)
        
        batch       = list(zip(*samples))
        states      = np.concatenate(batch[0])
        actions     = batch[1]
        rewards     = batch[2]
        next_states = np.concatenate(batch[3])
        dones       = batch[4]
        
        return states, actions, rewards, next_states, dones, indices, weights
    
    def update_priorities(self, batch_indices, batch_priorities):
        for idx, prio in zip(batch_indices, batch_priorities):
            self.priorities[idx] = prio

    def __len__(self):
        return len(self.buffer)

In [None]:
replay_size = 100000
replay_buffer = NaivePrioritizedBuffer(replay_size)

### Beta Calculation

In [None]:
beta_start = 0.4
beta_frames = 1000 
beta_by_frame = lambda frame_idx: min(1.0, beta_start + frame_idx * (1.0 - beta_start) / beta_frames)

In [None]:
plt.plot([beta_by_frame(i) for i in range(10000)])

### Epsilon Greedy Exploration

In [None]:
epsilon_start = 1.0
epsilon_final = 0.01
epsilon_decay = 500

epsilon_by_frame = lambda frame_idx: epsilon_final + (epsilon_start - epsilon_final) * math.exp(-1. * frame_idx / epsilon_decay)

In [None]:
plt.plot([epsilon_by_frame(i) for i in range(10000)])

### Computing Temporal Difference Loss

In [None]:
def compute_td_loss(agent, batch_size, replay_buffer, optimizer, gamma, beta):
    state, action, reward, next_state, done, indices, weights = replay_buffer.sample(batch_size, beta)
    
    state = torch.tensor(np.float32(state)).type(dtype)
    next_state = torch.tensor(np.float32(next_state)).type(dtype)
    action = torch.tensor(action).type(dtypelong)
    reward = torch.tensor(reward).type(dtype)
    done = torch.tensor(done).type(dtype)
    weights = torch.tensor(weights).type(dtype)

    q_values      = agent.q_network(state)
    next_q_values = agent.target_q_network(next_state)

    q_value          = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_value     = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)
    
    loss  = (q_value - expected_q_value.detach()).pow(2) * weights
    prios = loss + 1e-5
    loss  = loss.mean()
        
    optimizer.zero_grad()
    loss.backward()
    replay_buffer.update_priorities(indices, prios.data.cpu().numpy())
    optimizer.step()
    
    return loss

### Update target network

In [None]:
def soft_update(q_network, target_q_network, tau):
    for t_param, param in zip(target_q_network.parameters(), q_network.parameters()):
        if t_param is param:
            continue
        new_param = tau * param.data + (1.0 - tau) * t_param.data
        t_param.data.copy_(new_param)

def hard_update(q_network, target_q_network):
    for t_param, param in zip(target_q_network.parameters(), q_network.parameters()):
        if t_param is param:
            continue
        new_param = param.data
        t_param.data.copy_(new_param)
        
def update_target(q_network, target_q_network):
    target_q_network.load_state_dict(q_network.state_dict())

### Training

In [None]:
learning_rate = 0.001
target_update_rate = 0.1
gamma = 0.99
target_network_update_f = 1000
num_timesteps = 10000
log_every = 200
batch_size = 32
start_train = 32

def train(env):
    
    agent = Agent(env, q_network, target_q_network)
    optimizer = optim.Adam(q_network.parameters(), lr = learning_rate)

    losses, all_rewards = [], []
    episode_reward = 0
    state = env.reset()

    for ts in range(1, num_timesteps + 1):
        epsilon = epsilon_by_frame(ts)
        action = agent.act(state, epsilon)
        
        next_state, reward, done, _ = env.step(int(action.cpu()))
        
        replay_buffer.push(state, action, reward, next_state, done)

        state = next_state
        episode_reward += reward

        if done:
            state = env.reset()
            all_rewards.append(episode_reward)
            episode_reward = 0

        if len(replay_buffer) > start_train:
            beta = beta_by_frame(ts)
            loss = compute_td_loss(agent, batch_size, replay_buffer, optimizer, gamma, beta)
            losses.append(loss.data)

            if ts % target_network_update_f == 0:
                # soft_update(agent.q_network, agent.target_q_network, target_update_rate)
                # hard_update(agent.q_network, agent.target_q_network)
                update_target(agent.q_network, agent.target_q_network)

        if ts % log_every == 0:
            out_str = "Timestep {}".format(ts)
            if len(all_rewards) > 0:
                out_str += ", Reward: {}".format(all_rewards[-1])
            if len(losses) > 0:
                out_str += ", TD Loss: {}".format(losses[-1])
            print(out_str)
    
    return losses, all_rewards, agent

### Plot Losses and Rewards

In [None]:
def plot(losses, rewards):
    plt.figure(figsize = (20,20))
    plt.subplot(211)
    plt.title("Rewards")
    plt.plot(rewards)
    plt.subplot(212)
    plt.title("Loss")
    plt.plot(losses)
    plt.show()

### Run trained agent on environment

In [None]:
def run_agent(env):
    env = env
    agent = Agent(env, q_network, target_q_network)
    state = env.reset()
    while True:
        action = agent.act(state, 0)
        next_state, reward, done, info = env.step(int(action.cpu()))
        env.render()
        state=next_state
        if done:
            print("Reward: {} | Actions: {}".format(reward, info['actions']))
            break

## Prioritized DQN with Linear Model

In [None]:
from lattice2d_linear_env import Lattice2DLinearEnv

class DQN(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(DQN, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs[0], 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )
        
    def forward(self, x):
        return self.layers(x)

In [None]:
env = Lattice2DLinearEnv("H")
q_network = DQN(env.observation_space.shape, env.action_space.n)
target_q_network = DQN(env.observation_space.shape, env.action_space.n)

if USE_CUDA:
    q_network = q_network.cuda()
    target_q_network = target_q_network.cuda()

### Train on single sequence

In [None]:
env = Lattice2DLinearEnv("HPPHPHPH")
train(env)

### Results of training

In [None]:
run_agent(env)

### Train on multiple sequences

In [None]:
# Env params
collision_penalty = -2
trap_penalty = 0.5

max_seq_length = 5
seq_dict = {}

# Train on all sequences with length <= max_seq_length
for seq in seqs_list:
    if len(seq) > max_seq_length:
        break
    else:
        env = Lattice2DLinearEnv(seq, collision_penalty, trap_penalty)
        # Brute force for sequences with length less than 4
        if len(seq) <= 4:
            reward, actions = env.all_combs()
            seq_dict.update( {seq : reward})
        else:
            losses, rewards, agent = train(env)
            seq_dict.update( {seq : rewards[-1]})

## Prioritized DQN with CNN Model

In [None]:
from lattice2d_cnn_env import Lattice2DCNNEnv

class CnnDQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(CnnDQN, self).__init__()
        self.input_shape = input_shape
        self.num_actions = num_actions
        self.features = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(self.feature_size(), 512),
            nn.ReLU(),
            nn.Linear(512, self.num_actions),
        )
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def feature_size(self):
        return self.features(torch.zeros(1, *self.input_shape)).view(1, -1).size(1)

### Train on single sequence

In [None]:
env = Lattice2DLinearEnv("HPPHPHPH")
train(env)

### Results of training

In [None]:
run_agent(env)

### Train on multiple sequences

In [None]:
# Env params
collision_penalty = -2
trap_penalty = 0.5

max_seq_length = 5
seq_dict = {}

# Train on all sequences with length <= max_seq_length
for seq in seqs_list:
    if len(seq) > max_seq_length:
        break
    else:
        env = Lattice2DLinearEnv(seq, collision_penalty, trap_penalty)
        # Brute force for sequences with length less than 4
        if len(seq) <= 4:
            reward, actions = env.all_combs()
            seq_dict.update( {seq : reward})
        else:
            losses, rewards, agent = train(env)
            seq_dict.update( {seq : rewards[-1]})