In [1]:
import time
import torch, torchvision
from torch.autograd import Variable
import numpy as np
import pandas as pd
import os
import gc
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import datetime

from DBFNet.Model import *
import warnings

### 计算距离函数

In [2]:
def compute_distance_km(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = lat1.astype('float'), lon1.astype('float'), lat2.astype('float'), lon2.astype('float')
    # 批量计算地球上两点间的球面距离
    R = 6371e3  # 地球半径（米）
    phi_1, phi_2 = np.radians(lat1), np.radians(lat2)
    delta_phi = np.radians(lat2 - lat1)
    delta_lambda = np.radians(lon2 - lon1)
    a = np.power(np.sin(delta_phi / 2), 2) + np.cos(phi_1) * np.cos(phi_2) * np.power(np.sin(delta_lambda / 2), 2)
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
    return R * c / 1000.

### 归一化函数

In [3]:
def data_scaler(x, feature_range=[0, 1]):
    scaler = MinMaxScaler(feature_range=feature_range)
    shape = x.shape
    x = x.reshape(shape[0], -1)
    x = scaler.fit_transform(x)
    x = x.reshape(shape)
    return x, scaler

### 数据集相关类

In [4]:
class TC_PreTrainLoader(Dataset):
    def __init__(self, pretrain_tc):
        self.tc = pretrain_tc

    def __getitem__(self, index):
        return self.tc[index, :, :-2], self.tc[index, :, -2:]

    def __len__(self):
        return self.tc.shape[0]
    
class GPH_PreTrainLoader(Dataset):
    def __init__(self, gph):
        self.g = gph

    def __getitem__(self, item):
        return self.g[item, :4], self.g[item, 4:]

    def __len__(self):
        return len(self.g)

class TrainLoader(Dataset):
    def __init__(self, x, y, g, standard):
        self.x = x
        self.y = y
        self.g = g
        self.s = standard

    def __getitem__(self, item):
        return self.x[item], self.g[item, :4], self.y[item], self.g[item, 4:], self.s[item]

    def __len__(self):
        return len(self.x)

### 训练函数

In [5]:
def train(cfg1=TC_Encoder_config, cfg2=Map_Encoder_config, checkpoint=None,
          tc_pretrain_path='./CMA_dataset/CMA_csv/xtc_pretrain_data.npy', xtc_path='./CMA_dataset/xtc.npy',
          ytc_path='./CMA_dataset/ytc.npy', split_index_path='./CMA_dataset/split_index.npy', gph_path='./gph.npy',
          tc_pretrain_epoch=50, tc_pretrain_lr=0.01, gph_pretrain_epoch=50, gph_pretrain_lr=0.01, batch_size=32,
          epoch=50, lr=0.001, model_save_path='./model/model001.pth'):
    warnings.filterwarnings("ignore") # 关闭warning

    model = DBFNet(cfg1, cfg2)
    if checkpoint is not None:
        state_dict = torch.load(checkpoint)
        model.load_state_dict(state_dict)
    criterion = nn.L1Loss()
    if torch.cuda.is_available():
        model = model.cuda()


    # 载入TC预训练数据
    print('处理TC Encoder预训练数据：', end='')
    pre_tc = np.load(tc_pretrain_path)
    pre_tc, pre_tc_scaler = data_scaler(pre_tc)
    pre_tc = DataLoader(TC_PreTrainLoader(pre_tc), batch_size=batch_size, shuffle=True)
    print('完成')

    # TC Encoder预训练
    print('TC Encoder 预训练：')
    model.train()
    for i in range(tc_pretrain_epoch):
        optimizer = torch.optim.Adam(model.parameters(), lr=tc_pretrain_lr * (0.1 ** (i // 20)))
        # print(tc_pretrain_lr * (0.1 ** (i // 20)))
        time1 = datetime.datetime.now()
        pretrain_loss = 0
        for X, Y in pre_tc:
            if torch.cuda.is_available():
                x = X.float().cuda()
                y = Y.float().cuda()
            else:
                x = X.float()
                y = Y.float()
            optimizer.zero_grad()
            pred = model.pretrain_tc_forward(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            pretrain_loss += loss.detach().item()
        time2 = datetime.datetime.now()
        print('Pretrain TC [%d/%d] cost:%ds train_loss: %.5f' % (
        i + 1, tc_pretrain_epoch, (time2 - time1).seconds, pretrain_loss))
    del pre_tc # 降低内存占用


    # 载入完整数据集
    print('载入完整数据集...')
    xtc = np.load(xtc_path, allow_pickle=True).astype(float)
    ytc = np.load(ytc_path, allow_pickle=True).astype(float)
    split_index = np.load(split_index_path, allow_pickle=True).astype(int)
    gph = np.load(gph_path, allow_pickle=True).astype(float)
    full_train_index = [*range(split_index[0], split_index[2])]
    test_index = [*range(split_index[2], split_index[3])]
    standard = xtc[:, -1, :2].reshape(-1, 1, 2)


    # 归一化
    xtc, _ = data_scaler(xtc)
    # _ , y_scaler = data_scaler(ytc)
    gph, _ = data_scaler(gph)
    # gph, _ = data_scaler(gph)

    # 打包GPH预训练数据
    pre_gph = DataLoader(GPH_PreTrainLoader(gph), batch_size=batch_size, shuffle=True, drop_last=True)

    # GPH Encoder-Decoder预训练
    print('GPH Encoder-Decoder 预训练：')
    model.train()
    for i in range(gph_pretrain_epoch):
        optimizer = torch.optim.Adam(model.parameters(), lr=gph_pretrain_lr * (0.1 ** (i // 20)))
        time1 = datetime.datetime.now()
        pretrain_loss = 0
        for X, Y in tqdm(pre_gph):
            if torch.cuda.is_available():
                x = X.float().cuda()
                y = Y.float().cuda()
            else:
                x = X.float()
                y = Y.float()
            optimizer.zero_grad()
            pred = model.pretrain_map_forward(x)
            loss = criterion(pred, y)
            loss.backward()
            # print([x.grad for x in optimizer.param_groups[0]['params']])
            optimizer.step()
            pretrain_loss += loss.detach().item()
        time2 = datetime.datetime.now()
        print('Pretrain GPH [%d/%d] cost:%ds train_loss: %.5f' % (i + 1, gph_pretrain_epoch, (time2 - time1).seconds,
                                                                  pretrain_loss))
    del pre_gph # 降低内存占用


    # 从训练数据中划分训练集和验证集
    train_index, val_index, _, _ = train_test_split(full_train_index, full_train_index, test_size=0.2)
    # 打包数据集
    train_dataset = DataLoader(TrainLoader(xtc[train_index], ytc[train_index], gph[train_index], standard[train_index]),
                               batch_size=batch_size, shuffle=True)
    val_dataset = DataLoader(TrainLoader(xtc[val_index], ytc[val_index], gph[val_index], standard[val_index]),
                             batch_size=batch_size, shuffle=True)
    test_dataset = DataLoader(TrainLoader(xtc[test_index], ytc[test_index], gph[test_index], standard[test_index]),
                              batch_size=batch_size, shuffle=False)
    del xtc, ytc, gph, split_index # 降低内存占用

    # 整个模型训练
    print('完整模型训练：')
    for i in range(epoch):
        optimizer = torch.optim.Adam(model.parameters(), lr=lr * (0.1 ** (i // 20)))
        # 训练集
        time1 = datetime.datetime.now()
        model.train()
        train_loss1, train_loss2 = 0, 0
        for X1, X2, Y1, Y2, _ in tqdm(train_dataset):
            if torch.cuda.is_available():
                x1 = X1.float().cuda()
                x2 = X2.float().cuda()
                y1 = Y1.float().cuda()
                y2 = Y2.float().cuda()
            else:
                x1 = X1.float()
                x2 = X2.float()
                y1 = Y1.float()
                y2 = Y2.float()
            optimizer.zero_grad()
            pred1, pred2 = model(x1, x2)
            loss1 = criterion(pred1, y1)
            loss2 = criterion(pred2, y2)
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()
            train_loss1 += loss1.detach().item()
            train_loss2 += loss2.detach().item()
        torch.save(model.state_dict(), model_save_path)
        time2 = datetime.datetime.now()
        print('[%d/%d] Train cost:%ds loss: %.5f/%.5f' % (i + 1, epoch, (time2 - time1).seconds, train_loss1, train_loss2), end=' ')
        # 验证集
        model.eval()
        with torch.no_grad():
            time1 = datetime.datetime.now()
            val_loss1, val_loss2 = 0, 0
            for X1, X2, Y1, Y2, _ in val_dataset:
                if torch.cuda.is_available():
                    x1 = X1.float().cuda()
                    x2 = X2.float().cuda()
                    y1 = Y1.float().cuda()
                    y2 = Y2.float().cuda()
                else:
                    x1 = X1.float()
                    x2 = X2.float()
                    y1 = Y1.float()
                    y2 = Y2.float()
                pred1, pred2 = model(x1, x2)
                loss1 = criterion(pred1, y1)
                loss2 = criterion(pred2, y2)
                val_loss1 += loss1.detach().item()
                val_loss2 += loss2.detach().item()
            time2 = datetime.datetime.now()
        print('  ||  Valid cost:%ds loss: %.5f/%.5f' % ((time2 - time1).seconds, val_loss1, val_loss2))
    print('训练完成')

    # 在测试集上预测
    print('测试集上预测：')
    model.eval()
    pred_save, truth_save, standard_save = [], [], []
    for X1, X2, Y1, _, S in tqdm(test_dataset):
        if torch.cuda.is_available():
            x1 = X1.float().cuda()
            x2 = X2.float().cuda()
            y1 = Y1.float().cuda()
            s = S.float().cuda()
        else:
            x1 = X1.float()
            x2 = X2.float()
            y1 = Y1.float()
            s = S.float()
        pred, _ = model(x1, x2)
        pred_save.append(pred.detach().cpu().numpy())
        truth_save.append(y1.detach().cpu().numpy())
        standard_save.append(s.detach().cpu().numpy())
    # 做反归一化
    pred_save = np.concatenate(pred_save, axis=0)
    truth_save = np.concatenate(truth_save, axis=0)
    standard_save = np.concatenate(standard_save, axis=0)
    # pred_save = data_scaler_inverse(pred_save, y_scaler)
    # 计算最终预测结果
    pred = pred_save + standard_save
    truth = truth_save + standard_save
    for i in range(1, 5):
        pred_lat, pred_lon, truth_lat, truth_lon = pred[:, i - 1, 0], pred[:, i - 1, 1], truth[:, i - 1, 0], truth[:, i - 1, 1]
        distance = compute_distance_km(pred_lat, pred_lon, truth_lat, truth_lon)
        print(str(i * 6) + 'h APE：', distance.mean())

### 训练

In [6]:
train(model_save_path='./model/model_demo.pth')
gc.collect()

处理TC Encoder预训练数据：完成
TC Encoder 预训练：
Pretrain TC [1/50] cost:1s train_loss: 16.89595
Pretrain TC [2/50] cost:1s train_loss: 5.94226
Pretrain TC [3/50] cost:1s train_loss: 5.21437
Pretrain TC [4/50] cost:1s train_loss: 4.73386
Pretrain TC [5/50] cost:1s train_loss: 4.44967
Pretrain TC [6/50] cost:1s train_loss: 4.24450
Pretrain TC [7/50] cost:1s train_loss: 3.96177
Pretrain TC [8/50] cost:1s train_loss: 4.01570
Pretrain TC [9/50] cost:1s train_loss: 3.90488
Pretrain TC [10/50] cost:1s train_loss: 3.81908
Pretrain TC [11/50] cost:1s train_loss: 3.81404
Pretrain TC [12/50] cost:1s train_loss: 3.70426
Pretrain TC [13/50] cost:1s train_loss: 3.70496
Pretrain TC [14/50] cost:1s train_loss: 3.61890
Pretrain TC [15/50] cost:1s train_loss: 3.60796
Pretrain TC [16/50] cost:1s train_loss: 3.51278
Pretrain TC [17/50] cost:1s train_loss: 3.60853
Pretrain TC [18/50] cost:1s train_loss: 3.36293
Pretrain TC [19/50] cost:1s train_loss: 3.43408
Pretrain TC [20/50] cost:1s train_loss: 3.26364
Pretrain TC

100%|██████████| 781/781 [00:30<00:00, 25.86it/s]


Pretrain GPH [1/50] cost:30s train_loss: 49.09777


100%|██████████| 781/781 [00:31<00:00, 24.43it/s]


Pretrain GPH [2/50] cost:31s train_loss: 38.30277


100%|██████████| 781/781 [00:31<00:00, 24.77it/s]


Pretrain GPH [3/50] cost:31s train_loss: 34.52732


100%|██████████| 781/781 [00:32<00:00, 23.90it/s]


Pretrain GPH [4/50] cost:32s train_loss: 33.38972


100%|██████████| 781/781 [00:32<00:00, 24.31it/s]


Pretrain GPH [5/50] cost:32s train_loss: 31.90945


100%|██████████| 781/781 [00:31<00:00, 24.57it/s]


Pretrain GPH [6/50] cost:31s train_loss: 32.94148


100%|██████████| 781/781 [00:32<00:00, 23.78it/s]


Pretrain GPH [7/50] cost:32s train_loss: 31.21187


100%|██████████| 781/781 [00:31<00:00, 25.03it/s]


Pretrain GPH [8/50] cost:31s train_loss: 30.64280


100%|██████████| 781/781 [00:31<00:00, 24.61it/s]


Pretrain GPH [9/50] cost:31s train_loss: 30.27454


100%|██████████| 781/781 [00:30<00:00, 25.35it/s]


Pretrain GPH [10/50] cost:30s train_loss: 33.84336


100%|██████████| 781/781 [00:31<00:00, 24.99it/s]


Pretrain GPH [11/50] cost:31s train_loss: 31.39143


100%|██████████| 781/781 [00:32<00:00, 24.27it/s]


Pretrain GPH [12/50] cost:32s train_loss: 33.74748


100%|██████████| 781/781 [00:30<00:00, 25.64it/s]


Pretrain GPH [13/50] cost:30s train_loss: 31.01718


100%|██████████| 781/781 [00:30<00:00, 25.46it/s]


Pretrain GPH [14/50] cost:30s train_loss: 31.38599


100%|██████████| 781/781 [00:32<00:00, 24.39it/s]


Pretrain GPH [15/50] cost:32s train_loss: 29.78042


100%|██████████| 781/781 [00:30<00:00, 26.00it/s]


Pretrain GPH [16/50] cost:30s train_loss: 29.12757


100%|██████████| 781/781 [00:31<00:00, 24.53it/s]


Pretrain GPH [17/50] cost:31s train_loss: 29.01195


100%|██████████| 781/781 [00:31<00:00, 24.56it/s]


Pretrain GPH [18/50] cost:31s train_loss: 28.63931


100%|██████████| 781/781 [00:30<00:00, 25.21it/s]


Pretrain GPH [19/50] cost:30s train_loss: 28.35382


100%|██████████| 781/781 [00:31<00:00, 24.45it/s]


Pretrain GPH [20/50] cost:31s train_loss: 28.14558


100%|██████████| 781/781 [00:31<00:00, 25.01it/s]


Pretrain GPH [21/50] cost:31s train_loss: 26.23186


100%|██████████| 781/781 [00:31<00:00, 24.56it/s]


Pretrain GPH [22/50] cost:31s train_loss: 26.07834


100%|██████████| 781/781 [00:31<00:00, 24.75it/s]


Pretrain GPH [23/50] cost:31s train_loss: 26.02798


100%|██████████| 781/781 [00:33<00:00, 23.51it/s]


Pretrain GPH [24/50] cost:33s train_loss: 26.00646


100%|██████████| 781/781 [00:31<00:00, 24.46it/s]


Pretrain GPH [25/50] cost:31s train_loss: 25.93066


100%|██████████| 781/781 [00:32<00:00, 24.12it/s]


Pretrain GPH [26/50] cost:32s train_loss: 25.88453


100%|██████████| 781/781 [00:31<00:00, 25.04it/s]


Pretrain GPH [27/50] cost:31s train_loss: 25.82839


100%|██████████| 781/781 [00:30<00:00, 25.81it/s]


Pretrain GPH [28/50] cost:30s train_loss: 25.79395


100%|██████████| 781/781 [00:30<00:00, 25.62it/s]


Pretrain GPH [29/50] cost:30s train_loss: 25.89359


100%|██████████| 781/781 [00:31<00:00, 24.50it/s]


Pretrain GPH [30/50] cost:31s train_loss: 25.75082


100%|██████████| 781/781 [00:30<00:00, 25.54it/s]


Pretrain GPH [31/50] cost:30s train_loss: 25.71268


100%|██████████| 781/781 [00:31<00:00, 24.58it/s]


Pretrain GPH [32/50] cost:31s train_loss: 25.73016


100%|██████████| 781/781 [00:32<00:00, 24.25it/s]


Pretrain GPH [33/50] cost:32s train_loss: 25.66151


100%|██████████| 781/781 [00:31<00:00, 24.41it/s]


Pretrain GPH [34/50] cost:32s train_loss: 25.60192


100%|██████████| 781/781 [00:31<00:00, 25.06it/s]


Pretrain GPH [35/50] cost:31s train_loss: 25.57385


100%|██████████| 781/781 [00:30<00:00, 25.27it/s]


Pretrain GPH [36/50] cost:30s train_loss: 25.58441


100%|██████████| 781/781 [00:32<00:00, 24.33it/s]


Pretrain GPH [37/50] cost:32s train_loss: 25.60092


100%|██████████| 781/781 [00:31<00:00, 25.11it/s]


Pretrain GPH [38/50] cost:31s train_loss: 25.53858


100%|██████████| 781/781 [00:32<00:00, 23.98it/s]


Pretrain GPH [39/50] cost:32s train_loss: 25.47760


100%|██████████| 781/781 [00:29<00:00, 26.65it/s]


Pretrain GPH [40/50] cost:29s train_loss: 25.46589


100%|██████████| 781/781 [00:32<00:00, 24.11it/s]


Pretrain GPH [41/50] cost:32s train_loss: 25.17490


100%|██████████| 781/781 [00:31<00:00, 24.81it/s]


Pretrain GPH [42/50] cost:31s train_loss: 25.18003


100%|██████████| 781/781 [00:32<00:00, 24.37it/s]


Pretrain GPH [43/50] cost:32s train_loss: 25.16581


100%|██████████| 781/781 [00:32<00:00, 24.08it/s]


Pretrain GPH [44/50] cost:32s train_loss: 25.16347


100%|██████████| 781/781 [00:32<00:00, 24.31it/s]


Pretrain GPH [45/50] cost:32s train_loss: 25.15883


100%|██████████| 781/781 [00:28<00:00, 26.93it/s]


Pretrain GPH [46/50] cost:29s train_loss: 25.14436


100%|██████████| 781/781 [00:30<00:00, 25.24it/s]


Pretrain GPH [47/50] cost:30s train_loss: 25.14160


100%|██████████| 781/781 [00:31<00:00, 24.89it/s]


Pretrain GPH [48/50] cost:31s train_loss: 25.13789


100%|██████████| 781/781 [00:28<00:00, 27.03it/s]


Pretrain GPH [49/50] cost:28s train_loss: 25.14022


100%|██████████| 781/781 [00:31<00:00, 24.59it/s]


Pretrain GPH [50/50] cost:31s train_loss: 25.13566
完整模型训练：


100%|██████████| 533/533 [00:25<00:00, 20.80it/s]


[1/50] Train cost:25s loss: 559.11410/18.65412   ||  Valid cost:2s loss: 119.34774/5.14872


100%|██████████| 533/533 [00:25<00:00, 20.52it/s]


[2/50] Train cost:26s loss: 435.81019/18.94388   ||  Valid cost:2s loss: 108.99391/4.75368


100%|██████████| 533/533 [00:24<00:00, 21.63it/s]


[3/50] Train cost:24s loss: 402.93908/18.74506   ||  Valid cost:2s loss: 98.77228/5.18930


100%|██████████| 533/533 [00:25<00:00, 20.87it/s]


[4/50] Train cost:25s loss: 377.90703/18.86163   ||  Valid cost:2s loss: 118.97264/5.43147


100%|██████████| 533/533 [00:26<00:00, 20.09it/s]


[5/50] Train cost:26s loss: 371.51568/19.00432   ||  Valid cost:2s loss: 98.25122/4.82728


100%|██████████| 533/533 [00:24<00:00, 21.42it/s]


[6/50] Train cost:24s loss: 379.53400/19.80814   ||  Valid cost:2s loss: 96.81511/5.65936


100%|██████████| 533/533 [00:24<00:00, 21.72it/s]


[7/50] Train cost:24s loss: 381.56546/21.34414   ||  Valid cost:2s loss: 96.08084/5.43859


100%|██████████| 533/533 [00:25<00:00, 21.12it/s]


[8/50] Train cost:25s loss: 362.40679/20.07224   ||  Valid cost:2s loss: 86.93305/4.96185


100%|██████████| 533/533 [00:24<00:00, 21.72it/s]


[9/50] Train cost:24s loss: 353.57080/19.87502   ||  Valid cost:2s loss: 94.49540/5.59558


100%|██████████| 533/533 [00:26<00:00, 20.10it/s]


[10/50] Train cost:26s loss: 360.26518/20.09311   ||  Valid cost:2s loss: 98.34340/5.43283


100%|██████████| 533/533 [00:26<00:00, 20.33it/s]


[11/50] Train cost:26s loss: 356.20980/19.99230   ||  Valid cost:2s loss: 90.48602/5.24404


100%|██████████| 533/533 [00:26<00:00, 20.13it/s]


[12/50] Train cost:26s loss: 369.92601/20.52871   ||  Valid cost:2s loss: 90.97051/5.08867


100%|██████████| 533/533 [00:26<00:00, 20.14it/s]


[13/50] Train cost:26s loss: 362.17993/19.98575   ||  Valid cost:2s loss: 93.06680/5.32694


100%|██████████| 533/533 [00:26<00:00, 20.06it/s]


[14/50] Train cost:26s loss: 350.55437/19.51361   ||  Valid cost:2s loss: 88.33535/4.87196


100%|██████████| 533/533 [00:28<00:00, 18.70it/s]


[15/50] Train cost:28s loss: 356.23071/19.65060   ||  Valid cost:2s loss: 91.45022/5.04342


100%|██████████| 533/533 [00:26<00:00, 20.35it/s]


[16/50] Train cost:26s loss: 346.34349/19.23603   ||  Valid cost:2s loss: 88.24498/5.08484


100%|██████████| 533/533 [00:26<00:00, 20.00it/s]


[17/50] Train cost:26s loss: 346.52052/19.11763   ||  Valid cost:2s loss: 87.49853/4.83207


100%|██████████| 533/533 [00:25<00:00, 21.16it/s]


[18/50] Train cost:25s loss: 345.69718/19.66203   ||  Valid cost:2s loss: 90.86036/5.43894


100%|██████████| 533/533 [00:27<00:00, 19.61it/s]


[19/50] Train cost:27s loss: 355.68341/19.76899   ||  Valid cost:2s loss: 88.71165/5.07538


100%|██████████| 533/533 [00:27<00:00, 19.12it/s]


[20/50] Train cost:27s loss: 354.34732/19.63757   ||  Valid cost:2s loss: 87.20139/4.87938


100%|██████████| 533/533 [00:25<00:00, 20.99it/s]


[21/50] Train cost:25s loss: 342.30375/19.31422   ||  Valid cost:2s loss: 86.27868/4.81980


100%|██████████| 533/533 [00:26<00:00, 20.20it/s]


[22/50] Train cost:26s loss: 341.18865/19.19796   ||  Valid cost:2s loss: 86.66946/4.79352


100%|██████████| 533/533 [00:26<00:00, 20.09it/s]


[23/50] Train cost:26s loss: 340.49559/19.10856   ||  Valid cost:2s loss: 85.67345/4.76023


100%|██████████| 533/533 [00:24<00:00, 21.33it/s]


[24/50] Train cost:25s loss: 338.90223/19.03979   ||  Valid cost:2s loss: 85.05439/4.75034


100%|██████████| 533/533 [00:24<00:00, 21.36it/s]


[25/50] Train cost:24s loss: 336.85323/18.93943   ||  Valid cost:2s loss: 85.27398/4.73420


100%|██████████| 533/533 [00:23<00:00, 22.41it/s]


[26/50] Train cost:23s loss: 335.45129/18.89836   ||  Valid cost:2s loss: 84.69013/4.72048


100%|██████████| 533/533 [00:25<00:00, 20.65it/s]


[27/50] Train cost:25s loss: 334.75251/18.88223   ||  Valid cost:2s loss: 84.86921/4.71979


100%|██████████| 533/533 [00:29<00:00, 18.27it/s]


[28/50] Train cost:29s loss: 336.02459/18.85852   ||  Valid cost:2s loss: 85.53596/4.71882


100%|██████████| 533/533 [00:25<00:00, 20.51it/s]


[29/50] Train cost:26s loss: 335.08689/18.82129   ||  Valid cost:2s loss: 84.47006/4.69919


100%|██████████| 533/533 [00:24<00:00, 21.54it/s]


[30/50] Train cost:24s loss: 333.87880/18.73396   ||  Valid cost:2s loss: 84.03113/4.67814


100%|██████████| 533/533 [00:26<00:00, 19.89it/s]


[31/50] Train cost:26s loss: 333.47178/18.68336   ||  Valid cost:2s loss: 84.07872/4.67331


100%|██████████| 533/533 [00:26<00:00, 19.99it/s]


[32/50] Train cost:26s loss: 333.41050/18.67204   ||  Valid cost:2s loss: 83.87197/4.66498


100%|██████████| 533/533 [00:23<00:00, 22.52it/s]


[33/50] Train cost:23s loss: 334.22247/18.63151   ||  Valid cost:2s loss: 84.92620/4.65854


100%|██████████| 533/533 [00:27<00:00, 19.66it/s]


[34/50] Train cost:27s loss: 334.74379/18.62947   ||  Valid cost:2s loss: 85.38857/4.66197


100%|██████████| 533/533 [00:27<00:00, 19.56it/s]


[35/50] Train cost:27s loss: 336.23806/18.62579   ||  Valid cost:2s loss: 85.69870/4.65805


100%|██████████| 533/533 [00:26<00:00, 20.16it/s]


[36/50] Train cost:26s loss: 335.38086/18.64099   ||  Valid cost:2s loss: 84.51762/4.64698


100%|██████████| 533/533 [00:27<00:00, 19.51it/s]


[37/50] Train cost:27s loss: 334.03922/18.58272   ||  Valid cost:2s loss: 83.93797/4.63326


100%|██████████| 533/533 [00:25<00:00, 20.58it/s]


[38/50] Train cost:25s loss: 332.56251/18.53450   ||  Valid cost:2s loss: 83.59858/4.62524


100%|██████████| 533/533 [00:25<00:00, 20.89it/s]


[39/50] Train cost:25s loss: 331.72054/18.51644   ||  Valid cost:2s loss: 83.59698/4.62269


100%|██████████| 533/533 [00:26<00:00, 19.78it/s]


[40/50] Train cost:26s loss: 330.64805/18.49096   ||  Valid cost:2s loss: 83.49191/4.61527


100%|██████████| 533/533 [00:25<00:00, 20.65it/s]


[41/50] Train cost:25s loss: 328.81931/18.45371   ||  Valid cost:2s loss: 83.55405/4.61037


100%|██████████| 533/533 [00:26<00:00, 20.14it/s]


[42/50] Train cost:26s loss: 329.47150/18.45405   ||  Valid cost:2s loss: 83.74137/4.61280


100%|██████████| 533/533 [00:25<00:00, 21.27it/s]


[43/50] Train cost:25s loss: 328.37705/18.43438   ||  Valid cost:2s loss: 83.68357/4.60503


100%|██████████| 533/533 [00:26<00:00, 20.07it/s]


[44/50] Train cost:26s loss: 328.64273/18.45502   ||  Valid cost:2s loss: 83.38757/4.61322


100%|██████████| 533/533 [00:26<00:00, 20.33it/s]


[45/50] Train cost:26s loss: 329.24910/18.43787   ||  Valid cost:2s loss: 83.43726/4.61370


100%|██████████| 533/533 [00:24<00:00, 21.65it/s]


[46/50] Train cost:24s loss: 328.52354/18.44124   ||  Valid cost:2s loss: 83.76014/4.60869


100%|██████████| 533/533 [00:27<00:00, 19.65it/s]


[47/50] Train cost:27s loss: 328.17828/18.43557   ||  Valid cost:2s loss: 83.10671/4.60627


100%|██████████| 533/533 [00:27<00:00, 19.30it/s]


[48/50] Train cost:27s loss: 328.86490/18.43950   ||  Valid cost:2s loss: 83.21167/4.60438


100%|██████████| 533/533 [00:26<00:00, 20.35it/s]


[49/50] Train cost:26s loss: 328.76638/18.43690   ||  Valid cost:2s loss: 83.04593/4.60224


100%|██████████| 533/533 [00:27<00:00, 19.71it/s]


[50/50] Train cost:27s loss: 328.11462/18.43408   ||  Valid cost:2s loss: 83.54688/4.60808
训练完成
测试集上预测：


100%|██████████| 86/86 [00:01<00:00, 53.18it/s]


6h APE： 38.589675380111544
12h APE： 81.67666828407347
18h APE： 131.1015203232273
24h APE： 188.08918189553893


66

In [7]:
def test_2018_2021(cfg1=TC_Encoder_config, cfg2=Map_Encoder_config, checkpoint_folder='./model/',
                   xtc_path='./CMA_dataset/xtc.npy', ytc_path='./CMA_dataset/ytc.npy', gph_path='./gph.npy',
                   batch_size=32, is_save=True, checkpoint='./model/model_demo.pth'):
    warnings.filterwarnings("ignore")  # 关闭warning

    # 载入完整数据集
    print('载入完整数据集...')
    xtc = np.load(xtc_path, allow_pickle=True).astype(float)
    ytc = np.load(ytc_path, allow_pickle=True).astype(float)
    gph = np.load(gph_path, allow_pickle=True).astype(float)
    test_index = [*range(22227, 24614)] # 2018年至2021年的台风数据索引
    standard = xtc[:, -1, :2].reshape(-1, 1, 2)

    # 归一化
    xtc, _ = data_scaler(xtc)
    gph, _ = data_scaler(gph)

    # 打包数据集
    test_dataset = DataLoader(TrainLoader(xtc[test_index], ytc[test_index], gph[test_index], standard[test_index]),
                              batch_size=batch_size, shuffle=False)
    del xtc, ytc, standard, gph

    model = DBFNet(cfg1, cfg2)
    if torch.cuda.is_available():
        model = model.cuda()
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # 在测试集上预测
    print('测试集上预测：')
    model.eval()
    pred_save, truth_save, standard_save = [], [], []
    for X1, X2, Y1, _, S in tqdm(test_dataset):
        if torch.cuda.is_available():
            x1 = X1.float().cuda()
            x2 = X2.float().cuda()
            y1 = Y1.float().cuda()
            s = S.float().cuda()
        else:
            x1 = X1.float()
            x2 = X2.float()
            y1 = Y1.float()
            s = S.float()
        pred, _ = model(x1, x2)
        pred_save.append(pred.detach().cpu().numpy())
        truth_save.append(y1.detach().cpu().numpy())
        standard_save.append(s.detach().cpu().numpy())
    # 做反归一化
    pred_save = np.concatenate(pred_save, axis=0)
    truth_save = np.concatenate(truth_save, axis=0)
    standard_save = np.concatenate(standard_save, axis=0)
    # pred_save = data_scaler_inverse(pred_save, y_scaler)
    # 计算最终预测结果
    pred = pred_save + standard_save
    truth = truth_save + standard_save
    for i in range(1, 5):
        pred_lat, pred_lon, truth_lat, truth_lon = pred[:, i - 1, 0], pred[:, i - 1, 1], truth[:, i - 1, 0], truth[:,
                                                                                                             i - 1, 1]
        distance = compute_distance_km(pred_lat, pred_lon, truth_lat, truth_lon)
        print(str(i * 6) + 'h APE：', distance.mean())

In [8]:
test_2018_2021()

载入完整数据集...
测试集上预测：


100%|██████████| 75/75 [00:01<00:00, 66.78it/s]

6h APE： 39.73501210493949
12h APE： 83.73567743841966
18h APE： 133.86184621412266
24h APE： 190.3670649653396



