In [43]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


def smape_loss(y_pred, target):
    loss = 2 * (y_pred - target).abs() / (y_pred.abs() + target.abs() + 1e-8)
    return loss.mean()


def gen_trg_mask(length, device):
    mask = torch.tril(torch.ones(length, length, device=device)) == 1

    mask = (
        mask.float()
        .masked_fill(mask == 0, float("-inf"))
        .masked_fill(mask == 1, float(0.0))
    )
    return mask


class Spec2label(pl.LightningModule):
    def __init__(
        self,
        n_encoder_inputs,
        n_outputs,
        channels=512,
        dropout=0.2,
        lr=1e-4,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.channels = channels
        self.n_outputs = n_outputs
        self.lr = lr
        self.dropout = dropout

        self.input_pos_embedding = torch.nn.Embedding(1024, embedding_dim=channels)
        self.target_pos_embedding = torch.nn.Embedding(1024, embedding_dim=channels)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=channels,
            nhead=8,
            dropout=self.dropout,
            dim_feedforward=4*channels,
        )
        # decoder_layer = nn.TransformerDecoderLayer(
        #     d_model=channels,
        #     nhead=8,
        #     dropout=self.dropout,
        #     dim_feedforward=4 * channels,
        # )

        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=8)
        # self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=8)

        self.input_projection = Linear(n_encoder_inputs, channels)
        # self.output_projection = Linear(n_decoder_inputs, channels)

        # self.linear = Linear(channels, 2)
        self.fc1 = Linear(channels, 64)
        self.fc2 = Linear(64, n_outputs)
        self.do = nn.Dropout(p=self.dropout)

    def encode_src(self, src):
        src_start = self.input_projection(src).permute(1, 0, 2)

        in_sequence_len, batch_size = src_start.size(0), src_start.size(1)
        pos_encoder = (
            torch.arange(0, in_sequence_len, device=src.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        
        pos_encoder = self.input_pos_embedding(pos_encoder).permute(1, 0, 2)

        src = src_start + pos_encoder
        src = self.encoder(src) + src_start

        return src

    def forward(self, x):
        src = x
        
        src = self.encode_src(src) # (1, bs, 512)
        src = F.relu(src) # (1, bs, 512)
        
        src = src.permute(1, 0, 2) #(bs, 1, 512)
        src = src.view(-1, self.channels) # (bs, 512)

        src = self.fc1(src) # (bs, 64)
        src = F.relu(src)
        tgt = self.fc2(src) # (bs, 2)
        # out = self.decode_trg(trg=trg, memory=src)
        return tgt

    def training_step(self, batch, batch_idx):
        src, trg_out = batch['x'], batch['y']

        y_hat = self((src))

        y_hat = y_hat.view(-1)
        y = trg_out.view(-1)

        loss = smape_loss(y_hat, y)
        self.log("train_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        src, trg_out = batch['x'], batch['y']

        y_hat = self((src))

        y_hat = y_hat.view(-1)
        y = trg_out.view(-1)

        loss = smape_loss(y_hat, y)
        self.log("valid_loss", loss)

        return loss

    def test_step(self, batch, batch_idx):
        src, trg_out = batch['x'], batch['y']

        y_hat = self((src))

        y_hat = y_hat.view(-1)
        y = trg_out.view(-1)

        loss = smape_loss(y_hat, y)
        self.log("test_loss", loss)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=10, factor=0.1
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "valid_loss",
        }

In [44]:
source = torch.rand(size=(5, 1, 343))
target_in = torch.rand(size=(5, 1, 343))
target_out = torch.rand(size=(5, 1, 2))

# source = torch.rand(size=(32, 16, 9))
# target_in = torch.rand(size=(32, 16, 8))
# target_out = torch.rand(size=(32, 16, 1))
ts = Spec2label(n_encoder_inputs=343, n_outputs=2)
pred = ts((source))

print(pred.size())

# ts.training_step((source, target_in, target_out), batch_idx=1)

torch.Size([5, 2])


In [40]:
import sys
sys.path.append("/home/jdli/TransSpectra/")
from data import GaiaXPlabel_v2
from torch.utils.data import DataLoader


data_dir = "/data/jdli/gaia/"
tr_file = "ap17_xp.npy"

device = torch.device('cuda:1')

BATCH_SIZE = 64

gdata  = GaiaXPlabel_v2(data_dir+tr_file, total_num=1000, part_train=True, device=device)

val_size = int(0.1*len(gdata))
A_size = int(0.5*(len(gdata)-val_size))
B_size = len(gdata) - A_size - val_size

A_dataset, B_dataset, val_dataset = torch.utils.data.random_split(
    gdata, [A_size, B_size, val_size], 
    generator=torch.Generator().manual_seed(42)
)
print(len(A_dataset), len(B_dataset), len(val_dataset))

A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)



450 450 100


In [45]:
def train(
    train_loader, val_loader,
    output_json_path: str,
    log_dir: str = "ts_logs",
    model_dir: str = "ts_models",
    batch_size: int = 8,
    epochs: int = 2000,
    horizon_size: int = 30,
):

    model = Spec2label(
        n_encoder_inputs=343,
        n_outputs=2,
        lr=1e-5,
        dropout=0.1,
    )

    logger = TensorBoardLogger(
        save_dir=log_dir,
    )

    checkpoint_callback = ModelCheckpoint(
        monitor="valid_loss",
        mode="min",
        dirpath=model_dir,
        filename="ts",
    )

    trainer = pl.Trainer(
        max_epochs=epochs,
        accelerator='gpu', devices=1,
        logger=logger,
        callbacks=[checkpoint_callback],
    )
    trainer.fit(model, train_loader, val_loader)

    result_val = trainer.test(test_dataloaders=val_loader)

    output_json = {
        "val_loss": result_val[0]["test_loss"],
        "best_model_path": checkpoint_callback.best_model_path,
    }

    if output_json_path is not None:
        with open(output_json_path, "w") as f:
            json.dump(output_json, f, indent=4)

    return output_json

In [None]:
EPOCH = 10

train(
    A_loader, val_loader,
    output_json_path="/data/jdli/gaia/model/forcasting_1107A.json",
    log_dir="/data/jdli/gaia/model/forcasting_1107A.log",
    model_dir="/data/jdli/gaia/model/forcasting_1107A.pt",
    epochs=EPOCH,
)

  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name                 | Type               | Params
------------------------------------------------------------
0 | input_pos_embedding  | Embedding          | 524 K 
1 | target_pos_embedding | Embedding          | 524 K 
2 | encoder              | TransformerEncoder | 25.2 M
3 | input_projection     | Linear             | 176 K 
4 | fc1                  | Linear             | 32.8 K
5 | fc2                  | Linear             | 130   
6 | do                   | Dropout            | 0     
------------------------------------------------------------
26.5 M    Trainable params
0         Non-trainable params
26.5 M    Total params
105.907   Total estimated model params size (MB)
