In [1]:
import math
from typing import Tuple
import time
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.utils.data import dataset, DataLoader

In [43]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("/home/jdli/TransSpectra/")

from transformer import TransformerReg,generate_square_subsequent_mask


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [50]:

from data import GaiaXPlabel_forcast
from torch.utils.data import DataLoader

data_dir = "/data/jdli/gaia/"
tr_file = "ap17_wise_xpcont_cut.npy"

device = torch.device('cuda:1')
TOTAL_NUM = 1000
BATCH_SIZE = 64

gdata  = GaiaXPlabel_forcast(data_dir+tr_file, total_num=TOTAL_NUM, part_train=True, device=device)

val_size = int(0.1*len(gdata))
A_size = int(0.5*(len(gdata)-val_size))
B_size = len(gdata) - A_size - val_size

A_dataset, B_dataset, val_dataset = torch.utils.data.random_split(gdata, [A_size, B_size, val_size], generator=torch.Generator().manual_seed(42))

print(len(A_dataset), len(B_dataset), len(val_dataset))

A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=2048)

450 450 100


In [52]:
model = TransformerReg(
    dim_val=64, input_size=1,
    batch_first=True, 
    enc_seq_len=30, dec_seq_len=4,
    out_seq_len=4, 
    n_decoder_layers=2,
    n_encoder_layers=2, 
    n_heads=4,
    max_seq_len=30,
).to(device)

In [53]:
src = A_dataset[0]['x'].view(-1,30,1)
tgt = A_dataset[0]['tgt'].view(-1,4,1)

# Make src mask for decoder with size:
# [batch_size*n_heads, output_sequence_length, enc_seq_len]
src_mask = generate_square_subsequent_mask(
    dim1=4, dim2=30
    ).to(device)

# Make tgt mask for decoder with size:
# [batch_size*n_heads, output_sequence_length, output_sequence_length]
tgt_mask = generate_square_subsequent_mask( 
    dim1=4, dim2=4
    ).to(device)

In [90]:
def infer(model: nn.Module, src: torch.Tensor, forecast_window:int,
          device,) -> torch.Tensor:
    
    target_seq_dim = 1
    tgt = src[:,-1,0].view(-1, 1, 1) # [bs, 1, 1]
    print(tgt.size())
    # Iteratively concatenate tgt with the first element in the prediction
    for _ in range(forecast_window-1):
        
        dim_a = tgt.shape[1] #1,2,3,.. n
        dim_b = src.shape[1] #30
        
        src_mask = generate_square_subsequent_mask(dim1=dim_a, dim2=dim_b).to(device)
        tgt_mask = generate_square_subsequent_mask(dim1=dim_a, dim2=dim_a).to(device)
        
        prediction = model(src, tgt, src_mask, tgt_mask)

        # Obtain the predicted value at t+1 where t is the last step 
        # represented in tgt
        last_predicted_value = prediction[:,-1,:].view(-1,1,1) #[bs, 1]
        
        # Reshape from [batch_size, 1] --> [1, batch_size, 1]
        # last_predicted_value = last_predicted_value
        # print(last_predicted_value.detach(), last_predicted_value.size())
        print(tgt.size())
        # Detach the predicted element from the graph and concatenate with 
        # tgt in dimension 1 or 0
        tgt = torch.cat((tgt, last_predicted_value), dim=target_seq_dim)
    
    src_mask = generate_square_subsequent_mask(dim1=4, dim2=30).to(device)
    tgt_mask = generate_square_subsequent_mask(dim1=4, dim2=4).to(device)
    
    # Make final prediction
    return model(src, tgt, src_mask, tgt_mask)

src = torch.rand(5, 30, 1).to(device)
infer(model=model, src=src, forecast_window=4, device=device)

torch.Size([5, 1, 1])
torch.Size([5, 1, 1])
torch.Size([5, 2, 1])
torch.Size([5, 3, 1])


tensor([[[ 0.4758],
         [ 0.1749],
         [ 0.6493],
         [ 0.1224]],

        [[-0.0355],
         [ 0.1663],
         [ 0.0221],
         [ 0.2551]],

        [[ 0.0329],
         [ 0.5259],
         [ 0.5337],
         [ 0.3079]],

        [[-0.0866],
         [ 0.2090],
         [ 0.2769],
         [ 0.4881]],

        [[ 0.9540],
         [ 0.4800],
         [ 0.6724],
         [ 0.2874]]], device='cuda:1', grad_fn=<ViewBackward0>)

In [None]:
src_mask = generate_square_subsequent_mask(
dim1=4, dim2=30
).to(device)

tgt_mask = generate_square_subsequent_mask( 
    dim1=4, dim2=4
    ).to(device)

# cost = torch.nn.GaussianNLLLosss(full=True, reduction='mean')



def train_epoch(tr_loader):
    # model.train()
    model.train()
    # model_mohaom.train()
    total_loss = 0.
    start = time.time()
    for batch, data in enumerate(tr_loader):

        output = model(data['x'].view(-1,enc_seq_len,1), 
        data['y'].view(-1,out_seq_len,1), src_mask, tgt_mask)
        loss = cost(output, data['y'].view(-1,out_seq_len,1))
        loss_value = loss.item()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        total_loss+=loss_value
        del data, output

    end = time.time()
    print(f"Epoch #%d tr loss:%.4f time:%.2f s"%(epoch, total_loss/(batch+1), (end-start)))


def eval(val_loader):
    model.eval()
    with torch.no_grad():
        total_val_loss = 0
        for bs, data in enumerate(val_loader):
            # output = model(data['x'], data['tgt'])
            output = infer(model=model, src=data['x'].view(-1,enc_seq_len,1), forecast_window=4, device=device)

            loss = cost(output, data['y'].view(-1,out_seq_len,1))
            total_val_loss+=loss.item()

    print("val loss:%.4f"%(total_val_loss/(bs+1)))

    # if epoch%5==0:
    #     eval(val_loader)

In [101]:
torch.sin(X_tr)

In [119]:
from data import ASPCAP_error

data_dir = "/data/jdli/sdss/"
tr_file = "hogg19_spec_tr.npy"

device = torch.device('cuda:1')
TOTAL_NUM = 6000
BATCH_SIZE = 3

aspcap  = ASPCAP_error(data_dir+tr_file, total_num=TOTAL_NUM,
part_train=False, device=device)

# aspcap_val = ASPCAP(data_dir+val_file, device=device)
train_size = int(0.5*len(aspcap))
val_size = len(aspcap) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(aspcap, [train_size, val_size], generator=torch.Generator().manual_seed(42))
print(len(train_dataset), len(val_dataset))

# tr_loader  = DataLoader(train_dataset, batch_size=BATCH_SIZE, )
# val_loader = DataLoader(val_dataset,  batch_size=BATCH_SIZE, )
tr_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, )

ImportError: cannot import name 'ASPCAP_error' from 'data' (/home/jdli/TransSpectra/data.py)