In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import causal_convolution_layer
import Dataloader

In [2]:
import math
from torch.utils.data import DataLoader

In [3]:
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model=128, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [4]:
class TransformerTimeSeries(torch.nn.Module):
    def __init__(self):
        super(TransformerTimeSeries,self).__init__()
        self.input_embedding = causal_convolution_layer.context_embedding(2,128,5)
        self.transformer_model = torch.nn.Transformer(nhead=16,num_encoder_layers=12,d_model=128)
        self.positional_embedding = PositionalEncoding(128,.1)
        self.fc1 = torch.nn.Linear(128,1)
        
    def forward(self,x_obs,x_future,y_obs,y_future,attention_masks):
        
        # concatenate observed points and time covariate
        # (B*feature_size*n_time_points)
        z_obs = torch.cat((y_obs.unsqueeze(1),x_obs.unsqueeze(1)),1)
        z_future = torch.cat((y_future.unsqueeze(1),x_future.unsqueeze(1)),1)

        # input_embedding returns shape (B*embedding_size*n_time_points) -> need (n_time_points*B,embedding_size)
        z_obs_embedding = self.input_embedding(z_obs).permute(2,0,1)
        z_future_embedding = self.input_embedding(z_future).permute(2,0,1)
        
        # get my positional embeddings
        positional_embeddings = self.positional_embedding(torch.cat((z_obs_embedding,z_future_embedding),0))
        z_obs_embedding = positional_embeddings[0:z_obs_embedding.shape[0]]#self.positional_embedding(z_obs_embedding)
        z_future_embedding = positional_embeddings[z_obs_embedding.shape[0]:]#self.positional_embedding(z_future_embedding)
        
        transformer_embedding = self.transformer_model(z_obs_embedding,z_future_embedding,tgt_mask=attention_masks)

        output = self.fc1(transformer_embedding.permute(1,0,2))
        
        return output
        

In [5]:
x = torch.randn([1,])

In [6]:
train_dataset = Dataloader.time_series_paper(96,4500)
validation_dataset = Dataloader.time_series_paper(96,500)

x: 4500*120 fx: 4500*120
x: 500*120 fx: 500*120


In [7]:
criterion = torch.nn.MSELoss()

In [8]:
train_dl = DataLoader(train_dataset,batch_size=16,shuffle=True)
validation_dl = DataLoader(validation_dataset,batch_size=64)

In [9]:
model = TransformerTimeSeries().cuda()

In [10]:
lr = .00005 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,10, gamma=0.95)

In [11]:
epochs = 50

In [12]:
train_epoch_l
for e,epoch in enumerate(range(epochs)):
    train_loss = []
    eval_loss = []
    
    model.train()
    for step,(x_obs,x_future,y_obs,y_future,attention_masks) in enumerate(train_dl):
        optimizer.zero_grad()
        output = model(x_obs.cuda(),x_future.cuda(),y_obs.cuda(),y_future.cuda(),attention_masks[0].cuda())
        loss = criterion(output.squeeze()[:,:-1],y_future.cuda()[:,1:])
        train_loss.append(loss)
        
        loss.backward()
        optimizer.step()
        
    model.eval()
    with torch.no_grad():
        for step,(x_obs,x_future,y_obs,y_future,attention_masks) in enumerate(validation_dl):
            output = model(x_obs.cuda(),x_future.cuda(),y_obs.cuda(),y_future.cuda(),attention_masks[0].cuda())
            loss = criterion(output.squeeze()[:,:-1],y_future.cuda()[:,1:])
            eval_loss.append(loss)
        
    
    with torch.no_grad():
        print("Epoch {}: Train loss-{} \t Validation loss-{}".format(e,
                                                                     np.sum(train_loss)/len(train_dl),
                                                                     np.sum(eval_loss)/len(validation_dl)))
        
    scheduler.step()



Epoch 0: Train loss-1463.03955078125 	 Validation loss-965.7162475585938
Epoch 1: Train loss-531.7294311523438 	 Validation loss-104.82935333251953
Epoch 2: Train loss-81.66537475585938 	 Validation loss-67.85919952392578
Epoch 3: Train loss-48.09587860107422 	 Validation loss-21.851825714111328
Epoch 4: Train loss-35.00532531738281 	 Validation loss-43.44219970703125
Epoch 5: Train loss-30.521230697631836 	 Validation loss-32.069358825683594
Epoch 6: Train loss-25.60594367980957 	 Validation loss-14.210101127624512
Epoch 7: Train loss-22.631959915161133 	 Validation loss-39.948360443115234
Epoch 8: Train loss-20.195777893066406 	 Validation loss-38.40398025512695
Epoch 9: Train loss-18.737388610839844 	 Validation loss-17.837112426757812
Epoch 10: Train loss-16.97298240661621 	 Validation loss-7.598769187927246
Epoch 11: Train loss-15.455190658569336 	 Validation loss-8.019545555114746
Epoch 12: Train loss-15.016799926757812 	 Validation loss-42.82564163208008
Epoch 13: Train loss-14.

KeyboardInterrupt: 

## Visualize the data

In [None]:
n_plots = 10
with torch.no_grad():
    model.eval()
    for step,(x_obs,x_future,y_obs,y_future,attention_masks) in enumerate(validation_dl):
            output = model(x_obs.cuda(),x_future.cuda(),y_obs.cuda(),y_future.cuda(),attention_masks[0].cuda())

            if step > n_plots:
                break

            with torch.no_grad():
                plt.figure(figsize=(10,10))
                plt.plot(x_future[0].cpu().detach().squeeze().numpy(),y_future[0].cpu().detach().squeeze().numpy(),'g-')
                plt.plot(x_obs[0].cpu().detach().squeeze().numpy(),y_obs[0].cpu().detach().squeeze().numpy(),'g-')

                plt.plot(x_future[0].cpu().detach().squeeze().numpy()[1:],output[0].cpu().detach().squeeze().numpy()[:-1])
                plt.show()