In [137]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

In [144]:
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [143]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [32]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask,
                            tgt, tgt_mask)

    def encode(self, src, src_mask):
    
        return self.encoder(self.src_embed(src).to(device), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt).to(device), memory, src_mask, tgt_mask)

In [33]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return self.proj(x) 

In [34]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [35]:
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [36]:
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        """ Return a2 * x_normalized + b2,
        where x_normalized is calculated by subtracting row-wise means from x and dividing the result by row-wise standard deviation + eps.
        standard deviation is calculated with Bessel's correction (the default in Pytorch)
        """

        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [37]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

In [38]:
class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [39]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [40]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [41]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [42]:
def attention(query, key, value, mask=None, dropout=None):
    """Compute 'Scaled Dot Product Attention'
    """
    d_k = query.size(-1)
    # todo: compute the attention scores by using torch.matmul
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:

        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # todo: compute the result as the values weighted by attention probabilities (again, using torch.matmul)
    result = torch.matmul(p_attn, value)
    return result, p_attn

In [43]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask,
                                 dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

In [44]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [118]:
class Embeddings(nn.Module):
    def __init__(self, d_model, d_vocab):
        super().__init__()
        
        self.transform = nn.Linear(d_vocab, d_model, dtype=float).to(device)
        self.d_vocab = d_vocab
        self.d_model = d_model

    def forward(self, x):
        len_ = int(len(x[0]) / self.d_vocab )
        trajs_emb = torch.zeros(len(x), len_ , self.d_model, requires_grad=True).to(device)
        
        for j, x_tmp in enumerate(x):
            traj_emb = torch.zeros(len_ , self.d_model, requires_grad=False).to(device)
            
            for i in range(self.d_vocab, len_ + 1, self.d_vocab):
                point = x_tmp[i - self.d_vocab:i]          
                traj_emb[i - self.d_vocab, :] =  self.transform(point)
            trajs_emb[j,:,:] =  traj_emb

        return trajs_emb * math.sqrt(self.d_model)

In [51]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)
        return self.dropout(x)

In [52]:
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy  # use it for attn, ffn, and position in the model layers
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    # todo: insert correct arguments into the EncoderDecoder constructor.
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform(p)
    return model

In [53]:

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [54]:
def loss_backprop(generator, criterion, out, targets):
    """
    Memory optmization. Compute each timestep separately and sum grads.
    """
    assert out.size(1) == targets.size(1)
    targets = targets.to(torch.float32)
    
    total = 0.0
    out_grad = []
    for i in range(out.size(1)):
        out_column = Variable(out[:, i].data, requires_grad=True)
        gen = generator(out_column)
        loss = criterion(gen, targets[:, i]) 
        total += loss.item() #.data[0]
        loss.backward()
        out_grad.append(out_column.grad.data.clone())
    out_grad = torch.stack(out_grad, dim=1)
    out.backward(gradient=out_grad)
    return total


In [55]:
def loss_(generator, criterion, out, targets):
    """
    Memory optmization. Compute each timestep separately and sum grads.
    """
    assert out.size(1) == targets.size(1)
    targets = targets.to(torch.float32)

          
    total = 0.0
    out_grad = []
    for i in range(out.size(1)):
        out_column = Variable(out[:, i].data, requires_grad=True)
        gen = generator(out_column)
        loss = criterion(gen, targets[:, i]) 
        total += loss.item() #.data[0]
        
    return total


In [56]:
def make_std_mask(src, tgt, pad=-1):
    
    src_mask = (src[-1,:int(src.shape[1]/2)] != pad).unsqueeze(-2)
    tgt_mask = (tgt != pad).unsqueeze(-2)
    tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
    return src_mask, tgt_mask

