In [1]:
import math
from typing import Tuple
import time
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset, DataLoader

In [43]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/home/jdli/TransSpectra/")

from transformer import TransformerReg,generate_square_subsequent_mask


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [141]:
from data import GaiaXPlabel_forcast
from torch.utils.data import DataLoader

data_dir = "/data/jdli/gaia/"
tr_file = "ap17_wise_xpcont_cut.npy"

device = torch.device('cuda:1')
TOTAL_NUM = 1000
BATCH_SIZE = 64

gdata  = GaiaXPlabel_forcast(data_dir+tr_file, total_num=TOTAL_NUM, part_train=True, device=device)

val_size = int(0.1*len(gdata))
A_size = int(0.5*(len(gdata)-val_size))
B_size = len(gdata) - A_size - val_size

A_dataset, B_dataset, val_dataset = torch.utils.data.random_split(gdata, [A_size, B_size, val_size], generator=torch.Generator().manual_seed(42))

print(len(A_dataset), len(B_dataset), len(val_dataset))

A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=2048)

450 450 100


In [144]:
model = TransformerReg(
    dim_val=64, input_size=1,
    batch_first=True, 
    enc_seq_len=30, dec_seq_len=4,
    out_seq_len=4, 
    n_decoder_layers=2,
    n_encoder_layers=2, 
    n_heads=4,
    max_seq_len=30,
).to(device)


# dim_val = 64 # This can be any value divisible by n_heads. 512 is used in the original transformer paper.
# n_heads = 4 # The number of attention heads (aka parallel attention layers). dim_val must be divisible by this number
# n_decoder_layers = 2 # Number of times the decoder layer is stacked in the decoder
# n_encoder_layers = 2 # Number of times the encoder layer is stacked in the encoder
# input_size = 1 # The number of input variables. 1 if univariate forecasting.
# enc_seq_len = 8575 # length of input given to encoder. Can have any integer value.
# dec_seq_len = 2 # length of input given to decoder. Can have any integer value.
# output_sequence_length = 2 # Length of the target sequence, i.e. how many time steps should your forecast
# max_seq_len = 8575 # What's the longest sequence the model will encounter? Used to make the positional encoder
# model = TransformerReg(dim_val=dim_val, input_size=input_size, 
#                     batch_first=True, dec_seq_len=dec_seq_len, 
#                     out_seq_len=output_sequence_length, n_decoder_layers=n_decoder_layers,
#                     n_encoder_layers=n_encoder_layers, n_heads=n_heads,
#                     max_seq_len=max_seq_len,
#                     ).to(device)

In [90]:
def infer(model: nn.Module, src: torch.Tensor, forecast_window:int,
          device,) -> torch.Tensor:
    
    target_seq_dim = 1
    tgt = src[:,-1,0].view(-1, 1, 1) # [bs, 1, 1]
    print(tgt.size())
    # Iteratively concatenate tgt with the first element in the prediction
    for _ in range(forecast_window-1):
        
        dim_a = tgt.shape[1] #1,2,3,.. n
        dim_b = src.shape[1] #30
        
        src_mask = generate_square_subsequent_mask(dim1=dim_a, dim2=dim_b).to(device)
        tgt_mask = generate_square_subsequent_mask(dim1=dim_a, dim2=dim_a).to(device)
        
        prediction = model(src, tgt, src_mask, tgt_mask)

        # Obtain the predicted value at t+1 where t is the last step 
        # represented in tgt
        last_predicted_value = prediction[:,-1,:].view(-1,1,1) #[bs, 1]
        
        # Reshape from [batch_size, 1] --> [1, batch_size, 1]
        # last_predicted_value = last_predicted_value
        # print(last_predicted_value.detach(), last_predicted_value.size())
        print(tgt.size())
        # Detach the predicted element from the graph and concatenate with 
        # tgt in dimension 1 or 0
        tgt = torch.cat((tgt, last_predicted_value), dim=target_seq_dim)
    
    src_mask = generate_square_subsequent_mask(dim1=4, dim2=30).to(device)
    tgt_mask = generate_square_subsequent_mask(dim1=4, dim2=4).to(device)
    
    # Make final prediction
    return model(src, tgt, src_mask, tgt_mask)

