In [2]:
import numpy as np
import pandas as pd
import math
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from model import *
from optimizer import *
from utils import *



In [83]:
def evaluation(pred, threshold_prob, threshold_trigger, threshold_type):
    # pred: 模型預測結果, (batch_size, wave_length)
    
    # 存每個測站是否 pick 到的結果 & pick 到的時間點
    pred_isTrigger = []
    pred_trigger_sample = []
    
    for i in range(pred.shape[0]):
        isTrigger = False
        
        if threshold_type == 'avg':
            a = pd.Series(pred[i])    
            win_avg = a.rolling(window=threshold_trigger).mean().to_numpy()

            c = np.where(win_avg >= threshold_prob, 1, 0)

            pred_trigger = 0
            if c.any():
                tri = np.where(c==1)
                pred_trigger = tri[0][0]-threshold_trigger+1
                isTrigger = True

        elif threshold_type == 'continue':
            pred = np.where(pred[i] >= threshold_prob, 1, 0)
           
            a = pd.Series(pred)    
            data = a.groupby(a.eq(0).cumsum()).cumsum().tolist()
          
            if threshold_trigger in data:
                pred_trigger = data.index(threshold_trigger)-threshold_trigger+1
                isTrigger = True
            else:
                pred_trigger = 0

        pred_isTrigger.append(isTrigger)
        pred_trigger_sample.append(pred_trigger)
        
    return pred_isTrigger, pred_trigger_sample

In [107]:
def z_score_standardize(data):
    new_wave = torch.empty((data.shape))

    new_wave = data - torch.mean(data, dim=1)[:, None, :]
    new_wave /= torch.std(data, dim=1)[:, None, :]
    
    if torch.any(torch.isinf(new_wave)):
        new_wave[torch.isinf(new_wave)] = 0

    if torch.any(torch.isnan(new_wave)):
        new_wave[torch.isnan(new_wave)] = 0
        
    return new_wave

In [104]:
def predict(wave, model, device, threshold_prob, threshold_trigger, threshold_type):
    # zscore normalization, wave=(batch_size, wave_length, 3)
    wave = wave.permute(1,0).unsqueeze(0).to(device)
    wave = z_score_standardize(wave)
    
    # model inference, wave=(batch_size, wave_length, 3)
    out = model(wave, -1, -1).squeeze()
    
    # if batch_size == 1
    out = out.unsqueeze(0)
    
    # 將 output 套入 threshold 做 picking
    res, pred_trigger = evaluation(out, threshold_prob, threshold_trigger, threshold_type)
    
    return res, pred_trigger

### 測試區

In [66]:
# load model
device = 'cpu'
model = SingleP_Conformer(8, 256, 4, 4, False, False).to(device)
model_path = os.path.join("/mnt/nas4/weiwei/picking_p/results/conformer/", 'model.pt')
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [105]:
# load waveform
import glob
import os
path = glob.glob('/mnt/nas3/earthquake_dataset_large/Palert/*.pt')
wave = torch.load(path[10])

In [108]:
# start predicting
out = predict(wave[:-1], model, device, 0.7, 15, 'continue')
out

([True], [1637])