In [2]:
import os
import os.path as osp
import json
from typing import Any
import torch
import pickle
import logging
import numpy as np
from model import Fourier_Model
from tqdm import tqdm
from API import dataloader
from utils import *
from timeit import default_timer
import torch
import os
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import argparse
import io

In [3]:
class Tester:
    def __init__(self, model, test_loader, path):
        self.model = model
        self.test_loader = test_loader
        self.path = path
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def load_checkpoint(self, checkpoint_path):
        with open(checkpoint_path, 'rb') as f:
            buffer = io.BytesIO(f.read())
        
        checkpoint = torch.load(buffer)
        self.model.load_state_dict(checkpoint)
        self.model.to(self.device)

    def test(self, checkpoint_path):
        self.load_checkpoint(checkpoint_path) 
        self.model.eval()
        inputs_lst, trues_lst, preds_lst = [], [], []
        for batch_x, batch_y in self.test_loader:
            pred_y, x_rec = self.model(batch_x.to(self.device))


            list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
                 batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
        inputs, trues, preds = map(lambda data: np.concatenate(
            data, axis=0), [inputs_lst, trues_lst, preds_lst])

        folder_path = self.path+'/results/{}/sv/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        # 计算RMSE
        rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2))
        rmse = rmse.item()
        
        for np_data in ['inputs', 'trues', 'preds']:
            np.save(osp.join(folder_path, np_data + '.npy'), vars()[np_data])
        return rmse

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class TaxiBJDataset(torch.utils.data.Dataset):
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
        self.mean = np.mean(input_data)
        self.std = np.std(input_data)
        
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, idx):
        input_sample = self.input_data[idx]
        output_sample = self.output_data[idx]
        
        # normalize the input and output samples
        #input_sample = (input_sample - self.mean) / self.std
        #output_sample = (output_sample - self.mean) / self.std
        
        input_sample = torch.from_numpy(input_sample).to(torch.float32)
        output_sample = torch.from_numpy(output_sample).to(torch.float32)
        return input_sample[:,:,::4,::4], output_sample
    

def load_data(batch_size, val_batch_size, data_root, num_workers):

    train_input_data = np.load(data_root + 'TaxiBJ/train_input_data_reshaped.npy')
    train_output_data = np.load(data_root + 'TaxiBJ/train_output_data_reshaped.npy')
    # train_input_data = train_input_data.reshape(3555, 12, 2, 128, 128)
    # train_output_data = train_output_data.reshape(3555, 12, 2, 128, 128)


    test_input_data = np.load(data_root + 'TaxiBJ/test_input_data_reshaped.npy')
    test_output_data = np.load(data_root + 'TaxiBJ/test_output_data_reshaped.npy')
    # test_input_data = test_input_data.reshape(445, 12, 2, 128, 128)
    # test_output_data = test_output_data.reshape(445, 12, 2, 128, 128)
    

    val_input_data = np.load(data_root+'TaxiBJ/val_input_data_reshaped.npy')
    val_output_data = np.load(data_root+'TaxiBJ/val_output_data_reshaped.npy')
    # val_input_data = val_input_data.reshape(444, 12, 2, 128, 128)
    # val_output_data = val_output_data.reshape(444, 12, 2, 128, 128)

    train_set = TaxiBJDataset(train_input_data, train_output_data)
    test_set = TaxiBJDataset(test_input_data, test_output_data)
    val_set = TaxiBJDataset(val_input_data, val_output_data)

    dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
    dataloader_test = DataLoader(test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
    dataloader_validation = DataLoader(val_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)

    mean, std = 0, 1

    return dataloader_train, dataloader_validation, dataloader_test, mean, std

# if __name__ == '__main__':
#     dataloader_train, dataloader_validation, dataloader_test, mean, std = load_data(batch_size=1, 
#                                                                                     val_batch_size=1, 
#                                                                                     data_root='/data/workspace/yancheng/MM/earthfarseer/data/',
#                                                                                     num_workers=8)
#     for input_frames, output_frames in iter(dataloader_train):
#         print(input_frames.shape, output_frames.shape)

In [5]:
dataloader_train, dataloader_validation, dataloader_test, mean, std = load_data(batch_size=1, 
                                                                                    val_batch_size=1, 
                                                                                    data_root='/data/workspace/yancheng/MM/earthfarseer/data/',
                                                                                    num_workers=8)


In [None]:
model = Fourier_Model(shape_in = [12, 2, 128, 128], hid_S=64, hid_T=256, N_S=2, N_T=4, groups=2)

test_loader = dataloader_test

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

path = '/data/workspace/yancheng/MM/earthfarseer/results_super_res'

tester = Tester(model=model, test_loader=test_loader, path=path)
args = argparse.Namespace(checkpoint_path='/data/workspace/yancheng/MM/earthfarseer/results_super_res/results_super_res/checkpoint.pth')
mse = tester.test(args.checkpoint_path)
print(mse)