In [57]:
def train_epoch(train_iter, model, criterion, opt, transpose=False):
    model.train()
    for i, batch in enumerate(train_iter):
        src, trg, src_mask, trg_mask = \
                batch.src.to(device), batch.trg.to(device), batch.src_mask.to(device), batch.trg_mask.to(device)
        
        out = model.forward(src, trg, src_mask, trg_mask)
        loss = loss_backprop(model.generator, criterion, out, trg)

        model_opt.step()
        model_opt.optimizer.zero_grad()

    print('test loss: ', loss)

In [58]:
def valid_epoch(valid_iter, model, criterion, transpose=False):
    model.eval()
    total = 0
    for batch in valid_iter:
        src, trg, src_mask, trg_mask = \
            batch.src.to(device), batch.trg.to(device), batch.src_mask.to(device), batch.trg_mask.to(device)
        out = model.forward(src, trg, src_mask, trg_mask)
        loss = loss_(model.generator, criterion, out, trg)
    print('validation loss: ', loss)

# ____

In [59]:
class Batch:
    def __init__(self, src, trg, src_mask, trg_mask):
        self.src = src
        self.trg = trg
        self.src_mask = src_mask
        self.trg_mask = trg_mask


In [60]:
import random

def data_iterator(srcs, tgts, batch_size=128, shuffle=True):
    if shuffle:
        pairs = list(zip(srcs, tgts))
        random.shuffle(pairs)
        srcs, tgts = [list(t) for t in zip(*pairs)]

    for i in range(0, len(srcs), batch_size):
        x = torch.tensor(srcs[i: i + batch_size])
        y = torch.tensor(tgts[i: i + batch_size])
        src = Variable(x, requires_grad=False)
        tgt = Variable(y, requires_grad=False)
        src_mask, tgt_mask = make_std_mask(src, tgt, -1)
        yield Batch(src, tgt, src_mask, tgt_mask)

In [61]:
with open('trajs.csv', 'rb') as f:
    trajs_train = np.load(f)
    
with open('labels.csv', 'rb') as f:
    labels_train = np.load(f)
    
with open('trajs_test.csv', 'rb') as f:
    trajs_test = np.load(f)
    
with open('labels_test.csv', 'rb') as f:
    labels_test = np.load(f)


In [62]:

N, d, T = trajs_train.shape 

trajs_ = np.zeros((N, d * T + 2))

for i in range(N):
    for j in range(T):
        trajs_[i , d * j + 2] = trajs_train[i, 0, j]
        trajs_[i , d * j + 1 + 2] = trajs_train[i, 1, j]
        
trajs_train = trajs_
 

In [63]:

N, d, T = trajs_test.shape 

trajs_ = np.zeros((N, d * T + 2))

for i in range(N):
    for j in range(T):
        trajs_[i , d * j + 2] = trajs_test[i, 0, j]
        trajs_[i , d * j + 1 + 2] = trajs_test[i, 1, j]

trajs_test = trajs_

In [64]:

N, _, T = labels_train.shape

labels_ = np.zeros((N, T + 1))

labels_[:, 1:] =  labels_train[:,0, :]

labels_train = labels_


In [65]:

N, _, T = labels_test.shape

labels_ = np.zeros((N, T + 1))

labels_[:, 1:] =  labels_test[:,0, :]

labels_test = labels_


In [119]:

model = make_model(2, 1, N=6, d_model=256, d_ff=1023, h=4, dropout=0.1)
model_opt = get_std_opt(model)

criterion = nn.MSELoss()
criterion.to(device)

model.to(device)

  nn.init.xavier_uniform(p)


EncoderDecoder(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=256, out_features=256, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=256, out_features=1023, bias=True)
          (w_2): Linear(in_features=1023, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0-1): 2 x SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=256, out_

In [120]:
def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 11,055,605 trainable parameters


In [121]:
BATCH_SIZE = 2

In [122]:
from tqdm import tqdm

In [129]:
epoches = 100

