<a href="https://colab.research.google.com/github/fannix/timeseries_generation/blob/master/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import torch
batch_size = 5
nb_digits = 10
y = torch.LongTensor(batch_size,1).random_() % nb_digits
y
y_onehot = torch.FloatTensor(batch_size, nb_digits)
y_onehot.zero_()
y_onehot.scatter_(1, y, 1)

tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
import torch
from torch import nn

from torch.utils.data import DataLoader


class DecoderRNN(nn.Module):
  def __init__(self, hidden_size, output_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(output_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
    self.out = nn.Linear(hidden_size, output_size)

  def forward(self, input, hc):
    output = self.embedding(input)
    output = torch.relu(output)
    output, (hidden, cell) = self.lstm(output, hc)
    #print(output.shape)
    output = self.out(output)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device), 
            torch.zeros(1, batch_size, self.hidden_size, device=device))


class EncoderRNN(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(input_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)

  def forward(self, input, hc):
    embedded = self.embedding(input)
    output = embedded
    output, (hidden, cell) = self.lstm(output, hc)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device),
            torch.zeros(1, batch_size, self.hidden_size, device=device))

class Seq2Seq(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.encoder = EncoderRNN(input_size, hidden_size)
    self.decoder = DecoderRNN(hidden_size, input_size)
  
  def forward(self, input, expected = None):
    batch_size = input.shape[0]
    h0, c0 = self.encoder.init_hidden(batch_size)
    encode_output, (encode_hidden, encode_cell) = self.encoder(input, (h0, c0))

    inp = torch.zeros(batch_size, dtype=torch.long, device=device)
    inp = inp.view(batch_size, 1)
    h, c = encode_hidden, encode_cell

    #print(inp.shape)
    output_list = []
    while True:
      if expected == None:
        out, (h, c) = self.decoder(inp, (h, c))
      # teacher forcing
      else:
        pass
      output_list.append(out.squeeze(1))
      if len(output_list) == input.shape[1]:
        break
    # print(output_list[0].shape)
    return torch.stack(output_list, 2)



Attention LSTM Seq2Seq


some helper functions

In [20]:
import torch

def mask_3d(inputs, seq_len, mask_value=0.):
    """
    Use the mask_value to make a 3d or 2d tensor
    inputs: tensor. N*T*D
    seq_len: N. length of valid seq
    """
    batches = inputs.size()[0]
    assert batches == len(seq_len)
    max_idx = max(seq_len)
    for n, idx in enumerate(seq_len):
        if idx <= max_idx.item():
            if len(inputs.size()) == 3:
                inputs[n, idx.int():, :] = mask_value
            else:
                assert len(inputs.size()) == 2, "The size of inputs must be 2 or 3, received {}".format(inputs.size())
                inputs[n, idx.int():] = mask_value
    return inputs

input = torch.randn(3, 4)
seq_len = torch.LongTensor([1, 2, 3])
print(seq_len)
mask_3d(input, seq_len)

tensor([1, 2, 3])


tensor([[-0.4719,  0.0000,  0.0000,  0.0000],
        [ 0.2437, -0.7672,  0.0000,  0.0000],
        [ 1.0179, -0.5958,  1.2337,  0.0000]])

In [0]:
def pad_collate(batch, values=(0, 0), dim=0):
    """
    args:
        batch - list of (tensor, label)
    reutrn:
        xs - a tensor of all examples in 'batch' after padding
        ys - a LongTensor of all labels in batch
        ws - a tensor of sequence lengths
    """

    sequence_lengths = torch.Tensor([int(x[0].shape[dim]) for x in batch])
    sequence_lengths, xids = sequence_lengths.sort(descending=True)
    target_lengths = torch.Tensor([int(x[1].shape[dim]) for x in batch])
    target_lengths, yids = target_lengths.sort(descending=True)
    batch_x, batch_y = list(zip(*batch))
    batch_x = torch.nn.utils.rnn.pad_sequence(batch_x, batch_first=True, padding_value=values[0])
    batch_y = torch.nn.utils.rnn.pad_sequence(batch_y, batch_first=True, padding_value=values[1])

    # stack all

    xs = batch_x[xids]
    ys = batch_y[yids]
    return xs, ys, sequence_lengths.int(), target_lengths.int()

In [17]:
zipped = [("a", 1), ("b", 2)]

unzipped_object = zip(*zipped)

unzipped_list = list(unzipped_object)
print(unzipped_list)

[('a', 'b'), (1, 2)]


In [5]:
# https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch
import torch
a = [torch.tensor([1,2,3]), torch.tensor([3, 4, 5, 6])]
b = torch.nn.utils.rnn.pad_sequence(a, batch_first=True)
print(a)
print(b)
c = torch.nn.utils.rnn.pack_padded_sequence(b, batch_first=True, lengths=[3,4], enforce_sorted=False)
torch.nn.utils.rnn.pad_packed_sequence(c)

[tensor([1, 2, 3]), tensor([3, 4, 5, 6])]
tensor([[1, 2, 3, 0],
        [3, 4, 5, 6]])


(tensor([[1, 3],
         [2, 4],
         [3, 5],
         [0, 6]]), tensor([3, 4]))

In [68]:
from torch.utils import data
from random import choice, randrange
import numpy as np
class ReverseDataset(data.Dataset):
    """
    Inspired from https://talbaumel.github.io/blog/attention/
    https://towardsdatascience.com/attention-seq2seq-with-pytorch-learning-to-invert-a-sequence-34faf4133e53
    """
    def __init__(self, min_length=5, max_length=20, type='train'):
        self.SOS = "<s>"  # id 1
        self.EOS = "</s>" # id 2, id of mask will be 0
        self.characters = list("abcd")
        self.int2char = list(self.characters)
        self.char2int = {c: i+3 for i, c in enumerate(self.characters)}
        self.VOCAB_SIZE = len(self.characters)
        self.min_length = min_length
        self.max_length = max_length
        if type=='train':
            self.dataset = [self._sample() for _ in range(3000)]
        else:
            self.dataset = [self._sample() for _ in range(300)]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        return self.dataset[item]

    def _sample(self):
        if self.min_length != self.max_length:
            random_length = randrange(self.min_length, self.max_length)# Pick a random length
        else:
            random_length = self.min_length
        random_char_list = [choice(self.characters[:-1]) for _ in range(random_length)]  # Pick random chars
        random_string = ''.join(random_char_list)
        a = torch.LongTensor([self.char2int.get(x) for x in random_string] + [2])
        b = torch.LongTensor([self.char2int.get(x) for x in random_string[::-1]] + [2]) # Return the random string and its reverse
        #x = np.zeros((random_length, self.VOCAB_SIZE))
        #x[np.arange(random_length), a-2] = 1
        return a, b

reverse_dataset = ReverseDataset(4, 4)
reverse_dataset[0]

(tensor([4, 3, 3, 5, 2]), tensor([5, 3, 3, 4, 2]))

In [0]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [0]:
import torch
from torch import nn
import math

from torch.utils.data import DataLoader

class AttenEncoderRNN(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(input_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)

  def forward(self, input, hc):
    embedded = self.embedding(input)
    output = embedded
    output, (hidden, cell) = self.lstm(output, hc)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device),
            torch.zeros(1, batch_size, self.hidden_size, device=device))

