In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch import optim
import math
import gym

#from actor_critic import PolicyNetwork
from torch import optim
import dataset
import math

In [28]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_space, n_actions, device, retain_graph=True):
        super(PolicyNetwork, self).__init__()
        self.base = nn.Linear(state_space, 1024)
        self.base2 = nn.Linear(1024, 256)
        self.actions = nn.Linear(256, n_actions)
        self.value = nn.Linear(256, 1)
        self.rewards = []
        self.action_pairs = [] #[(LOG_PROB, CRITIC_VALUE)]
        self.retain_graph= retain_graph
        self.device =device

    def forward(self, x):
        x = F.relu(self.base(x))
        x = F.relu(self.base2(x))
        a = F.softmax(self.actions(x), dim=-1)
        v = self.value(x)
        return a,v

    def select_action(self, state):
        a, v = self.forward(state)
        m = Categorical(a)
        action = m.sample()
        self.action_pairs.append((m.log_prob(action), v))
        return action.item()

    def addReward(self, r):
        self.rewards.append(r)

    def train(self, OPTIM, gamma):

        ##CONSTRUCT SAMPLED VALUES##
        rewards = []
        R = 0
        for r in self.rewards[::-1]:
            R = r + gamma * R
            rewards.insert(0, R)

        ##NORMALIZE SAMPLED STATE VALUES##
        rewards = torch.tensor(rewards, requires_grad=False)

        if rewards.shape[0] > 1:
            rewards = (rewards-rewards.mean())/(rewards.std() + 1e-4)

        ##GET ACTOR AND CRITIC LOSS##
        actor_loss = torch.tensor([0], dtype=torch.float32).to(self.device)
        critic_loss = torch.tensor([0], dtype=torch.float32).to(self.device)
        for (log_prob, val), R in zip(self.action_pairs, rewards):
            advantage = R.item() - val.item()
            actor_loss += -log_prob*advantage
            critic_loss += F.smooth_l1_loss(val, torch.tensor([[R]]).to(self.device))
        total_loss = actor_loss + critic_loss

        ##OPTIMIZE##
        OPTIM.zero_grad()
        total_loss.backward(retain_graph=self.retain_graph)
        OPTIM.step()

        ##CLEAR MEMORY##
        del self.rewards[:]
        del self.action_pairs[:]



In [29]:
REPEAT = 0
TERMINAL = 1
class RetraEncBlock(nn.Module):
    def __init__(self,device, d_model, nhead, actor_optim, actor_train=False, GAMMA=0.99, max_runtime=3, dim_feedforward=2048, dropout=0.1, batch_first=True):
        super(RetraEncBlock, self).__init__()
        self.encoder_block = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, \
                                                        dropout=dropout, batch_first=batch_first)
        self.actor_critic = PolicyNetwork(d_model, 2, device)
        self.actor_optim = actor_optim(self.actor_critic.parameters(), 3e-4)
        self.GAMMA = GAMMA
        self.actor_train = actor_train
        
        self.max_runtime = max_runtime
    
    def propagateLoss(self, loss):
        ##Take the negative of the loss so we minimize the loss
        assert(self.actor_train)
        self.actor_critic.addReward(loss)
        self.actor_critic.train(self.actor_optim, self.GAMMA)
        
    def forward(self, x, mask, pad_mask):
        action = REPEAT
        i=0
        while action == REPEAT and i < self.max_runtime:
            x = self.encoder_block(x, mask, pad_mask)
            
            
            if not self.actor_train:
                break
                
            a = self.actor_critic.select_action(x[:,0])
            if a == REPEAT:
                self.actor_critic.addReward(0)
            
                
            i+=1
        return x
            
