Skip to content

Commit

Permalink
Notes added on next steps as well as bug fix (layer norm applied in w…
Browse files Browse the repository at this point in the history
…rong order)
  • Loading branch information
jerrodparker20 committed Mar 3, 2020
1 parent 223263b commit 0be63b0
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 77 deletions.
26 changes: 26 additions & 0 deletions Implementations/learned_weighted_td.py
@@ -0,0 +1,26 @@

from transformerDqn import *
import gym
import torch
from dqn import DQN, ReplayBuffer
from torch.optim import Adam
from torch.nn.functional import mse_loss

#creating DQN
embedding_size = 24
dropout = 0.1
B = 32
input_size = 3
dim_feedforward = 16
nhead = 1
num_actions = 4
num_encoder_layers = 1
embedder = CartPoleEmbedder
embedding_params = {'dropout': dropout, 'B': B, 'input_size': input_size, 'embedding_size': embedding_size}
encoder_layer_params = {'d_model':embedding_size, 'nhead':nhead, 'dim_feedforward':dim_feedforward, 'dropout':dropout}
dqn = TransformerDqn(embedder=embedder,embedder_params=embedding_params,
encoder_layer_params=encoder_layer_params,output_size=num_actions,
num_encoder_layers=1)



17 changes: 17 additions & 0 deletions StableTransformersReplication/implNotes.txt
@@ -0,0 +1,17 @@


Here are notes on the training procedure used in TXL.

1. How does adaptive embedding work? The complicated version div_val==2 doesn't seem to actually be used (must have been experimental)
This seems to just me a normal learned embedding layer.


How does choosing what goes in a batch happen (do we use consecutive segments so that cache doesn't get too large?)
How does masking work?

What are tgt_len, mem_len and ext_len?


LMOrderedIterator:
They've just implemented batches coming sequentially which may be unoptimal (semi correlated updates)
This allows them at each step to just store memory from previous step. (can make a better iterator than this***)
54 changes: 25 additions & 29 deletions StableTransformersReplication/transformer_xl.py
@@ -1,12 +1,22 @@

'''
This model can be ran as the original transformer XL or the stable transformer XL. Much of this code came from
https://github.com/kimiyoung/transformer-xl
https://github.com/kimiyoung/transformer-xl but now has added functionality to use gating as well as different
orderings of the submodules as done in https://arxiv.org/pdf/1910.06764.pdf
TO DO:
1. Figure out how we'll actually run this (how is cache kept and so on?)
a) need to rewrite the training script (will assume functionality for batching previous examples exists)
2. CHECK THIS: Have I applied layer norms in correct order?
3. Initialize b_g (one of the bias' in GRU) to 2
4. They use 256dim embedding, 512 memory size
5. Add in action set from table 5 of paper (15 actions) (is even more simple for some in table 6)
Remember: Receptive field in transformer XL is linear in #layers and segment size
'''

Expand Down Expand Up @@ -310,7 +320,7 @@ def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
tgt_len=None, ext_len=None, mem_len=None,
cutoffs=[], adapt_inp=False,
same_length=False, clamp_len=-1,
sample_softmax=-1):
use_gate=True, use_stable_version=True):
super(MemTransformerLM, self).__init__()
self.n_token = n_token

Expand All @@ -336,11 +346,11 @@ def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
self.layers.append(
RelPartialLearnableDecoderLayer(
n_head, d_model, d_head, d_inner, dropout,
use_stable_version=use_stable_version, use_gate=use_gate,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
dropatt=dropatt)
)


#To do: Look into sample softmax and adaptive softmax for future, not relevant here though
# are useful when need fast softmax over many classes

Expand All @@ -353,29 +363,15 @@ def backward_compatible(self):
self.sample_softmax = -1

def _create_params(self):
if self.attn_type == 0: # default attention
self.pos_emb = PositionalEmbedding(self.d_model)
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
elif self.attn_type == 1: # learnable
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.n_head, self.d_head))
self.r_bias = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head))
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.pos_emb = PositionalEmbedding(self.d_model)
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))

def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len

#QUESTION: What is happening here?
def init_mems(self):
if self.mem_len > 0:
mems = []
Expand Down Expand Up @@ -414,14 +410,14 @@ def _update_mems(self, hids, mems, qlen, mlen):
return new_mems

def _forward(self, dec_inp, mems=None):
qlen, bsz = dec_inp.size() #NOTE: qlen seems to be number of characters in input ex
qlen, bsz = dec_inp.size() #qlen is number of characters in input ex

word_emb = self.state_emb(dec_inp)
obs_emb = self.state_emb(dec_inp)

mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length: #DONT THINK WE WANT SAME LENGTH (I think this just makes each token have same attention span)
all_ones = word_emb.new_ones(qlen, klen)
all_ones = obs_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
Expand All @@ -431,16 +427,16 @@ def _forward(self, dec_inp, mems=None):
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
else:
dec_attn_mask = torch.triu(
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
obs_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]

