<a href="https://colab.research.google.com/github/fannix/timeseries_generation/blob/master/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import torch
batch_size = 5
nb_digits = 10
y = torch.LongTensor(batch_size,1).random_() % nb_digits
y
y_onehot = torch.FloatTensor(batch_size, nb_digits)
y_onehot.zero_()
y_onehot.scatter_(1, y, 1)

tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
import torch
from torch import nn

from torch.utils.data import DataLoader


class DecoderRNN(nn.Module):
  def __init__(self, hidden_size, output_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(output_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
    self.out = nn.Linear(hidden_size, output_size)

  def forward(self, input, hc):
    output = self.embedding(input)
    output = torch.relu(output)
    output, (hidden, cell) = self.lstm(output, hc)
    #print(output.shape)
    output = self.out(output)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device), 
            torch.zeros(1, batch_size, self.hidden_size, device=device))


class EncoderRNN(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(input_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)

  def forward(self, input, hc):
    embedded = self.embedding(input)
    output = embedded
    output, (hidden, cell) = self.lstm(output, hc)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device),
            torch.zeros(1, batch_size, self.hidden_size, device=device))

class Seq2Seq(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.encoder = EncoderRNN(input_size, hidden_size)
    self.decoder = DecoderRNN(hidden_size, input_size)
  
  def forward(self, input, expected = None):
    batch_size = input.shape[0]
    h0, c0 = self.encoder.init_hidden(batch_size)
    encode_output, (encode_hidden, encode_cell) = self.encoder(input, (h0, c0))

    inp = torch.zeros(batch_size, dtype=torch.long, device=device)
    inp = inp.view(batch_size, 1)
    h, c = encode_hidden, encode_cell

    #print(inp.shape)
    output_list = []
    while True:
      if expected == None:
        out, (h, c) = self.decoder(inp, (h, c))
      # teacher forcing
      else:
        pass
      output_list.append(out.squeeze(1))
      if len(output_list) == input.shape[1]:
        break
    # print(output_list[0].shape)
    return torch.stack(output_list, 2)



Attention LSTM Seq2Seq

In [12]:
from torch.utils import data
from random import choice, randrange
import numpy as np
class ReverseDataset(data.Dataset):
    """
    Inspired from https://talbaumel.github.io/blog/attention/
    """
    def __init__(self, min_length=5, max_length=20, type='train'):
        self.SOS = "<s>"  # id 0
        self.EOS = "</s>" # id 1
        self.characters = list("abcd")
        self.int2char = list(self.characters)
        self.char2int = {c: i+2 for i, c in enumerate(self.characters)}
        self.VOCAB_SIZE = len(self.characters)
        self.min_length = min_length
        self.max_length = max_length
        if type=='train':
            self.set = [self._sample() for _ in range(3000)]
        else:
            self.set = [self._sample() for _ in range(300)]

    def __len__(self):
        return len(self.set)

    def __getitem__(self, item):
        return self.set[item]

    def _sample(self):
        if self.min_length != self.max_length:
            random_length = randrange(self.min_length, self.max_length)# Pick a random length
        else:
            random_length = self.min_length
        random_char_list = [choice(self.characters[:-1]) for _ in range(random_length)]  # Pick random chars
        random_string = ''.join(random_char_list)
        a = np.array([self.char2int.get(x) for x in random_string] + [1])
        b = np.array([self.char2int.get(x) for x in random_string[::-1]] + [1]) # Return the random string and its reverse
        #x = np.zeros((random_length, self.VOCAB_SIZE))
        #x[np.arange(random_length), a-2] = 1
        return a, b

reverse_dataset = ReverseDataset(4, 4)
reverse_dataset[0]

(array([2, 2, 2, 2, 1]), array([2, 2, 2, 2, 1]))

In [0]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [0]:
import torch
from torch import nn
import math

from torch.utils.data import DataLoader

class AttenEncoderRNN(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(input_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)

  def forward(self, input, hc):
    embedded = self.embedding(input)
    output = embedded
    output, (hidden, cell) = self.lstm(output, hc)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device),
            torch.zeros(1, batch_size, self.hidden_size, device=device))