model.to(device)
for epoch in tqdm(range(epoches)):
    print(epoch,' epoch')
    train_epoch(data_iterator(trajs_train, labels_train, batch_size=BATCH_SIZE), model, criterion, model_opt)
    valid_epoch(data_iterator(trajs_test, labels_test, batch_size=BATCH_SIZE), model, criterion)

  0%|          | 0/100 [00:00<?, ?it/s]

0  epoch
test loss:  12.353393045719713


  1%|          | 1/100 [02:06<3:28:56, 126.63s/it]

validation loss:  8.596023071557283
1  epoch
test loss:  11.056173251302198


  2%|▏         | 2/100 [04:14<3:27:38, 127.13s/it]

validation loss:  10.31643023295328
2  epoch
test loss:  11.220582989008108


  3%|▎         | 3/100 [06:22<3:26:29, 127.72s/it]

validation loss:  13.177885972614604
3  epoch
test loss:  3.457066885428503


  4%|▍         | 4/100 [08:31<3:24:59, 128.12s/it]

validation loss:  10.35844816826284
4  epoch
test loss:  8.790072169620544


  5%|▌         | 5/100 [10:40<3:23:15, 128.37s/it]

validation loss:  9.960880408063531
5  epoch
test loss:  10.214355662232265


  6%|▌         | 6/100 [12:50<3:22:12, 129.07s/it]

validation loss:  11.085173905827105
6  epoch
test loss:  3.305491986684501


  7%|▋         | 7/100 [15:00<3:20:29, 129.34s/it]

validation loss:  6.901904291473329
7  epoch
test loss:  10.644994601141661


  8%|▊         | 8/100 [17:10<3:18:27, 129.43s/it]

validation loss:  10.698151991004124
8  epoch
test loss:  14.68811470712535


  9%|▉         | 9/100 [19:20<3:16:43, 129.71s/it]

validation loss:  8.793492661206983
9  epoch
test loss:  8.018836976960301


 10%|█         | 10/100 [21:30<3:14:44, 129.83s/it]

validation loss:  16.426893593743443
10  epoch
test loss:  6.715057938359678


 11%|█         | 11/100 [23:40<3:12:31, 129.79s/it]

validation loss:  9.61343583650887
11  epoch
test loss:  11.097475216374733


 12%|█▏        | 12/100 [25:50<3:10:33, 129.92s/it]

validation loss:  10.20289945451077
12  epoch
test loss:  6.942453833937179


 13%|█▎        | 13/100 [28:00<3:08:21, 129.90s/it]

validation loss:  12.807464355486445
13  epoch
test loss:  11.553165655437624


 14%|█▍        | 14/100 [30:09<3:06:01, 129.78s/it]

validation loss:  7.777079862498795
14  epoch
test loss:  6.134802018976188


 15%|█▌        | 15/100 [32:19<3:03:40, 129.65s/it]

validation loss:  12.39368934907543
15  epoch
test loss:  10.255525258369744


 16%|█▌        | 16/100 [34:28<3:01:29, 129.63s/it]

validation loss:  11.353111563161292
16  epoch
test loss:  9.260426452587126


 17%|█▋        | 17/100 [36:38<2:59:26, 129.71s/it]

validation loss:  11.032842821383383
17  epoch
test loss:  10.977336845418904


 18%|█▊        | 18/100 [38:48<2:57:26, 129.84s/it]

validation loss:  9.933248282730347
18  epoch
test loss:  10.352165657590376


 19%|█▉        | 19/100 [40:58<2:55:17, 129.85s/it]

validation loss:  7.657032254708042
19  epoch
test loss:  15.292524384392891


 20%|██        | 20/100 [43:06<2:52:12, 129.16s/it]

validation loss:  11.458207691612188
20  epoch
test loss:  8.841270274395356


 21%|██        | 21/100 [45:10<2:48:03, 127.64s/it]

validation loss:  7.76948728644129
21  epoch
test loss:  6.728092382290924


 22%|██▏       | 22/100 [47:17<2:45:36, 127.39s/it]