class RetraDecBlock(nn.Module):
    def __init__(self,device, d_model, nhead, actor_optim, actor_train=False, GAMMA=0.99, max_runtime=20, dim_feedforward=2048, dropout=0.1, batch_first=True):
        super(RetraDecBlock, self).__init__()
        self.decoder_block = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, \
                                                        dropout=dropout, batch_first=batch_first)
        self.actor_critic = PolicyNetwork(d_model, 2, device)
        self.actor_optim = actor_optim(self.actor_critic.parameters(), 1e-5)
        self.GAMMA = GAMMA
        self.actor_train = actor_train
        
        
        self.max_runtime = max_runtime
        
    def propagateLoss(self, loss):
        ##Take the negative of the loss so we minimize the loss
        assert(self.actor_train)
        self.actor_critic.addReward(loss)
        self.actor_critic.train(self.actor_optim, self.GAMMA)
        
    def forward(self, x, mem, mask, pad_mask):
        action = REPEAT
        i=0
        while action == REPEAT and i < self.max_runtime:
            x = self.decoder_block(x, mem, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
            
            
            if not self.actor_train:
                break
                
            a = self.actor_critic.select_action(x[:,0])
            if a == REPEAT:
                self.actor_critic.addReward(0)
                
            i+=1
        return x
                       

In [30]:
class RetraNet(nn.Module):
    def __init__(self, device,
                    actor_optim,
                    max_len,
                    num_tokens,
                    num_encoders=1,
                    num_decoders=1,
                    dim=64,
                    nhead=8,
                    d_feedforward=1024,
                    batch_first=True):
        super(RetraNet, self).__init__()
        self.max_len = max_len
        self.device = device
        self.tokens = num_tokens
        ##Create encoder layers##
        self.src_emb = nn.Embedding(num_tokens, dim)
        self.tgt_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_len, dim)
        ##Create Transformer##
        self.encoders = [RetraEncBlock(device, dim, nhead, actor_optim).to(device) for i in range(num_encoders)]
        
        self.decoders = [RetraDecBlock(device, dim,nhead, actor_optim).to(device) for i in range(num_decoders)]
        ##Create Final Linear Layer##
        self.linear = nn.Linear(dim, num_tokens)
        ##Create TimeStep Input##
        self.timesteps = torch.Tensor([[i for i in range(max_len)]]).type(torch.LongTensor).to(device)
    def forward(self, src, tgt, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask):
        pos_emb = self.pos_emb(self.timesteps)
        src_emb = self.src_emb(src)
        tgt_emb = self.tgt_emb(tgt)
        src_in = pos_emb + src_emb
        tgt_in = pos_emb + tgt_emb
        for enc in self.encoders:
            mem = enc(src_in, src_mask, src_pad_mask)
        for dec in self.decoders:
            out = dec(tgt_in, mem, tgt_mask, tgt_pad_mask)
        
        return (self.linear(out))
    
    def propagateLoss(self, loss):
        for enc in self.encoders:
            enc.propagateLoss(loss)
        for dec in self.decoders:
            dec.propagateLoss(loss)
    def toggleActor(self):
        for enc in self.encoders:
            enc.actor_train = not enc.actor_train
        for dec in self.decoders:
            dec.actor_train = not dec.actor_train
        
        
        
        

In [31]:
 device = torch.device('cuda')
# m = RetraNet(device, optim.SGD , 10, 50, dim=64).to(device)
# x = torch.zeros((1, 10), requires_grad=True).type(torch.LongTensor).to(device)
# y = torch.zeros((1, 10), requires_grad=True).type(torch.LongTensor).to(device)
# m(x,y, None, None, None, None)
# m.propagateLoss(-1)

In [32]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=d)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt, dset):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=d).type(torch.bool)

    src_padding_mask = (src == dset.TOKENS["<PAD>"])
    tgt_padding_mask = (tgt == dset.TOKENS["<PAD>"])
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [33]:
def train(model, optim, crit, device, iterations, dset, batch_size, print_freq, scheduler):
    running_loss = 0
    for it in range(iterations):
        x,y,y_ = dset.get_batch(batch_size)
        x = x.type(torch.LongTensor).to(device)
        y = y.type(torch.LongTensor).to(device)
        y_ = y_.type(torch.LongTensor).to(device)
        
        src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(x,y,dset)
        model_out = model(x, y, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask)
        
        optim.zero_grad()
        
        loss = crit(model_out.reshape(-1, model_out.shape[-1]), y_.reshape(-1))
        if(model.encoders[0].actor_train):
            model.propagateLoss(loss.item())
        loss.backward()
        
        optim.step()
        running_loss+=loss.item()
        scheduler.step()
        
        if (it+1) % print_freq == 0:
            print("Iteration:",it+1,"Loss:",running_loss/print_freq)
            running_loss=0
            
def convert(expression:str, dset, model, device):
    ##Convert String to a tensor##
    src = torch.tensor([dset.tokenize_expression(expression)]).to(device)
    src = src[:, 1:]
    ##Set up output##
    y = torch.ones(1, dset.max_len).fill_(dset.TOKENS["<PAD>"]).type(torch.long).to(device)
    y[0,0] = dset.TOKENS["<SOS>"]
    print(x.shape)
    print(y.shape)
    model.eval()
    for i in range(dset.max_len-1):
        src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(src,y,dset)
        out = model(src, y, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask)
        probs = out[0,i]
        next_token = torch.argmax(probs,dim=0)
        y[0, i+1] = next_token
        if next_token == dset.TOKENS["<EOS>"]:
            break
    y = y.squeeze(0)
    y = y.tolist()
    return dset.get_str(y)

In [34]:
d = device
dset = dataset.Arithmetic(10)
model = RetraNet(d, optim.Adam, dset.max_len, dset.num_tokens, dim=256, nhead=32).to(d)
model.toggleActor()
op = optim.Adam(model.parameters(), lr=3e-8)
scheduler = optim.lr_scheduler.StepLR(op, step_size=10000, gamma=.99)
crit = nn.CrossEntropyLoss(ignore_index=dset.TOKENS["<PAD>"])
#convert("1+1", dset, model, d)

In [35]:
#torch.autograd.set_detect_anomaly(True)
train(model, op, crit, d, 1000000, dset, 1, 100, scheduler)

Iteration: 100 Loss: 2.6867427909374237
Iteration: 200 Loss: 2.660132176876068
Iteration: 300 Loss: 2.6668092918396


KeyboardInterrupt: 

In [None]:
model.toggleActor()

In [None]:
model.encoders[0].actor_train