class AttenDecoderRNN(nn.Module):
  def __init__(self, hidden_size, output_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(output_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
    self.out = nn.Linear(hidden_size, output_size)
    self.combine = nn.Linear(2 * hidden_size, hidden_size)

  def forward(self, input, hc, encode_out):
    embed = self.embedding(input)
    attn = attention(embed, encode_out, encode_out)
    comb = self.combine(torch.cat([embed, attn[0]], -1))
    output = torch.relu(comb)
    output, (hidden, cell) = self.lstm(output, hc)
    #print(output.shape)
    output = self.out(output)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device), 
            torch.zeros(1, batch_size, self.hidden_size, device=device))


def attention(query, key, value, mask=None, dropout=None):
    """Compute 'Scaled Dot Product Attention
    query: N x 1 x D
    key: N x T x D
    value: N x T x D. key and value are the same. query, key and value are the same for self attention
    scores: N x 1 x T
    p_attn: N x 1 x T
    result: N x 1 x D
    
    """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = torch.nn.functional.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

class AttenSeq2Seq(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.encoder = AttenEncoderRNN(input_size, hidden_size)
    self.decoder = AttenDecoderRNN(hidden_size, input_size)
  
  def forward(self, input, expected = None):
    batch_size = input.shape[0]
    h0, c0 = self.encoder.init_hidden(batch_size)
    encode_output, (encode_hidden, encode_cell) = self.encoder(input, (h0, c0))
    # 0 is the start_symbol
    inp = torch.zeros(batch_size, dtype=torch.long, device=device)
    inp = inp.view(batch_size, 1)
    h, c = encode_hidden, encode_cell
    #print(inp.shape)
    output_list = []
    while True:
      if expected == None:
        out, (h, c) = self.decoder(inp, (h, c), encode_output)
      # teacher forcing
      else:
        pass
      output_list.append(out.squeeze(1))
      if len(output_list) == input.shape[1]:
        break
    # print(output_list[0].shape)
    return torch.stack(output_list, 2)



In [78]:
reverse_dataset = ReverseDataset(3, 15)

loader = DataLoader(reverse_dataset, 8, collate_fn=pad_collate)

model = AttenSeq2Seq(len(reverse_dataset.char2int) + 3, 128).to(device)

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
  sum_criterion = 0
  n_instance = 0
  for x, y, _, _ in loader:
    n_instance += x.shape[0]
    optimizer.zero_grad()
    res = model(x.to(device))
    loss = criterion(res, y.to(device), ignore_index=0)
    loss.backward()
    sum_criterion += loss.item()
    
    optimizer.step()

  print(f'{epoch}: {sum_criterion/n_instance}')

0: 0.10569670132795969
1: 0.04897627769162258
2: 0.029339559278140467
3: 0.02161625017412007
4: 0.017269706373723846
5: 0.014299613387634356
6: 0.011850742060380677
7: 0.010919193015123407
8: 0.009290929907467216
9: 0.008049904896334434


In [87]:
print(y)

res = model(x.to(device))

predict = res.argmax(dim=1)
print(predict)
y.masked_fill_(y==0, 2)

print(torch.sum(predict == y.to(device)) / float(res.shape[0] * res.shape[2]))

tensor([[4, 3, 3, 5, 4, 4, 5, 5, 4, 3, 2],
        [5, 5, 3, 3, 4, 4, 3, 3, 3, 3, 2],
        [5, 4, 3, 5, 5, 3, 3, 4, 5, 2, 0],
        [3, 3, 3, 4, 5, 4, 3, 5, 3, 2, 0],
        [5, 4, 5, 4, 4, 5, 2, 0, 0, 0, 0],
        [4, 5, 5, 5, 3, 2, 0, 0, 0, 0, 0],
        [4, 5, 3, 2, 0, 0, 0, 0, 0, 0, 0],
        [3, 4, 4, 2, 0, 0, 0, 0, 0, 0, 0]])
tensor([[4, 3, 3, 5, 4, 4, 5, 5, 4, 3, 2],
        [5, 5, 3, 3, 4, 4, 3, 3, 3, 3, 2],
        [5, 4, 3, 5, 5, 3, 3, 4, 5, 2, 2],
        [3, 3, 3, 4, 5, 4, 3, 5, 3, 2, 2],
        [5, 4, 5, 4, 4, 5, 2, 2, 2, 2, 2],
        [4, 5, 5, 5, 3, 2, 2, 2, 2, 2, 2],
        [4, 5, 3, 2, 2, 2, 2, 2, 2, 2, 2],
        [3, 4, 4, 2, 2, 2, 2, 2, 2, 2, 2]], device='cuda:0')
tensor(1., device='cuda:0')


In [80]:
a = torch.LongTensor([[3, 3, 4, 5, 2], [5, 4, 3, 3, 2]]).to(device)
model(a).argmax(1)

tensor([[5, 4, 3, 3, 2],
        [3, 3, 4, 5, 2]], device='cuda:0')

In [88]:
reverse_dataset_val = ReverseDataset(3, 10, type="test")
val_loader = DataLoader(reverse_dataset_val, 1, collate_fn=pad_collate)
with torch.no_grad():
  ncorrect = 0
  nwrong = 0
  for x, y, _, _ in val_loader:
    y.masked_fill_(y==0, 2)
    predict = model(x.to(device)).argmax(1)
    if torch.equal(y.to(device), predict):
      ncorrect += 1
    else:
      nwrong += 1
print(f'{ncorrect}, {nwrong}')


297, 3


# Test PeriodicDataset

In [19]:
import numpy as np
import random
import torch
from random import choice, randrange

class PeriodicSeriesDataset(torch.utils.data.Dataset):
  def __init__(self, sequence, num, min_length=5, max_length=15):
    
    self.sequence = sequence
    self.start_symbol = "<s>"
    self.end_symbol = "</s>"
    self.min_length = min_length
    self.max_length = max_length

    self.x = []
    self.y = []
    self.id2word = {i+3: w for (i, w) in enumerate(sequence)}
    self.id2word[0] = self.start_symbol
    self.id2word[1] = self.end_symbol
    self.word2id = {w: i for (i, w) in self.id2word.items()}
    seq2id = np.array([self.word2id[w] for w in sequence])
    print(sequence)
    place_holder = np.zeros(len(sequence) + 2, dtype=np.int)
    place_holder[-1] = 1
    for i in range(num):
      xi, yi = self._sample(self.min_length, self.max_length)
      xi = np.array([self.word2id[e] for e in xi] + [2])
      yi = np.array([self.word2id[e] for e in yi] + [2])

      self.x.append(xi)
      self.y.append(yi)

      
  def __len__(self):
    return len(self.x)
  
  def _sample(self, min_length, max_length):
    random_length = randrange(min_length, max_length)                        # Pick a random length
    random_char_list = [choice(self.sequence) for _ in range(random_length)]  # Pick random chars
    random_string = ''.join(random_char_list) 
    return random_string, random_string[-1] + random_string[0:-1]  # Return the random string and its shift

  def onehot_seq(self, word_seq):
    num_seq = [self.word2id[w] for w in word_seq]
    return self.onehot_num(num_seq)
  
  def onehot_num(self, num_seq):
    y = torch.LongTensor(num_seq).view(-1, 1)
    onehot = torch.FloatTensor(len(num_seq), len(self.word2id))
    onehot.zero_()
    onehot.scatter_(1, y, 1)
    return onehot
  
  def onecold_num(self, tensor):
    dim_n = tensor.shape[0]
    dim_c = tensor.shape[1]
    onecold = tensor.argmax(dim=1)
    return onecold

  def onecold_seq(self, tensor):
    onecold = self.onecold_num(tensor)
    print(onecold)
    return [self.id2word[i.item()] for i in onecold]

  def __getitem__(self, index):
    return  self.x[index], self.y[index]

import string
sequence = list(string.ascii_letters[:6])
pseries = PeriodicSeriesDataset(sequence, 10)
for i in range(len(pseries)):
  print(pseries[i])


['a', 'b', 'c', 'd', 'e', 'f']
(array([7, 7, 2, 6, 2, 7, 6, 4, 7, 3, 1]), array([3, 7, 7, 2, 6, 2, 7, 6, 4, 7, 1]))
(array([4, 4, 6, 7, 5, 1]), array([5, 4, 4, 6, 7, 1]))
(array([2, 5, 5, 4, 4, 5, 1]), array([5, 2, 5, 5, 4, 4, 1]))
(array([2, 6, 6, 7, 2, 2, 1]), array([2, 2, 6, 6, 7, 2, 1]))
(array([5, 5, 4, 3, 6, 4, 6, 4, 4, 4, 6, 6, 4, 1]), array([4, 5, 5, 4, 3, 6, 4, 6, 4, 4, 4, 6, 6, 1]))
(array([3, 6, 6, 2, 4, 6, 1]), array([6, 3, 6, 6, 2, 4, 1]))
(array([7, 6, 6, 2, 7, 2, 1]), array([2, 7, 6, 6, 2, 7, 1]))
(array([2, 5, 3, 7, 4, 5, 1]), array([5, 2, 5, 3, 7, 4, 1]))
(array([4, 7, 6, 4, 6, 7, 6, 7, 5, 7, 5, 1]), array([5, 4, 7, 6, 4, 6, 7, 6, 7, 5, 7, 1]))
(array([6, 7, 4, 3, 5, 6, 3, 4, 6, 3, 4, 6, 5, 1]), array([5, 6, 7, 4, 3, 5, 6, 3, 4, 6, 3, 4, 6, 1]))


In [26]:
sequence = list(string.ascii_letters[:4])
pseries = PeriodicSeriesDataset(sequence, 3000, 5, 25)
loader = DataLoader(pseries, 1)

model = Seq2Seq(len(pseries.word2id), 20).to(device)

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.RMSprop(model.parameters())
for epoch in range(10):
  sum_criterion = 0
  n_instance = 0
  for x, y in loader:
    x = x.to(device)
    y = y.to(device)
    n_instance += x.shape[0]
    optimizer.zero_grad()
    res = model(x.to(device))
    loss = criterion(res, y)
    loss.backward()
    sum_criterion += loss.item()
    
    optimizer.step()

  print(f'{epoch}: {sum_criterion/n_instance}')
  #print(res.shape)



['a', 'b', 'c', 'd']
0: 1.1465233557124932
1: 1.084906201571226
2: 1.003747613226374
3: 0.9052569772948822
4: 0.8723418232897917
5: 0.8270123624453942
6: 0.8303878551271434
7: 0.8211065677969406
8: 0.7931944092025515
9: 0.8202313436104063


In [27]:
periodic_dataset_val = PeriodicSeriesDataset(list(string.ascii_letters[:4]), 300, 5, 10)
val_loader = DataLoader(periodic_dataset_val, 1)
with torch.no_grad():
  ncorrect = 0
  nwrong = 0
  for x, y in val_loader:
    predict = model(x.to(device)).argmax(1)
    if torch.equal(y, predict):
      ncorrect += 1
    else:
      nwrong += 1
print(f'{ncorrect}, {nwrong}')

['a', 'b', 'c', 'd']
37, 263


In [54]:
sequence = list(string.ascii_letters[:4])
pseries = PeriodicSeriesDataset(sequence, 3000, 5, 15)
loader = DataLoader(pseries, 1)

model = AttenSeq2Seq(len(pseries.word2id), 64).to(device)

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adadelta(model.parameters())
for epoch in range(50):
  sum_criterion = 0
  n_instance = 0
  for x, y in loader:
    n_instance += x.shape[0]
    optimizer.zero_grad()
    res = model(x.to(device))
    loss = criterion(res, y.to(device))
    loss.backward()
    sum_criterion += loss.item()
    
    optimizer.step()

  print(f'{epoch}: {sum_criterion/n_instance}')
  #print(res.shape)


['a', 'b', 'c', 'd']
0: 0.8852485108760496
1: 0.4773196799615398
2: 0.3441714658272443
3: 0.285378480357814
4: 0.23895165590881576
5: 0.20279428705900182
6: 0.18678215551349148
7: 0.17078311624676734
8: 0.1565253962141226
9: 0.14512069364028796
10: 0.1397529769297761
11: 0.1287329624556019
12: 0.13225862203542066
13: 0.11700433067537519
14: 0.10964543732545648
15: 0.10374334436085451
16: 0.0990335777516721
17: 0.10557876518105624
18: 0.09982724089555506
19: 0.09396270630399337
20: 0.09504352666219917
21: 0.08361534974674716
22: 0.08701855717638705
23: 0.07914577496590669
24: 0.08435461946243918
25: 0.07939490075371557
26: 0.07472918648230331
27: 0.073944263521053
28: 0.07448297275192384
29: 0.06334762630109375
30: 0.06152762790013714
31: 0.05677910757635192
32: 0.055703185642011516
33: 0.0519143622999841
34: 0.05534165273213729
35: 0.051158557703565824
36: 0.05709209493717126
37: 0.04981925260335557
38: 0.05458240390511961
39: 0.04463825441883562
40: 0.04654023700266573
41: 0.043485546

In [0]:
model(
    torch.LongTensor(
    [[2, 3, 4, 5, 6, 7, 1],
     [4, 5, 6, 7, 2, 3, 1]
     ]).to(device)
).argmax(dim=1)

In [0]:
model(
    torch.LongTensor(
    [[2, 3, 4, 5, 1],
     [3, 4, 5, 6, 1]
     ]).to(device)
).argmax(dim=1)

In [0]:
model(
    torch.LongTensor(
    [[3, 5, 6, 2, 7, 4, 1],
     [5, 6, 2, 3, 7, 4, 1]
     ]).to(device)
).argmax(dim=1)

In [58]:
periodic_dataset_val = PeriodicSeriesDataset(list(string.ascii_letters[:4]), 300, 10, 15)
val_loader = DataLoader(periodic_dataset_val, 1)
with torch.no_grad():
  ncorrect = 0
  nwrong = 0
  for x, y in val_loader:
    predict = model(x.to(device)).argmax(1)
    if torch.equal(y, predict):
      ncorrect += 1
    else:
      nwrong += 1
print(f'{ncorrect}, {nwrong}')

['a', 'b', 'c', 'd']
103, 197
