In [1]:
# import numpy as np
# %load_ext autoreload
# %autoreload 2
import time
import torch
from torch.utils.data import DataLoader
from transGaia.transgaia import Spec2label
from transGaia.data import GaiaXP_allcoefs_5label_cont_norm, GaiaXP_55coefs_5label_cont_ANDnorm

In [2]:
data_dir = "/data/jdli/gaia/"
tr_file = "ap17_wise_xpcont_cut.npy"

# device = torch.device('cuda:1')
device = torch.device('cpu')
TOTAL_NUM = 1000
BATCH_SIZE = 128

gdata  = GaiaXP_allcoefs_5label_cont_norm(
    data_dir+tr_file, total_num=TOTAL_NUM, 
    part_train=True,  device=device
)

val_size = int(0.1*len(gdata))
A_size = int(0.5*(len(gdata)-val_size))
B_size = len(gdata) - A_size - val_size

A_dataset, B_dataset, val_dataset = torch.utils.data.random_split(gdata, [A_size, B_size, val_size], generator=torch.Generator().manual_seed(42))

print(len(A_dataset), len(B_dataset), len(val_dataset))

A_loader = DataLoader(A_dataset, batch_size=BATCH_SIZE)
B_loader = DataLoader(B_dataset, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

450 450 100


In [3]:
##==================Model parameters============================
##==============================================================
#===============================================================
INPUT_LEN = 55*2+3

model = Spec2label(
    n_encoder_inputs=INPUT_LEN,
    n_outputs=1, channels=512, n_heads=8, n_layers=8,
).to(device)

# cost = torch.nn.GaussianNLLLoss(full=True, reduction='mean')
cost = torch.nn.MSELoss(reduction='mean')

# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3, weight_decay=1e-6
)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

tr_select = "A"
model_dir = "/data/jdli/gaia/model/1205_teff/" + tr_select

if tr_select=="A":
    tr_loader = A_loader
    # check_point = model_dir +"/sp2_alpha_mse_A_ep200.pt"

elif tr_select=="B":
    tr_loader = B_loader
    # check_point = "/data/jdli/gaia/model/1119/B/sp2_4labels_mse_B_ep85.pt"
print("===================================")
# print("Loading checkpoint %s"%(check_point))
# model.load_state_dict(torch.load(check_point))

print("Traing %s begin"%tr_select)

Traing A begin


In [5]:
def train_epoch(tr_loader, epoch):
        # model.train()
    model.train()
    total_loss = 0.
    start_time = time.time()
    
    for batch, data in enumerate(tr_loader):
        output = model(data['x'].view(-1, INPUT_LEN))
        loss = cost(output, data['y'][:,0].view(-1,1))
        loss_value = loss.item()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss+=loss_value
        del data, output
    print("epoch %d train loss:%.4f | %.4f s"%(epoch, total_loss/(batch+1e-5), time.time()-start_time))
    
    
def eval(val_loader):
    model.eval()
    total_val_loss=0
    
    with torch.no_grad():
        for bs, data in enumerate(val_loader):
            output = model(data['x'].view(-1, INPUT_LEN))
            # loss = cost(output, data['y'], data['e_y'])
            loss = cost(output, data['y'][:,0].view(-1,1))
            total_val_loss+=loss.item()
            del data, output

    print("val loss:%.4f"%(total_val_loss/(bs+1e-5)))

num_epochs = 3
for epoch in range(num_epochs+1):
    train_epoch(tr_loader, epoch)

    if epoch%5==0:
        eval(val_loader)
    if epoch%50==0: 
        save_point =  "/sp2_teff_robustnorm_mse_%s_ep%d.pt"%(tr_select, epoch)
        torch.save(model.state_dict(), model_dir+save_point)

# torch.cuda.empty_cache()

epoch 0 train loss:42.9769 | 22.6586 s
val loss:85.4371
epoch 1 train loss:55.7195 | 22.4517 s
epoch 2 train loss:44.3606 | 22.3735 s
epoch 3 train loss:50.3621 | 22.2166 s