validation loss:  7.232378176428028
22  epoch
test loss:  9.960133607391072


 23%|██▎       | 23/100 [49:26<2:44:25, 128.12s/it]

validation loss:  9.15230563381192
23  epoch
test loss:  8.097235742141493


 24%|██▍       | 24/100 [51:36<2:42:53, 128.59s/it]

validation loss:  9.446001133183017
24  epoch
test loss:  6.891320757858921


 25%|██▌       | 25/100 [53:46<2:41:04, 128.86s/it]

validation loss:  11.608550646164076
25  epoch
test loss:  7.3801003527478315


 26%|██▌       | 26/100 [55:56<2:39:24, 129.26s/it]

validation loss:  8.848377466645616
26  epoch
test loss:  8.29247495670279


 27%|██▋       | 27/100 [58:05<2:37:07, 129.14s/it]

validation loss:  9.833608186057972
27  epoch
test loss:  6.826167179387994


 28%|██▊       | 28/100 [1:00:15<2:35:15, 129.38s/it]

validation loss:  11.767677921583527
28  epoch
test loss:  6.966747706523165


 29%|██▉       | 29/100 [1:02:24<2:33:02, 129.32s/it]

validation loss:  7.334912696237552
29  epoch
test loss:  12.141112956368033


 30%|███       | 30/100 [1:04:32<2:30:27, 128.97s/it]

validation loss:  11.874406106924653
30  epoch
test loss:  11.161167355579892


 31%|███       | 31/100 [1:06:41<2:28:22, 129.03s/it]

validation loss:  6.8786172687341605
31  epoch
test loss:  6.3235927420755615


 32%|███▏      | 32/100 [1:08:51<2:26:31, 129.29s/it]

validation loss:  6.526527373876888
32  epoch
test loss:  8.828858579450753


 33%|███▎      | 33/100 [1:11:02<2:24:51, 129.73s/it]

validation loss:  8.091129322818233
33  epoch
test loss:  7.146202579140663


 34%|███▍      | 34/100 [1:13:11<2:22:38, 129.67s/it]

validation loss:  11.39460564305773
34  epoch
test loss:  6.702386062905134


 35%|███▌      | 35/100 [1:15:21<2:20:31, 129.71s/it]

validation loss:  11.66444567758299
35  epoch
test loss:  7.993433807481779


 36%|███▌      | 36/100 [1:17:30<2:18:13, 129.59s/it]

validation loss:  9.291380362919881
36  epoch
test loss:  12.793355400644941


 37%|███▋      | 37/100 [1:19:40<2:16:06, 129.63s/it]

validation loss:  13.501971273117988
37  epoch
test loss:  13.242958605871536


 38%|███▊      | 38/100 [1:21:50<2:13:59, 129.68s/it]

validation loss:  8.399518037040252
38  epoch
test loss:  11.41202436198364


 39%|███▉      | 39/100 [1:24:00<2:11:52, 129.71s/it]

validation loss:  9.260608546275762
39  epoch
test loss:  9.413220109809117


 40%|████      | 40/100 [1:26:09<2:09:42, 129.71s/it]

validation loss:  6.279893101297375
40  epoch
test loss:  5.643490587448468


 41%|████      | 41/100 [1:28:19<2:07:29, 129.65s/it]

validation loss:  9.110919773302157
41  epoch
test loss:  10.420517732040025


 42%|████▏     | 42/100 [1:30:28<2:05:08, 129.46s/it]

validation loss:  11.420408000325551
42  epoch
test loss:  8.154050631805148


 43%|████▎     | 43/100 [1:32:37<2:02:58, 129.45s/it]

validation loss:  6.852076141127327
43  epoch
test loss:  5.8432140086315485


 44%|████▍     | 44/100 [1:34:47<2:00:58, 129.62s/it]

validation loss:  10.492352018136444
44  epoch
test loss:  13.23743013391504


 45%|████▌     | 45/100 [1:36:57<1:58:44, 129.54s/it]

