In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import dataset
import math

In [2]:
d = torch.device('cuda')

In [3]:
class Baseline(nn.Module):
    def __init__(self, device,
                    max_len,
                    num_tokens,
                    dim=64,
                    nhead=8,
                    num_encoders=2,
                    num_decoders=2,
                    d_feedforward=1024,
                    batch_first=True):
        super(Baseline, self).__init__()
        self.max_len = max_len
        self.device = device
        self.tokens = num_tokens
        ##Create encoder layers##
        self.src_emb = nn.Embedding(num_tokens, dim)
        self.tgt_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_len, dim)
        ##Create Transformer##
        self.transformer = nn.Transformer(d_model=dim, nhead=nhead,num_encoder_layers=num_encoders,     \
                                         num_decoder_layers=num_decoders, dim_feedforward=d_feedforward, \
                                         batch_first=batch_first)
        ##Create Final Linear Layer##
        self.linear = nn.Linear(dim, num_tokens)
        ##Create TimeStep Input##
        self.timesteps = torch.Tensor([[i for i in range(max_len)]]).type(torch.LongTensor).to(device)
    def forward(self, src, tgt, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask):
        pos_emb = self.pos_emb(self.timesteps)
        src_emb = self.src_emb(src)
        tgt_emb = self.tgt_emb(tgt)
        src_in = pos_emb + src_emb
        tgt_in = pos_emb + tgt_emb
        trans_out = self.transformer(src_in, tgt_in, src_mask, tgt_mask, None, src_pad_mask, tgt_pad_mask)
        return (self.linear(trans_out))
        
        

In [4]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=d)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt, dset):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=d).type(torch.bool)

    src_padding_mask = (src == dset.TOKENS["<PAD>"])
    tgt_padding_mask = (tgt == dset.TOKENS["<PAD>"])
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


In [5]:
dset = dataset.Arithmetic(10)
x,y,y_ = dset.get_batch(4)
print(x,y)
create_mask(x,y,dset)

tensor([[11,  4, 11,  2, 13, 14, 14, 14, 14],
        [11,  5, 11, 11,  2, 13, 14, 14, 14],
        [ 3, 12, 11,  2, 13, 14, 14, 14, 14],
        [11,  2, 12,  4, 13, 14, 14, 14, 14]], dtype=torch.int32) tensor([[15, 11,  6, 13, 14, 14, 14, 14, 14],
        [15, 11,  3, 13, 14, 14, 14, 14, 14],
        [15, 11,  6, 13, 14, 14, 14, 14, 14],
        [15, 11,  8, 13, 14, 14, 14, 14, 14]], dtype=torch.int32)


(tensor([[False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False, False]],
        device='cuda:0'),
 tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
         [0.

In [6]:
def train(model, optim, crit, device, iterations, dset, batch_size, print_freq, scheduler):
    running_loss = 0
    for it in range(iterations):
        x,y,y_ = dset.get_batch(batch_size)
        x = x.type(torch.LongTensor).to(device)
        y = y.type(torch.LongTensor).to(device)
        y_ = y_.type(torch.LongTensor).to(device)
        
        src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(x,y,dset)
        model_out = model(x, y, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask)
        
        optim.zero_grad()
        
        loss = crit(model_out.reshape(-1, model_out.shape[-1]), y_.reshape(-1))
        loss.backward()
        
        optim.step()
        running_loss+=loss.item()
        scheduler.step()
        
        if (it+1) % print_freq == 0:
            print("Iteration:",it+1,"Loss:",running_loss/print_freq)
            running_loss=0
            
def convert(expression:str, dset, model, device):
    ##Convert String to a tensor##
    src = torch.tensor([dset.tokenize_expression(expression)]).to(device)
    src = src[:, 1:]
    print(dset.max_len)
    print(src.shape)
    ##Set up output##
    y = torch.ones(1, dset.max_len).fill_(dset.TOKENS["<PAD>"]).type(torch.long).to(device)
    y[0,0] = dset.TOKENS["<SOS>"]
    model.eval()
    for i in range(dset.max_len-1):
        src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(src,y,dset)
        out = model(src, y, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask)
        probs = out[0,i]
        next_token = torch.argmax(probs,dim=0)
        y[0, i+1] = next_token
        if next_token == dset.TOKENS["<EOS>"]:
            break
    y = y.squeeze(0)
    y = y.tolist()
    return dset.get_str(y)
        
        

In [7]:
dset = dataset.Arithmetic(100)
model = Baseline(d, dset.max_len,dset.num_tokens,num_encoders=1,
                    num_decoders=1, dim=128).to(d)
op = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(op, step_size=10000, gamma=.95)
crit = nn.CrossEntropyLoss(ignore_index=dset.TOKENS["<PAD>"])
convert("1+1", dset, model, d)

11
torch.Size([1, 11])


'<SOS><EOS><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD>'

In [21]:
train(model, op, crit, d, 100000, dset, 128, 1000, scheduler)

Iteration: 1000 Loss: 0.1666370017975569
Iteration: 2000 Loss: 0.13526613587886094
Iteration: 3000 Loss: 0.10743810536339879
Iteration: 4000 Loss: 0.09017679917812348
Iteration: 5000 Loss: 0.07490667941607534
Iteration: 6000 Loss: 0.06688066721521319
Iteration: 7000 Loss: 0.05998076453059912
Iteration: 8000 Loss: 0.05215799659304321


KeyboardInterrupt: 

In [34]:
convert("20*10", dset, model, d)

11
torch.Size([1, 11])


'<SOS>200<EOS><PAD><PAD><PAD><PAD><PAD><PAD>'