src = torch.rand(5, 30, 1).to(device)
infer(model=model, src=src, forecast_window=4, device=device)

torch.Size([5, 1, 1])
torch.Size([5, 1, 1])
torch.Size([5, 2, 1])
torch.Size([5, 3, 1])


tensor([[[ 0.4758],
         [ 0.1749],
         [ 0.6493],
         [ 0.1224]],

        [[-0.0355],
         [ 0.1663],
         [ 0.0221],
         [ 0.2551]],

        [[ 0.0329],
         [ 0.5259],
         [ 0.5337],
         [ 0.3079]],

        [[-0.0866],
         [ 0.2090],
         [ 0.2769],
         [ 0.4881]],

        [[ 0.9540],
         [ 0.4800],
         [ 0.6724],
         [ 0.2874]]], device='cuda:1', grad_fn=<ViewBackward0>)

In [147]:
enc_seq_len = 30
out_seq_len = 4

src_mask = generate_square_subsequent_mask(
dim1=4, dim2=30
).to(device)

tgt_mask = generate_square_subsequent_mask( 
    dim1=4, dim2=4
    ).to(device)


def train_epoch(tr_loader):
    # model.train()
    model.train()
    # model_mohaom.train()
    total_loss = 0.
    start = time.time()
    for batch, data in enumerate(tr_loader):

        output = model(data['x'].view(-1,enc_seq_len,1), 
        data['y'].view(-1,out_seq_len,1), src_mask, tgt_mask)
        loss = cost(output, data['y'].view(-1,out_seq_len,1))
        loss_value = loss.item()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        total_loss+=loss_value
        del data, output

    end = time.time()
    print(f"Epoch #%d tr loss:%.4f time:%.2f s"%(epoch, total_loss/(batch+1), (end-start)))


def eval(val_loader):
    model.eval()
    with torch.no_grad():
        total_val_loss = 0
        for bs, data in enumerate(val_loader):
            # output = model(data['x'], data['tgt'])
            output = infer(model=model, src=data['x'].view(-1,enc_seq_len,1), forecast_window=4, device=device)

            loss = cost(output, data['y'].view(-1,out_seq_len,1))
            total_val_loss+=loss.item()

    print("val loss:%.4f"%(total_val_loss/(bs+1)))

    # if epoch%5==0:
    #     eval(val_loader)
for epoch in range(num_epochs+1):
    train_epoch(A_loader)

Epoch #0 tr loss:0.9180 time:1.35 s
Epoch #1 tr loss:0.9523 time:1.35 s
Epoch #2 tr loss:0.9661 time:1.34 s
Epoch #3 tr loss:0.9332 time:1.34 s
Epoch #4 tr loss:0.9639 time:1.34 s
Epoch #5 tr loss:0.9463 time:1.34 s
Epoch #6 tr loss:0.9303 time:1.34 s
Epoch #7 tr loss:0.9430 time:1.36 s
Epoch #8 tr loss:0.9249 time:1.37 s
Epoch #9 tr loss:0.9350 time:1.38 s
Epoch #10 tr loss:0.9422 time:1.39 s
Epoch #11 tr loss:0.9454 time:1.37 s
Epoch #12 tr loss:0.9725 time:1.34 s
Epoch #13 tr loss:0.9236 time:1.36 s
Epoch #14 tr loss:0.9636 time:1.36 s
Epoch #15 tr loss:0.9746 time:1.34 s
Epoch #16 tr loss:0.9183 time:1.34 s
Epoch #17 tr loss:0.9343 time:1.36 s
Epoch #18 tr loss:0.9871 time:1.35 s
Epoch #19 tr loss:0.9474 time:1.36 s
Epoch #20 tr loss:0.9194 time:1.34 s


In [150]:
src = A_dataset[0]['x'].view(-1,enc_seq_len,1)