validation loss:  10.91836094937753
45  epoch
test loss:  7.87130886876389


 46%|████▌     | 46/100 [1:39:04<1:56:04, 128.97s/it]

validation loss:  7.929084602237708
46  epoch
test loss:  7.735694794188021


 47%|████▋     | 47/100 [1:41:14<1:54:12, 129.28s/it]

validation loss:  9.940298016706947
47  epoch
test loss:  8.977290027833078


 48%|████▊     | 48/100 [1:43:24<1:52:07, 129.37s/it]

validation loss:  9.898934865836054
48  epoch
test loss:  10.565220372947806


 49%|████▉     | 49/100 [1:45:33<1:50:00, 129.41s/it]

validation loss:  8.345117092012515
49  epoch
test loss:  9.828887719864724


 50%|█████     | 50/100 [1:47:43<1:47:53, 129.48s/it]

validation loss:  10.353778541306383
50  epoch
test loss:  7.587609370631185


 51%|█████     | 51/100 [1:49:53<1:45:50, 129.59s/it]

validation loss:  6.63427493666677
51  epoch
test loss:  9.610184907971416


 52%|█████▏    | 52/100 [1:52:03<1:43:45, 129.70s/it]

validation loss:  8.752376186457695
52  epoch
test loss:  10.998833653228303


 53%|█████▎    | 53/100 [1:54:13<1:41:46, 129.93s/it]

validation loss:  10.765328848560785
53  epoch
test loss:  6.187025467567764


 54%|█████▍    | 54/100 [1:56:23<1:39:36, 129.92s/it]

validation loss:  5.51660592877306
54  epoch
test loss:  7.949628948961617


 55%|█████▌    | 55/100 [1:58:33<1:37:28, 129.98s/it]

validation loss:  7.839237665757537
55  epoch
test loss:  8.97211130827418


 56%|█████▌    | 56/100 [2:00:43<1:35:16, 129.91s/it]

validation loss:  11.18483519590518
56  epoch
test loss:  12.71942451033101


 57%|█████▋    | 57/100 [2:02:52<1:32:55, 129.66s/it]

validation loss:  7.986456657657982
57  epoch
test loss:  7.086704564550018


 58%|█████▊    | 58/100 [2:05:02<1:30:48, 129.73s/it]

validation loss:  7.285875369521818
58  epoch
test loss:  10.991212853754405


 59%|█████▉    | 59/100 [2:07:12<1:28:40, 129.77s/it]

validation loss:  10.225920750002842
59  epoch
test loss:  12.693123579539133


 60%|██████    | 60/100 [2:09:21<1:26:24, 129.62s/it]

validation loss:  7.7961673803769145
60  epoch
test loss:  8.218163544079289


 61%|██████    | 61/100 [2:11:30<1:24:00, 129.25s/it]

validation loss:  8.287733358331025
61  epoch
test loss:  13.510279802925652


 62%|██████▏   | 62/100 [2:13:40<1:22:00, 129.47s/it]

validation loss:  6.036419651712507
62  epoch
test loss:  8.158580099101528


 63%|██████▎   | 63/100 [2:15:50<1:20:06, 129.90s/it]

validation loss:  8.772586571762076
63  epoch
test loss:  12.916507017310323


 64%|██████▍   | 64/100 [2:18:00<1:17:55, 129.87s/it]

validation loss:  11.44056102165814
64  epoch
test loss:  15.480038380832411


 65%|██████▌   | 65/100 [2:20:09<1:15:36, 129.61s/it]

validation loss:  8.769667258806294
65  epoch
test loss:  11.376756542875228


 66%|██████▌   | 66/100 [2:22:19<1:13:25, 129.56s/it]

validation loss:  9.264590805796615
66  epoch
test loss:  8.332988584967097


 67%|██████▋   | 67/100 [2:24:28<1:11:09, 129.37s/it]