hids = []
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
pos_seq = torch.arange(klen-1, -1, -1.0, device=obs_emb.device,
dtype=obs_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)

core_out = self.drop(word_emb)
core_out = self.drop(obs_emb)
pos_emb = self.drop(pos_emb)

hids.append(core_out)
Expand Down Expand Up @@ -470,4 +466,4 @@ def forward(self, data, *mems):

pred_hid = hidden[-tgt_len:]

return F.softmax(pred_hid) #NEED TO CHANGE THIS
return F.softmax(pred_hid) #NEED TO CHANGE THIS (ADD MLP that maps to correct # actions
2 changes: 1 addition & 1 deletion StableTransformersReplication/vanillaTransformer.py
Expand Up @@ -80,7 +80,7 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, use_gate =
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)

self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self. dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)

self.norm1 = LayerNorm(d_model)
Expand Down
47 changes: 0 additions & 47 deletions dqn.py
@@ -1,52 +1,5 @@
import torch.nn as nn
import torch
import numpy as np


# In[]:

class ReplayBuffer:
def __init__(self, max_size=2000):
self.max_size = max_size

self.cur_states = []
self.actions = []
self.next_states = []
self.rewards = []
self.dones = []

def __len__(self):
return len(self.cur_states)

def add(self, cur_state, action, next_state, reward, done):
self.cur_states.append(cur_state)
self.actions.append(action)
self.next_states.append(next_state)
self.rewards.append(reward)
self.dones.append(done)

def sample(self, sample_size=32):
sample_transitions = {}
if self.__len__() >= sample_size:
# pick up only random 32 events from the memory
# TODO : Replace np with torch functionality and remove import numpy from above
indices = np.random.choice(self.__len__(), size=sample_size)
sample_transitions['cur_states'] = torch.stack(self.cur_states)[indices]
sample_transitions['actions'] = torch.stack(self.actions)[indices]
sample_transitions['next_states'] = torch.stack(self.next_states)[indices]
sample_transitions['rewards'] = torch.Tensor(self.rewards)[indices]
sample_transitions['dones'] = torch.Tensor(self.dones)[indices]
else:
# if the current buffer size is not greater than 32 then pick up the entire memory
sample_transitions['cur_states'] = torch.stack(self.cur_states)
sample_transitions['actions'] = torch.stack(self.actions)
sample_transitions['next_states'] = torch.stack(self.next_states)
sample_transitions['rewards'] = torch.Tensor(self.rewards)
sample_transitions['dones'] = torch.Tensor(self.dones)

return sample_transitions

# In[]:


class DQN(nn.Module):
Expand Down
45 changes: 45 additions & 0 deletions replayBuffer.py
@@ -0,0 +1,45 @@
import torch
import numpy as np



class ReplayBuffer:
def __init__(self, max_size=2000):
self.max_size = max_size

self.cur_states = []
self.actions = []
self.next_states = []
self.rewards = []
self.dones = []

def __len__(self):
return len(self.cur_states)

def add(self, cur_state, action, next_state, reward, done):
self.cur_states.append(cur_state)
self.actions.append(action)
self.next_states.append(next_state)
self.rewards.append(reward)
self.dones.append(done)

def sample(self, sample_size=32):
sample_transitions = {}
if self.__len__() >= sample_size:
# pick up only random 32 events from the memory
# TODO : Replace np with torch functionality and remove import numpy from above
indices = np.random.choice(self.__len__(), size=sample_size)
sample_transitions['cur_states'] = torch.stack(self.cur_states)[indices]
sample_transitions['actions'] = torch.stack(self.actions)[indices]
sample_transitions['next_states'] = torch.stack(self.next_states)[indices]
sample_transitions['rewards'] = torch.Tensor(self.rewards)[indices]
sample_transitions['dones'] = torch.Tensor(self.dones)[indices]
else:
# if the current buffer size is not greater than 32 then pick up the entire memory
sample_transitions['cur_states'] = torch.stack(self.cur_states)
sample_transitions['actions'] = torch.stack(self.actions)
sample_transitions['next_states'] = torch.stack(self.next_states)
sample_transitions['rewards'] = torch.Tensor(self.rewards)
sample_transitions['dones'] = torch.Tensor(self.dones)

return sample_transitions
23 changes: 23 additions & 0 deletions tester.py
@@ -0,0 +1,23 @@


import torch.nn as nn
import torch

class Tester(nn.Module):
def __init__(self):
super(Tester, self).__init__()

self.linear = nn.Linear(5,5)
self.dropout = nn.Dropout(p=0.5)

def forward(self,input):

output = self.linear(input)
print(self.dropout(output))
return output

test = Tester()
input = torch.rand(1,5)


test.forward(input)

0 comments on commit 0be63b0

Please sign in to comment.