infer(model=model, src=src, forecast_window=4, device=device)

torch.Size([1, 1, 1])
torch.Size([1, 1, 1])
torch.Size([1, 2, 1])
torch.Size([1, 3, 1])


tensor([[[0.2787],
         [0.4630],
         [0.4972],
         [0.6162]]], device='cuda:1', grad_fn=<ViewBackward0>)

In [151]:
A_dataset[0]['y']

tensor([ 0.6115,  1.1501, -0.5797,  0.0771], device='cuda:1')

In [155]:
model(src, A_dataset[0]['y'].view(-1,4,1),
      src_mask, tgt_mask
      )

tensor([[[0.5737],
         [0.0361],
         [0.3946],
         [0.8979]]], device='cuda:1', grad_fn=<ViewBackward0>)

In [157]:
model(src, A_dataset[0]['y'].view(-1,4,1),
      )

tensor([[[ 0.3979],
         [-0.1271],
         [ 0.4941],
         [ 0.5222]]], device='cuda:1', grad_fn=<ViewBackward0>)

In [156]:
 A_dataset[0]['y'].view(-1,4,1)

tensor([[[ 0.6115],
         [ 1.1501],
         [-0.5797],
         [ 0.0771]]], device='cuda:1')

In [127]:
from data import APerr

data_dir = "/data/jdli/sdss/"
tr_file = "hogg19_spec_tr.npy"

device = torch.device('cuda:0')
TOTAL_NUM = 6000
BATCH_SIZE = 1

aspcap  = APerr(data_dir+tr_file, total_num=TOTAL_NUM,
part_train=True, device=device)

# aspcap_val = ASPCAP(data_dir+val_file, device=device)
train_size = int(0.5*len(aspcap))
val_size = len(aspcap) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(aspcap, [train_size, val_size], generator=torch.Generator().manual_seed(42))
print(len(train_dataset), len(val_dataset))

# tr_loader  = DataLoader(train_dataset, batch_size=BATCH_SIZE, )
# val_loader = DataLoader(val_dataset,  batch_size=BATCH_SIZE, )
tr_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, )

3000 3000


In [128]:
dim_val = 64 # This can be any value divisible by n_heads. 512 is used in the original transformer paper.
n_heads = 4 # The number of attention heads (aka parallel attention layers). dim_val must be divisible by this number
n_decoder_layers = 2 # Number of times the decoder layer is stacked in the decoder
n_encoder_layers = 2 # Number of times the encoder layer is stacked in the encoder
input_size = 1 # The number of input variables. 1 if univariate forecasting.
enc_seq_len = 8575 # length of input given to encoder. Can have any integer value.
dec_seq_len = 2 # length of input given to decoder. Can have any integer value.
output_sequence_length = 2 # Length of the target sequence, i.e. how many time steps should your forecast
max_seq_len = 8575 # What's the longest sequence the model will encounter? Used to make the positional encoder
model = TransformerReg(dim_val=dim_val, input_size=input_size, 
                    batch_first=True, dec_seq_len=dec_seq_len, 
                    out_seq_len=output_sequence_length, n_decoder_layers=n_decoder_layers,
                    n_encoder_layers=n_encoder_layers, n_heads=n_heads,
                    max_seq_len=max_seq_len,
                    ).to(device)

# criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
total_loss = 0.
num_epochs = 20
# num_batches = train_size//BATCH_SIZE
itr = 1
num_iters  = 50

src_mask = generate_square_subsequent_mask(
dim1=2, dim2=8575
).to(device)

tgt_mask = generate_square_subsequent_mask(
    dim1=2, dim2=2
).to(device)


for epoch in range(num_epochs):
    model.train()
    
    total_loss = 0
    for batch, (x, y) in enumerate(tr_loader):
        start = time.time()

        output = model(x, y, src_mask, tgt_mask)
        loss = cost(output, y)
        loss_value = loss.item()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        del x, y, output

    print(f"Epoch #%d tr loss:%.4f time:%.2f s"%(epoch, total_loss, (end-start)))