validation loss:  6.907776487831143
67  epoch
test loss:  9.375053485545777


 68%|██████▊   | 68/100 [2:26:38<1:09:06, 129.59s/it]

validation loss:  9.369046247033111
68  epoch
test loss:  8.86476973623212


 69%|██████▉   | 69/100 [2:28:48<1:06:59, 129.67s/it]

validation loss:  12.818123650813504
69  epoch
test loss:  11.767028791036978


 70%|███████   | 70/100 [2:30:57<1:04:51, 129.72s/it]

validation loss:  8.742550191134796
70  epoch
test loss:  5.122460890739603


 71%|███████   | 71/100 [2:33:03<1:02:04, 128.43s/it]

validation loss:  11.45193360222811
71  epoch
test loss:  8.130373724694437


 72%|███████▏  | 72/100 [2:35:12<1:00:00, 128.58s/it]

validation loss:  10.218715263716149
72  epoch
test loss:  10.921623515285319


 73%|███████▎  | 73/100 [2:37:22<58:03, 129.03s/it]  

validation loss:  10.575481093677809
73  epoch
test loss:  9.83635111064359


 74%|███████▍  | 74/100 [2:39:32<56:01, 129.28s/it]

validation loss:  7.751011738408124
74  epoch
test loss:  8.648124794330215


 75%|███████▌  | 75/100 [2:41:41<53:55, 129.42s/it]

validation loss:  9.032896298165724
75  epoch
test loss:  11.442646903542482


 76%|███████▌  | 76/100 [2:43:52<51:51, 129.63s/it]

validation loss:  10.714606296816783
76  epoch
test loss:  10.889110377014731


 77%|███████▋  | 77/100 [2:46:02<49:45, 129.80s/it]

validation loss:  11.589744335462456
77  epoch
test loss:  11.717588716654518


 78%|███████▊  | 78/100 [2:48:11<47:34, 129.74s/it]

validation loss:  8.660244701544798
78  epoch
test loss:  8.421423095747741


 79%|███████▉  | 79/100 [2:50:21<45:26, 129.85s/it]

validation loss:  7.580859518318903
79  epoch
test loss:  6.202185493792058


 80%|████████  | 80/100 [2:52:32<43:22, 130.13s/it]

validation loss:  7.00258925277376
80  epoch
test loss:  11.171329832552146


 81%|████████  | 81/100 [2:54:42<41:11, 130.07s/it]

validation loss:  10.642682166955638
81  epoch
test loss:  11.043448154669022


 82%|████████▏ | 82/100 [2:56:52<38:59, 130.00s/it]

validation loss:  6.487750699103344
82  epoch
test loss:  10.77088086997901


 83%|████████▎ | 83/100 [2:59:02<36:50, 130.00s/it]

validation loss:  9.934893680529058
83  epoch
test loss:  9.545269100060978


 84%|████████▍ | 84/100 [3:01:12<34:38, 129.92s/it]

validation loss:  9.189189942708254
84  epoch
test loss:  11.132458502750978


 85%|████████▌ | 85/100 [3:03:20<32:23, 129.55s/it]

validation loss:  8.009979710772313
85  epoch
test loss:  8.755572279065746


 86%|████████▌ | 86/100 [3:05:30<30:12, 129.44s/it]

validation loss:  10.096502200649411
86  epoch
test loss:  11.182023413377465


 87%|████████▋ | 87/100 [3:07:40<28:04, 129.60s/it]

validation loss:  10.640798848629856
87  epoch
test loss:  6.529343548172619


 88%|████████▊ | 88/100 [3:09:49<25:55, 129.64s/it]

validation loss:  5.317972214639667
88  epoch
test loss:  12.782722738764278


 89%|████████▉ | 89/100 [3:11:59<23:45, 129.59s/it]

