In [1]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

  warn(f"Failed to load image Python extension: {e}")


In [63]:
def smape_loss(y_pred, target):
    loss = 2 * (y_pred - target).abs() / (y_pred.abs() + target.abs() + 1e-8)
    return loss.mean()


def gen_trg_mask(length, device):
    mask = torch.tril(torch.ones(length, length, device=device)) == 1

    mask = (
        mask.float()
        .masked_fill(mask==0, float("-inf"))
        .masked_fill(mask==1, float(0.0))
    )
    return mask


class Spec2HRd(pl.LightningModule):
    def __init__(
        self,
        n_encoder_inputs,
        n_decoder_inputs,
        n_outputs,
        channels=512,
        dropout=0.2,
        lr=1e-4,
        nhead=4,
    ):
        super().__init__()
        
        self.n_encoder_inputs = n_encoder_inputs
        self.n_decoder_inputs = n_decoder_inputs
        self.save_hyperparameters()
        self.channels = channels
        self.n_outputs = n_outputs
        self.lr = lr
        self.dropout = dropout


        encoder_layer = nn.TransformerEncoderLayer(
            d_model=channels,
            nhead=nhead,
            dropout=self.dropout,
            dim_feedforward=4*channels,
        )
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=channels,
            nhead=nhead,
            dropout=self.dropout,
            dim_feedforward=4 * channels,
        )

        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=8)
        self.decoder = torch.nn.TransformerDecoder(decoder_layer, num_layers=8)

        self.input_projection = Linear(n_encoder_inputs, channels)
        self.output_projection = Linear(n_decoder_inputs, channels)
        self.input_pos_embedding = torch.nn.Embedding(1024, embedding_dim=channels)
        self.target_pos_embedding = torch.nn.Embedding(1024, embedding_dim=channels)

        # self.linear = Linear(channels, 2)
        self.fc1 = Linear(channels, 64)
        self.fc2 = Linear(64, n_outputs)
        self.do = nn.Dropout(p=self.dropout)

    def encode_src(self, src):
        src_start = self.input_projection(src).permute(1, 0, 2)

        in_sequence_len, batch_size = src_start.size(0), src_start.size(1)
        pos_encoder = (
            torch.arange(0, in_sequence_len, device=src.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        
        pos_encoder = self.input_pos_embedding(pos_encoder).permute(1, 0, 2)

        src = src_start + pos_encoder
        src = self.encoder(src) + src_start

        return src
    
    def decode_trg(self, trg, memory):

        trg_start = self.output_projection(trg).permute(1, 0, 2)

        out_sequence_len, batch_size = trg_start.size(1), trg_start.size(0)

        pos_decoder = (
            torch.arange(0, out_sequence_len, device=trg.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        pos_decoder = self.target_pos_embedding(pos_decoder).permute(1, 0, 2)

        trg = pos_decoder + trg_start
        
        trg_mask = gen_trg_mask(out_sequence_len, trg.device)
        out = self.decoder(tgt=trg, memory=memory, tgt_mask=trg_mask) + trg_start
        out = out.permute(1, 0, 2)

        return out

    def forward(self, x):
        src = x
        
        enc_ouput = self.encode_src(src) # (1, bs, 512)
        src = F.relu(enc_ouput) # (1, bs, 512)
        
        src = src.permute(1, 0, 2) #(bs, 1, 512)
        src = src.view(-1, self.channels) # (bs, 512)

        src = self.fc1(src) # (bs, 64)
        src = F.relu(src)
        
        tgt_a = self.fc2(src) # (bs, 2)

        dec_input = torch.concat((x.view(-1, self.n_encoder_inputs), tgt_a), dim=1).view(-1, 1, self.n_decoder_inputs)
        
        
        out = self.decode_trg(trg=dec_input, 
                              memory=enc_ouput)# (1, bs, 512)
        out = F.relu(out)
        out = out.permute(1, 0, 2) #(bs, 1, 512)
        out = out.view(-1, self.channels) # (bs, 512)
        
        out = self.fc1(out) # (bs, 64)
        out = F.relu(out)
        
        out = self.fc2(out) # (bs, 2)
        print(tgt_a.shape, out.shape)
        return torch.concat((tgt_a, out), dim=1)

In [64]:

source = torch.rand(size=(5, 1, 343))
target_in = torch.rand(size=(5, 1, 343))
target_out = torch.rand(size=(5, 1, 4))

# source = torch.rand(size=(32, 16, 9))
# target_in = torch.rand(size=(32, 16, 8))
# target_out = torch.rand(size=(32, 16, 1))
ts = Spec2HRd(n_encoder_inputs=343, n_decoder_inputs=343+2, n_outputs=2, channels=345, nhead=3)
pred = ts((source))

# print(pred.size())

# ts.training_step((source, target_in, target_out), batch_idx=1)

torch.Size([5, 2]) torch.Size([25, 2])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 25 for tensor number 1 in the list.

In [4]:
import sys
sys.path.append("/home/jdli/TransSpectra/")
from data import GaiaXPlabel_v2
from torch.utils.data import DataLoader


data_dir = "/data/jdli/gaia/"
tr_file = "ap17_xp.npy"

device = torch.device('cuda:0')

BATCH_SIZE = 16

gdata  = GaiaXPlabel_v2(data_dir+tr_file, total_num=1000, part_train=True, device=device)

val_size = int(0.1*len(gdata))
A_size = int(0.5*(len(gdata)-val_size))
B_size = len(gdata) - A_size - val_size

A_dataset, B_dataset, val_dataset = torch.utils.data.random_split(
    gdata, [A_size, B_size, val_size], 
    generator=torch.Generator().manual_seed(42)
)
print(len(A_dataset), len(B_dataset), len(val_dataset))

A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)



450 450 100