class AttenDecoderRNN(nn.Module):
  def __init__(self, hidden_size, output_size):
    super().__init__()
    self.hidden_size = hidden_size

    self.embedding = nn.Embedding(output_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
    self.out = nn.Linear(hidden_size, output_size)
    self.combine = nn.Linear(2 * hidden_size, hidden_size)

  def forward(self, input, hc, encode_out):
    embed = self.embedding(input)
    attn = attention(embed, encode_out, encode_out)
    comb = self.combine(torch.cat([embed, attn[0]], -1))
    output = torch.relu(comb)
    output, (hidden, cell) = self.lstm(output, hc)
    #print(output.shape)
    output = self.out(output)
    return output, (hidden, cell)

  def init_hidden(self, batch_size):
    return (torch.zeros(1, batch_size, self.hidden_size, device=device), 
            torch.zeros(1, batch_size, self.hidden_size, device=device))


def attention(query, key, value, mask=None, dropout=None):
    """Compute 'Scaled Dot Product Attention
    query: N x 1 x D
    key: N x T x D
    value: N x T x D. key and value are the same. query, key and value are the same for self attention
    scores: N x 1 x T
    p_attn: N x 1 x T
    result: N x 1 x D
    
    """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = torch.nn.functional.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

class AttenSeq2Seq(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.encoder = AttenEncoderRNN(input_size, hidden_size)
    self.decoder = AttenDecoderRNN(hidden_size, input_size)
  
  def forward(self, input, expected = None):
    batch_size = input.shape[0]
    h0, c0 = self.encoder.init_hidden(batch_size)
    encode_output, (encode_hidden, encode_cell) = self.encoder(input, (h0, c0))
    # 0 is the start_symbol
    inp = torch.zeros(batch_size, dtype=torch.long, device=device)
    inp = inp.view(batch_size, 1)
    h, c = encode_hidden, encode_cell
    #print(inp.shape)
    output_list = []
    while True:
      if expected == None:
        out, (h, c) = self.decoder(inp, (h, c), encode_output)
      # teacher forcing
      else:
        pass
      output_list.append(out.squeeze(1))
      if len(output_list) == input.shape[1]:
        break
    # print(output_list[0].shape)
    return torch.stack(output_list, 2)



In [15]:
reverse_dataset = ReverseDataset(3, 10)

loader = DataLoader(reverse_dataset, 1)

model = AttenSeq2Seq(len(reverse_dataset.char2int) + 2, 128).to(device)

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
  sum_criterion = 0
  n_instance = 0
  for x, y in loader:
    n_instance += x.shape[0]
    optimizer.zero_grad()
    res = model(x.to(device))
    loss = criterion(res, y.to(device))
    loss.backward()
    sum_criterion += loss.item()
    
    optimizer.step()

  print(f'{epoch}: {sum_criterion/n_instance}')

0: 0.23522539879724597
1: 0.06295872043104948
2: 0.03875238106506943
3: 0.0279807736170986
4: 0.01984829056529664
5: 0.016359679112511988
6: 0.013531920917546092
7: 0.011737098647560382
8: 0.012188883421460503
9: 0.01168644904931591


In [16]:
print(y)

res = model(x.to(device))

print(res.argmax(dim=1))

print(torch.sum(res.argmax(dim=1) == y.to(device)) / float(res.shape[0] * res.shape[2]))

tensor([[4, 4, 2, 3, 1]])
tensor([[4, 4, 2, 3, 1]])
tensor(1.)


In [17]:
a = torch.LongTensor([[2, 3, 4, 2, 1], [4, 3, 2, 3, 1]]).to(device)
model(a).argmax(1)

tensor([[2, 4, 3, 2, 1],
        [3, 2, 3, 4, 1]])

In [18]:
reverse_dataset_val = ReverseDataset(3, 10, type="test")
val_loader = DataLoader(reverse_dataset_val, 1)
with torch.no_grad():
  ncorrect = 0
  nwrong = 0
  for x, y in val_loader:
    predict = model(x.to(device)).argmax(1)
    if torch.equal(y, predict):
      ncorrect += 1
    else:
      nwrong += 1
print(f'{ncorrect}, {nwrong}')


293, 7


Test PeriodicDataset

In [19]:
import numpy as np
import random
import torch
from random import choice, randrange

class PeriodicSeriesDataset(torch.utils.data.Dataset):
  def __init__(self, sequence, num, min_length=5, max_length=15):
    
    self.sequence = sequence
    self.start_symbol = "<s>"
    self.end_symbol = "</s>"
    self.min_length = min_length
    self.max_length = max_length

    self.x = []
    self.y = []
    self.id2word = {i+2: w for (i, w) in enumerate(sequence)}
    self.id2word[0] = self.start_symbol
    self.id2word[1] = self.end_symbol
    self.word2id = {w: i for (i, w) in self.id2word.items()}
    seq2id = np.array([self.word2id[w] for w in sequence])
    print(sequence)
    place_holder = np.zeros(len(sequence) + 2, dtype=np.int)
    place_holder[-1] = 1
    for i in range(num):
      xi, yi = self._sample(self.min_length, self.max_length)
      xi = np.array([self.word2id[e] for e in xi] + [1])
      yi = np.array([self.word2id[e] for e in yi] + [1])

      self.x.append(xi)
      self.y.append(yi)

      
  def __len__(self):
    return len(self.x)
  
  def _sample(self, min_length, max_length):
    random_length = randrange(min_length, max_length)                             # Pick a random length
    random_char_list = [choice(self.sequence) for _ in range(random_length)]  # Pick random chars
    random_string = ''.join(random_char_list) 
    return random_string, random_string[-1] + random_string[0:-1]  # Return the random string and its shift

  def onehot_seq(self, word_seq):
    num_seq = [self.word2id[w] for w in word_seq]
    return self.onehot_num(num_seq)
  
  def onehot_num(self, num_seq):
    y = torch.LongTensor(num_seq).view(-1, 1)
    onehot = torch.FloatTensor(len(num_seq), len(self.word2id))
    onehot.zero_()
    onehot.scatter_(1, y, 1)
    return onehot
  
  def onecold_num(self, tensor):
    dim_n = tensor.shape[0]
    dim_c = tensor.shape[1]
    onecold = tensor.argmax(dim=1)
    return onecold

  def onecold_seq(self, tensor):
    onecold = self.onecold_num(tensor)
    print(onecold)
    return [self.id2word[i.item()] for i in onecold]

  def __getitem__(self, index):
    return  self.x[index], self.y[index]

import string
sequence = list(string.ascii_letters[:6])
pseries = PeriodicSeriesDataset(sequence, 10)
for i in range(len(pseries)):
  print(pseries[i])


['a', 'b', 'c', 'd', 'e', 'f']
(array([7, 7, 2, 6, 2, 7, 6, 4, 7, 3, 1]), array([3, 7, 7, 2, 6, 2, 7, 6, 4, 7, 1]))
(array([4, 4, 6, 7, 5, 1]), array([5, 4, 4, 6, 7, 1]))
(array([2, 5, 5, 4, 4, 5, 1]), array([5, 2, 5, 5, 4, 4, 1]))
(array([2, 6, 6, 7, 2, 2, 1]), array([2, 2, 6, 6, 7, 2, 1]))
(array([5, 5, 4, 3, 6, 4, 6, 4, 4, 4, 6, 6, 4, 1]), array([4, 5, 5, 4, 3, 6, 4, 6, 4, 4, 4, 6, 6, 1]))
(array([3, 6, 6, 2, 4, 6, 1]), array([6, 3, 6, 6, 2, 4, 1]))
(array([7, 6, 6, 2, 7, 2, 1]), array([2, 7, 6, 6, 2, 7, 1]))
(array([2, 5, 3, 7, 4, 5, 1]), array([5, 2, 5, 3, 7, 4, 1]))
(array([4, 7, 6, 4, 6, 7, 6, 7, 5, 7, 5, 1]), array([5, 4, 7, 6, 4, 6, 7, 6, 7, 5, 7, 1]))
(array([6, 7, 4, 3, 5, 6, 3, 4, 6, 3, 4, 6, 5, 1]), array([5, 6, 7, 4, 3, 5, 6, 3, 4, 6, 3, 4, 6, 1]))


In [26]:
sequence = list(string.ascii_letters[:4])
pseries = PeriodicSeriesDataset(sequence, 3000, 5, 25)
loader = DataLoader(pseries, 1)

model = Seq2Seq(len(pseries.word2id), 20).to(device)

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.RMSprop(model.parameters())
for epoch in range(10):
  sum_criterion = 0
  n_instance = 0
  for x, y in loader:
    x = x.to(device)
    y = y.to(device)
    n_instance += x.shape[0]
    optimizer.zero_grad()
    res = model(x.to(device))
    loss = criterion(res, y)
    loss.backward()
    sum_criterion += loss.item()
    
    optimizer.step()

  print(f'{epoch}: {sum_criterion/n_instance}')
  #print(res.shape)



['a', 'b', 'c', 'd']
0: 1.1465233557124932
1: 1.084906201571226
2: 1.003747613226374
3: 0.9052569772948822
4: 0.8723418232897917
5: 0.8270123624453942
6: 0.8303878551271434
7: 0.8211065677969406
8: 0.7931944092025515
9: 0.8202313436104063


In [27]:
periodic_dataset_val = PeriodicSeriesDataset(list(string.ascii_letters[:4]), 300, 5, 10)
val_loader = DataLoader(periodic_dataset_val, 1)
with torch.no_grad():
  ncorrect = 0
  nwrong = 0
  for x, y in val_loader:
    predict = model(x.to(device)).argmax(1)
    if torch.equal(y, predict):
      ncorrect += 1
    else:
      nwrong += 1
print(f'{ncorrect}, {nwrong}')

['a', 'b', 'c', 'd']
37, 263


In [41]:
sequence = list(string.ascii_letters[:4])
pseries = PeriodicSeriesDataset(sequence, 3000, 5, 25)
loader = DataLoader(pseries, 1)

model = AttenSeq2Seq(len(pseries.word2id), 128).to(device)

criterion = torch.nn.functional.cross_entropy
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
for epoch in range(50):
  sum_criterion = 0
  n_instance = 0
  for x, y in loader:
    n_instance += x.shape[0]
    optimizer.zero_grad()
    res = model(x.to(device))
    loss = criterion(res, y.to(device))
    loss.backward()
    sum_criterion += loss.item()
    
    optimizer.step()

  print(f'{epoch}: {sum_criterion/n_instance}')
  #print(res.shape)


['a', 'b', 'c', 'd']
0: 1.0160587167243162
1: 0.6652138420602617
2: 0.5581542482101165
3: 0.48392349243724797
4: 0.4288947877885609
5: 0.3931740997308528
6: 0.3680704598834224
7: 0.343933614345459
8: 0.32918605117152766
9: 0.31000236260458164
10: 0.297092416047979
11: 0.28825675241621257
12: 0.27383098454618116
13: 0.25912842366799155
14: 0.2540340957442128
15: 0.24029351274432095
16: 0.23060962693099615
17: 0.22694099831506523
18: 0.21709496056560193
19: 0.21192820261729337
20: 0.201755728360938
21: 0.19749232836426153
22: 0.1906861965133634
23: 0.1832847593115035
24: 0.1780596475661238
25: 0.17259695795379934
26: 0.16722349216807833
27: 0.16349819274238397
28: 0.1589059991533127
29: 0.15151650747277157
30: 0.148468290570123
31: 0.14431009385613738
32: 0.14210995170683835
33: 0.13651415346922227
34: 0.1330155870912129
35: 0.12854019535057506
36: 0.12647350587811146
37: 0.12604638212449595
38: 0.12045280929109606
39: 0.11589173664028897
40: 0.11827887420723782
41: 0.11187306652833988
4

In [0]:
model(
    torch.LongTensor(
    [[2, 3, 4, 5, 6, 7, 1],
     [4, 5, 6, 7, 2, 3, 1]
     ]).to(device)
).argmax(dim=1)

In [0]:
model(
    torch.LongTensor(
    [[2, 3, 4, 5, 1],
     [3, 4, 5, 6, 1]
     ]).to(device)
).argmax(dim=1)

In [0]:
model(
    torch.LongTensor(
    [[3, 5, 6, 2, 7, 4, 1],
     [5, 6, 2, 3, 7, 4, 1]
     ]).to(device)
).argmax(dim=1)

In [45]:
periodic_dataset_val = PeriodicSeriesDataset(list(string.ascii_letters[:4]), 300, 5, 10)
val_loader = DataLoader(periodic_dataset_val, 1)
with torch.no_grad():
  ncorrect = 0
  nwrong = 0
  for x, y in val_loader:
    predict = model(x.to(device)).argmax(1)
    if torch.equal(y, predict):
      ncorrect += 1
    else:
      nwrong += 1
print(f'{ncorrect}, {nwrong}')

['a', 'b', 'c', 'd']
271, 29
