In [1]:
import time, os, gc
import torch, torchvision
from torch.autograd import Variable
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import xarray as xr
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import datetime

from other_models.Other_Models import *
import warnings

### 计算距离函数

In [2]:
def compute_distance_km(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = lat1.astype('float'), lon1.astype('float'), lat2.astype('float'), lon2.astype('float')
    # 批量计算地球上两点间的球面距离
    R = 6371e3  # 地球半径（米）
    phi_1, phi_2 = np.radians(lat1), np.radians(lat2)
    delta_phi = np.radians(lat2 - lat1)
    delta_lambda = np.radians(lon2 - lon1)
    a = np.power(np.sin(delta_phi / 2), 2) + np.cos(phi_1) * np.cos(phi_2) * np.power(np.sin(delta_lambda / 2), 2)
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
    return R * c / 1000.

### 数据集相关类和函数

In [3]:
def data_scalar(x, scaler=None, inverse=False):
    if inverse and (scaler is not None):
        shape = x.shape
        x = x.reshape(shape[0], -1)
        x = scaler.inverse_transform(x)
        x = x.reshape(shape)
    else:
        shape = x.shape
        scaler = MinMaxScaler()
        x = x.reshape(shape[0], -1)
        x = scaler.fit_transform(x)
        x = x.reshape(shape)
    return x, scaler


class RNN_TrainLoader(Dataset):
    def __init__(self, x, y, standard):
        self.x = x
        self.y = y
        self.s = standard

    def __getitem__(self, item):
        return self.x[item], self.y[item], self.s[item]

    def __len__(self):
        return len(self.x)


class CNN_TrainLoader(Dataset):
    def __init__(self, x, y, map, standard):
        self.x = x
        self.y = y
        self.m = map
        self.s = standard

    def __getitem__(self, item):
        return self.x[item], self.y[item], self.m[item], self.s[item]

    def __len__(self):
        return len(self.x)

### RNN系列模型训练函数

In [4]:
def train_RNN(type, epoch=100, lr=0.01, batch_size=32, xtc_path='./CMA_dataset/xtc.npy',
               ytc_path='./CMA_dataset/ytc.npy', split_index_path='./CMA_dataset/split_index.npy',
               model_save_path='./other_models/checkpoints/'):
    if type == 'LSTM':
        model = LSTM_Model()
        model_save_path = model_save_path + 'LSTM_demo.pth'
    elif type == 'BiLSTM':
        model = BiLSTM_Model()
        model_save_path = model_save_path + 'BiLSTM_demo.pth'
    elif type == 'GRU':
        model = GRU_Model()
        model_save_path = model_save_path + 'GRU_demo.pth'
    elif type == 'BiGRU':
        model = BiGRU_Model()
        model_save_path = model_save_path + 'BiGRU_demo.pth'
    if torch.cuda.is_available():
        model = model.cuda()

    xtc = np.load(xtc_path, allow_pickle=True).astype(float)
    ytc = np.load(ytc_path, allow_pickle=True).astype(float)
    standard = xtc[:, -1, :2].reshape(-1, 1, 2)
    xtc, _ = data_scalar(xtc)

    split_index = np.load(split_index_path, allow_pickle=True)
    train_index = [*range(split_index[0], split_index[2])]
    test_index = [*range(split_index[2], split_index[3])]
    train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=0.2)

    rnn_train_dataset = DataLoader(RNN_TrainLoader(xtc[train_index, :, :2], ytc[train_index], standard[train_index]),
                                   batch_size=batch_size, shuffle=True, drop_last=True)
    rnn_val_dataset = DataLoader(RNN_TrainLoader(xtc[val_index, :, :2], ytc[val_index], standard[val_index]),
                                 batch_size=batch_size, shuffle=True, drop_last=True)
    rnn_test_dataset = DataLoader(RNN_TrainLoader(xtc[test_index, :, :2], ytc[test_index], standard[test_index]),
                                  batch_size=batch_size, shuffle=True)

    criterion = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    min_val_loss = 99999
    for i in range(epoch):
        model.train()
        train_loss = 0
        for x, y, _ in rnn_train_dataset:
            if torch.cuda.is_available():
                x = x.float().cuda()
                y = y.float().cuda()
            else:
                x = x.float()
                y = y.float()
            p = model(x)
            loss = criterion(p, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.detach().item()
        model.eval()
        val_loss = 0
        for x, y, _ in rnn_val_dataset:
            if torch.cuda.is_available():
                x = x.float().cuda()
                y = y.float().cuda()
            else:
                x = x.float()
                y = y.float()
            p = model(x)
            loss = criterion(p, y)
            val_loss += loss.detach().item()
        if val_loss < min_val_loss:
            torch.save(model.state_dict(), model_save_path)
            min_val_loss = val_loss
        print('[%d/%d] Train Loss: %.5f / Val Loss: %.5f' % (i + 1, epoch, train_loss, val_loss))

    state_dice = torch.load(model_save_path)
    model.load_state_dict(state_dice)

    model.eval()
    with torch.no_grad():
        P, T, S = [], [], []
        for x, y, s in rnn_test_dataset:
            if torch.cuda.is_available():
                x = x.float().cuda()
            else:
                x = x.float()
            s = s.float()
            y = y.float()
            p = model(x)
            P.append(p.detach().cpu().numpy())
            S.append(s.detach().cpu().numpy())
            T.append(y.detach().cpu().numpy())
    P = np.concatenate(P, axis=0) + np.concatenate(S, axis=0)
    T = np.concatenate(T, axis=0) + np.concatenate(S, axis=0)

    for i in range(4):
        d = compute_distance_km(P[:, i, 0], P[:, i, 1], T[:, i, 0], T[:, i, 1])
        d = d.mean()
        print('%dh: %.5f' % ((i + 1) * 6, d))

### CNN模型训练函数

In [9]:
def train_CNN(epoch=100, lr=0.01, batch_size=32, xtc_path='./CMA_dataset/xtc.npy',
               ytc_path='./CMA_dataset/ytc.npy', split_index_path='./CMA_dataset/split_index.npy',
               map_path='./gph.npy', model_save_path='./other_models/checkpoints/'):
    model = CNN_Model()
    model_save_path = model_save_path + 'CNN_demo.pth'
    if torch.cuda.is_available():
        model = model.cuda()

    xtc = np.load(xtc_path, allow_pickle=True).astype(float)
    ytc = np.load(ytc_path, allow_pickle=True).astype(float)
    standard = xtc[:, -1, :2].reshape(-1, 1, 2)
    xtc, _ = data_scalar(xtc)
    xmap = np.load(map_path, allow_pickle=True).astype(float)
    xmap, _ = data_scalar(xmap)

    split_index = np.load(split_index_path, allow_pickle=True)
    train_index = [*range(split_index[0], split_index[2])]
    test_index = [*range(split_index[2], split_index[3])]
    train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=0.2)

    cnn_train_dataset = DataLoader(
        CNN_TrainLoader(xtc[train_index, 1:, :2], ytc[train_index], xmap[train_index, 4:], standard[train_index]),
        batch_size=batch_size, shuffle=True, drop_last=True)
    cnn_val_dataset = DataLoader(
        CNN_TrainLoader(xtc[val_index, 1:, :2], ytc[val_index], xmap[val_index, 4:], standard[val_index]),
        batch_size=batch_size, shuffle=True, drop_last=True)
    cnn_test_dataset = DataLoader(
        CNN_TrainLoader(xtc[test_index, 1:, :2], ytc[test_index], xmap[test_index, 4:], standard[test_index]),
        batch_size=batch_size, shuffle=True)
    del xtc, ytc, xmap
    
    criterion = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    min_val_loss = 99999
    for i in range(epoch):
        model.train()
        train_loss = 0
        for x, y, m, _ in cnn_train_dataset:
            if torch.cuda.is_available():
                x = x.float().cuda()
                y = y.float().cuda()
                m = m.float().cuda()
            else:
                x = x.float()
                y = y.float()
                m = m.float()
            p = model(x, m)
            loss = criterion(p, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.detach().item()
        model.eval()
        val_loss = 0
        for x, y, m, _ in cnn_val_dataset:
            if torch.cuda.is_available():
                x = x.float().cuda()
                y = y.float().cuda()
                m = m.float().cuda()
            else:
                x = x.float()
                y = y.float()
                m = m.float()
            p = model(x, m)
            loss = criterion(p, y)
            val_loss += loss.detach().item()
        if val_loss < min_val_loss:
            torch.save(model.state_dict(), model_save_path)
            min_val_loss = val_loss
        print('[%d/%d] Train Loss: %.5f / Val Loss: %.5f' % (i + 1, epoch, train_loss, val_loss))

    state_dice = torch.load(model_save_path)
    model.load_state_dict(state_dice)

    model.eval()
    with torch.no_grad():
        P, S, T = [], [], []
        for x, y, m, s in cnn_test_dataset:
            if torch.cuda.is_available():
                x = x.float().cuda()
                m = m.float().cuda()
            else:
                x = x.float()
                m = m.float()
            s = s.float()
            y = y.float()
            p = model(x, m)
            P.append(p.detach().cpu().numpy())
            S.append(s.detach().cpu().numpy())
            T.append(y.detach().cpu().numpy())
    P = np.concatenate(P, axis=0)
    T = np.concatenate(T, axis=0)
    print(np.concatenate([P, T], axis=2)[:20])
    P = P + np.concatenate(S, axis=0)
    T = T + np.concatenate(S, axis=0)
    print(np.concatenate([P, T], axis=2)[:20])

    for i in range(4):
        d = compute_distance_km(P[:, i, 0], P[:, i, 1], T[:, i, 0], T[:, i, 1])
        d = d.mean()
        print('%dh: %.5f' % ((i + 1) * 6, d))

In [6]:
train_RNN(type='LSTM')
gc.collect()

[1/100] Train Loss: 788.64592 / Val Loss: 188.56736
[2/100] Train Loss: 756.33360 / Val Loss: 184.05788
[3/100] Train Loss: 743.71184 / Val Loss: 182.59880
[4/100] Train Loss: 741.13760 / Val Loss: 182.59105
[5/100] Train Loss: 739.12292 / Val Loss: 182.12083
[6/100] Train Loss: 737.27419 / Val Loss: 183.26221
[7/100] Train Loss: 735.97147 / Val Loss: 181.33246
[8/100] Train Loss: 734.94627 / Val Loss: 181.63880
[9/100] Train Loss: 734.39018 / Val Loss: 181.09516
[10/100] Train Loss: 733.66704 / Val Loss: 181.27636
[11/100] Train Loss: 732.08484 / Val Loss: 180.79805
[12/100] Train Loss: 732.00015 / Val Loss: 180.70709
[13/100] Train Loss: 732.03595 / Val Loss: 180.54508
[14/100] Train Loss: 731.44854 / Val Loss: 180.34020
[15/100] Train Loss: 730.86449 / Val Loss: 180.52571
[16/100] Train Loss: 730.72565 / Val Loss: 180.51903
[17/100] Train Loss: 729.95131 / Val Loss: 181.65755
[18/100] Train Loss: 729.95904 / Val Loss: 180.74573
[19/100] Train Loss: 730.55177 / Val Loss: 181.03675
[2

0

In [7]:
train_RNN(type='BiLSTM')
gc.collect()

[1/100] Train Loss: 770.78941 / Val Loss: 182.61424
[2/100] Train Loss: 739.75265 / Val Loss: 181.57361
[3/100] Train Loss: 730.11693 / Val Loss: 180.69195
[4/100] Train Loss: 729.70079 / Val Loss: 179.27196
[5/100] Train Loss: 715.84140 / Val Loss: 176.00143
[6/100] Train Loss: 707.05977 / Val Loss: 175.28598
[7/100] Train Loss: 705.95195 / Val Loss: 175.31580
[8/100] Train Loss: 704.10952 / Val Loss: 174.97755
[9/100] Train Loss: 702.83133 / Val Loss: 175.87199
[10/100] Train Loss: 702.08747 / Val Loss: 174.25660
[11/100] Train Loss: 702.37961 / Val Loss: 175.61442
[12/100] Train Loss: 698.38702 / Val Loss: 175.99973
[13/100] Train Loss: 697.62822 / Val Loss: 175.39034
[14/100] Train Loss: 697.30223 / Val Loss: 174.99462
[15/100] Train Loss: 697.15369 / Val Loss: 174.39792
[16/100] Train Loss: 696.94816 / Val Loss: 173.50287
[17/100] Train Loss: 694.11364 / Val Loss: 175.14944
[18/100] Train Loss: 693.90701 / Val Loss: 173.77728
[19/100] Train Loss: 705.02688 / Val Loss: 172.24169
[2

21

In [8]:
train_RNN(type='GRU')
gc.collect()

[1/100] Train Loss: 786.24531 / Val Loss: 187.48085
[2/100] Train Loss: 747.26189 / Val Loss: 179.49012
[3/100] Train Loss: 752.25426 / Val Loss: 179.84745
[4/100] Train Loss: 744.31030 / Val Loss: 183.58123
[5/100] Train Loss: 748.29380 / Val Loss: 186.42392
[6/100] Train Loss: 740.12223 / Val Loss: 179.19080
[7/100] Train Loss: 759.20620 / Val Loss: 181.84983
[8/100] Train Loss: 782.40994 / Val Loss: 193.12387
[9/100] Train Loss: 750.33961 / Val Loss: 181.47858
[10/100] Train Loss: 733.72164 / Val Loss: 179.63826
[11/100] Train Loss: 733.25728 / Val Loss: 179.53201
[12/100] Train Loss: 730.90103 / Val Loss: 179.28404
[13/100] Train Loss: 738.33643 / Val Loss: 178.29329
[14/100] Train Loss: 730.56379 / Val Loss: 178.49702
[15/100] Train Loss: 728.83941 / Val Loss: 178.81122
[16/100] Train Loss: 738.06842 / Val Loss: 178.44021
[17/100] Train Loss: 729.47479 / Val Loss: 177.92941
[18/100] Train Loss: 729.18426 / Val Loss: 178.56592
[19/100] Train Loss: 728.04100 / Val Loss: 177.59852
[2

21

In [9]:
train_RNN(type='BiGRU')
gc.collect()

[1/100] Train Loss: 751.74101 / Val Loss: 180.75455
[2/100] Train Loss: 732.12414 / Val Loss: 179.71873
[3/100] Train Loss: 727.02035 / Val Loss: 178.04756
[4/100] Train Loss: 718.80516 / Val Loss: 175.85080
[5/100] Train Loss: 715.33475 / Val Loss: 175.34014
[6/100] Train Loss: 708.07768 / Val Loss: 176.45782
[7/100] Train Loss: 703.36910 / Val Loss: 174.68357
[8/100] Train Loss: 704.61016 / Val Loss: 172.36951
[9/100] Train Loss: 702.48134 / Val Loss: 172.06246
[10/100] Train Loss: 700.65669 / Val Loss: 173.50328
[11/100] Train Loss: 703.97978 / Val Loss: 173.89260
[12/100] Train Loss: 704.50710 / Val Loss: 171.86448
[13/100] Train Loss: 705.02984 / Val Loss: 172.19056
[14/100] Train Loss: 702.48981 / Val Loss: 172.05721
[15/100] Train Loss: 705.30167 / Val Loss: 175.85554
[16/100] Train Loss: 708.71713 / Val Loss: 171.87095
[17/100] Train Loss: 704.43779 / Val Loss: 172.10667
[18/100] Train Loss: 703.84274 / Val Loss: 171.62364
[19/100] Train Loss: 708.47512 / Val Loss: 184.82715
[2

21