validation loss:  9.118426626093424
89  epoch
test loss:  10.586911951730144


 90%|█████████ | 90/100 [3:14:09<21:37, 129.78s/it]

validation loss:  7.149349801557037
90  epoch
test loss:  10.174420150280639


 91%|█████████ | 91/100 [3:16:19<19:28, 129.79s/it]

validation loss:  8.839564019504792
91  epoch
test loss:  11.656854083465078


 92%|█████████▏| 92/100 [3:18:28<17:16, 129.53s/it]

validation loss:  11.371104425838894
92  epoch
test loss:  4.430431308566767


 93%|█████████▎| 93/100 [3:20:37<15:05, 129.33s/it]

validation loss:  7.88724458944489
93  epoch
test loss:  11.693159442933393


 94%|█████████▍| 94/100 [3:22:46<12:56, 129.47s/it]

validation loss:  7.486525497598905
94  epoch
test loss:  11.9001499898568


 95%|█████████▌| 95/100 [3:24:56<10:47, 129.54s/it]

validation loss:  11.944157821017143
95  epoch
test loss:  13.51358822282782


 96%|█████████▌| 96/100 [3:27:06<08:38, 129.66s/it]

validation loss:  7.6344971217567945
96  epoch
test loss:  8.174571157563378


 97%|█████████▋| 97/100 [3:29:16<06:28, 129.61s/it]

validation loss:  10.510656470322829
97  epoch
test loss:  8.213754967611749


 98%|█████████▊| 98/100 [3:31:25<04:18, 129.41s/it]

validation loss:  10.066652244742727
98  epoch
test loss:  7.347303916298188


 99%|█████████▉| 99/100 [3:33:34<02:09, 129.48s/it]

validation loss:  5.1578006982226725
99  epoch
test loss:  11.085944295704394


100%|██████████| 100/100 [3:35:44<00:00, 129.45s/it]

validation loss:  11.106605242477599





# _______

In [130]:
def prediction(model, src, src_mask, max_len, start_symbol):
    
    memory = model.encode(src.to(device), src_mask.to(device))
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    
    pred = torch.zeros(src.shape)

    for i in range(max_len-1):
        out = model.decode(memory, src_mask.to(device),
                           Variable(ys).to(device),
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data).to(device)))
        gen = model.generator(out[:, -1])
        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(*gen[0])], dim=1)
    return ys


In [135]:
n = 21

for batch in data_iterator(trajs_test[n:n+1], labels_test[n:n+1], batch_size=1):
    
    src, trg, src_mask, trg_mask = \
            batch.src.to(device), batch.trg.to(device), batch.src_mask.to(device), batch.trg_mask.to(device)


In [140]:
model.eval()
model.to(device)
max_len = 201
start_symbol = 0

pred = prediction(model, src, src_mask, max_len, start_symbol)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 10.90 GiB of which 1.31 MiB is free. Including non-PyTorch memory, this process has 10.90 GiB memory in use. Of the allocated memory 10.64 GiB is allocated by PyTorch, and 85.98 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

100 эпох

In [133]:
pred

tensor([[0.0000, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0041, 0.0041,
         0.0041, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0041, 0.0041, 0.0041,
         0.0041, 0.0041, 0.0041, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040, 0.0040,
         0.0040, 0.0040, 0.0

In [134]:
trg

tensor([[0.0000, 1.2000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000,
         0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000,
         0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000,
         0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000,
         0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 1.2000, 1.2000, 0.7000, 0.7000,
         0.7000, 0.7000, 0.7000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000,
         1.2000, 1.2000, 1.2

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

pred = pred.detach()

f, axes = plt.subplots(1, 1, sharex=True, figsize=(14, 5))



axes.plot(np.linspace(0, 200, 200), pred, label='prediction')
axes.plot(x_np, x.grad, label=function_name)

x.grad.zero_()

### END Solution (do not delete this comment)

axes[i, 0].legend()
axes[i, 1].legend()

plt.tight_layout()
plt.show()