In [133]:
from tracemalloc import start
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
from datasets.dataset import *
# from models.Autoformer import *(我h'h'h)
from tqdm import tqdm
import argparse
import random

import scipy.stats
from scipy.signal import convolve2d
from scipy.signal import find_peaks
from sklearn.metrics import *

from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.metrics import metric
import os

from my_utils import *

from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
SEQ_LEN = 128
MAST_RATIO = 0.2
BATCH_SIZE = 10

data_train = Qiang_Anomaly(setting='train', window_size=SEQ_LEN, mask_ratio=MAST_RATIO, test_mode=False, is_abs=False)
data_valid = Qiang_Anomaly(setting='valid', window_size=SEQ_LEN, mask_ratio=MAST_RATIO, test_mode=False, is_abs=False)
data_test = Qiang_Anomaly(setting='test', window_size=SEQ_LEN, mask_ratio=MAST_RATIO, test_mode=False, is_abs=False)

data_train = ConcatDataset([data_train, data_valid])

dataloader_train = DataLoader(data_train, batch_size=64, shuffle=False, num_workers=1, collate_fn=collate_fn)
dataloader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, collate_fn=collate_fn)
## Here batch size is equivelant to stride

In [44]:
for i, data in enumerate(tqdm(dataloader_train)): # (1) stastic on train set
    
    x, y, index_start, index_end, mask_size, meta = data # bs, channels, seq
    
    z = torch.concat([x, y], 2)[0].numpy()
    break

  0%|                                                                                                                                 | 0/602 [00:00<?, ?it/s]


In [3]:
dpath = '/data/rech/dingqian/data_das/patches_inference.pt'
ppath = '/data/rech/dingqian/data_das/picks_inference.pt'
ipath = '/data/rech/dingqian/data_das/index_inference.npy'
mpath = '/data/rech/dingqian/data_das/meta_inference.npy'

print(f'Loading data from: {dpath}')
# img_list = sorted(glob.glob(dpath + '/*.png') + glob.glob(dpath + '/*.jpg'))
tensors = torch.load(dpath)
# tensors = tensors.unsqueeze(1) # N, 1, 128, 128

print(f'Loading picks from: {ppath}')
picks = torch.load(ppath)
# picks = picks.unsqueeze(1) # N, 1, 128, 128

print(f'Loading indices_start from: {ipath}')
indices_start = np.load(ipath, allow_pickle=True)

print(f'Loading metas from: {mpath}') # [{'file': 'xxx'}]
metas = np.load(mpath, allow_pickle=True)

Loading data from: /data/rech/dingqian/data_das/patches_inference.pt
Loading picks from: /data/rech/dingqian/data_das/picks_inference.pt
Loading indices_start from: /data/rech/dingqian/data_das/index_inference.npy
Loading metas from: /data/rech/dingqian/data_das/meta_inference.npy


In [33]:
opath = '/data/rech/dingqian/data_das/output_right_10.npy'

print(f'Loading outputs from: {opath}') # [{'file': 'xxx'}]
outputs = np.load(opath, allow_pickle=True)

Loading outputs from: /data/rech/dingqian/data_das/output_right_10.npy


In [280]:
def load_npz(f):
    meta = dict(np.load(f))
    data = meta["data"]
    data = data.astype(np.float32)

    data = moving_average_fast(data.T, 10).T # 900 x 1250

    data -= np.median(data, axis=1, keepdims=True)
    data -= np.mean(data, axis=0)
    data /= np.std(data, axis=0)
    data = data.T # 1250 x 900

    meta["data"] = data
    meta["file"] = f

    return meta 

def map_emd(a, b):
    return scipy.stats.wasserstein_distance(a.flatten(), b.flatten())

def sliced_emd(a, b, slice_size=32, step_size=16):
    a_ts = torch.tensor(a)
    b_ts = torch.tensor(b)

    a_sliced = a_ts.unfold(0, slice_size, step_size).unfold(1, slice_size, step_size).flatten(0,1)
    b_sliced = b_ts.unfold(0, slice_size, step_size).unfold(1, slice_size, step_size).flatten(0,1)

    emd_list = list(map(map_emd, a_sliced, b_sliced))
    emd = sum(emd_list)

    return emd

def kl(a, b):
    a_norm = scipy.special.softmax(a.flatten())
    b_norm = scipy.special.softmax(b.flatten())

    ret = 0.5 * (scipy.stats.entropy(a_norm.flatten(), b_norm.flatten()) + scipy.stats.entropy(b_norm.flatten(), a_norm.flatten()))
    return ret

def js(a, b):
    a_norm = scipy.special.softmax(a.flatten())
    b_norm = scipy.special.softmax(b.flatten())

    ret = scipy.spatial.distance.jensenshannon(a_norm.flatten(), b_norm.flatten())
    return ret

def hist_kl(a, b):
    a_clip = np.clip(a, -2, 2)
    b_clip = np.clip(b, -2, 2)

    a_hist_prob = np.histogram(a_clip, 10)[0] / np.histogram(a_clip, 10)[0].sum()
    b_hist_prob = np.histogram(b_clip, 10)[0] / np.histogram(b_clip, 10)[0].sum()
    
    ret = 0.5 * (scipy.stats.entropy(a_hist_prob, b_hist_prob) + scipy.stats.entropy(b_hist_prob, a_hist_prob))
    return ret

def hist_js(a, b):
    a_clip = np.clip(a, -2, 2)
    b_clip = np.clip(b, -2, 2)

    a_hist_prob = np.histogram(a_clip, 10)[0] / np.histogram(a_clip, 10)[0].sum()
    b_hist_prob = np.histogram(b_clip, 10)[0] / np.histogram(b_clip, 10)[0].sum()
    
    ret = scipy.spatial.distance.jensenshannon(a_hist_prob, b_hist_prob)
    return ret

def get_metrics(y, pred):
    ret_auc = roc_auc_score(y, pred)
    fpr, tpr, thresholds = roc_curve(y, pred)
    
    maxindex = (tpr - fpr).tolist().index(max(tpr - fpr))
    threshold = thresholds[maxindex]
    
    y_pred = pred > threshold
    
    '''ret_f1 = 0.
    for threshold in tqdm(thresholds):
        f1_now = f1_score(y, pred > threshold)
        if f1_now > ret_f1:
            ret_f1 = f1_now
            ret_threshold = threshold'''
    
    ret_threshold = threshold
    ret_f1 = f1_score(y, pred > threshold)
    
    return ret_auc, ret_f1, ret_threshold
    

In [258]:
pred_dict = {}

diff_list = []
diff_abs_list = []
diff_emd_list = []
diff_sliced_emd_list = []

diff_kl_list = []
diff_js_list = []
diff_hist_kl_list = []
diff_hist_js_list = []

sum_list = []

has_pick_list = []

output_right_list = []
output_middle_list = []

mask_size = 10
mask_start = 128 - 23

for i in tqdm(range(len(metas))):
    tensor = tensors[i]
    pick_np = picks[i].squeeze().numpy()
    index_start = indices_start[i]
    meta = metas[i]
    input_np = tensor.squeeze().numpy()
    output_np = outputs[i]
    
    # print(meta['file'])
    
    if meta['file'] not in pred_dict:
        pred_dict[meta['file']] = load_npz(meta['file'])['data']
    
    data = pred_dict[meta['file']]
    
    # patch_origin = data[index_start[0]: index_start[0]+128, index_start[1]: index_start[1]+128]
    
    x_start = index_start[0] + mask_start
    
    data[index_start[1]: index_start[1]+128, x_start: x_start + mask_size] = output_np[:, mask_start: mask_start + mask_size]
    
    pred_dict[meta['file']] = data
    
    ########### calculate metrics
    
    input_mask = input_np[:, mask_start: mask_start + mask_size]
    output_mask = output_np[:, mask_start: mask_start + mask_size]
    pick_mask = pick_np[:, mask_start: mask_start + mask_size]
    
    has_pick = (pick_mask.sum() > 0)
    
    diff = (output_mask - input_mask).sum()
    diff_abs = (np.abs(output_mask) - np.abs(input_mask)).sum()
    diff_emd = scipy.stats.wasserstein_distance(input_mask.flatten(), output_mask.flatten())
    diff_sliced_emd = sliced_emd(input_mask, output_mask, 10, 5)

    diff_kl = kl(input_mask, output_mask)
    diff_js = js(input_mask, output_mask)
    diff_hist_kl = hist_kl(input_mask, output_mask)
    diff_hist_js = hist_js(input_mask, output_mask)

    summ = input_mask.sum()

    diff_list.append(diff)
    diff_abs_list.append(diff_abs)
    diff_emd_list.append(diff_emd)
    diff_sliced_emd_list.append(diff_sliced_emd)

    diff_kl_list.append(diff_kl)
    diff_js_list.append(diff_js)
    diff_hist_kl_list.append(diff_hist_kl)
    diff_hist_js_list.append(diff_hist_js)

    sum_list.append(summ)

    has_pick_list.append(has_pick)
    
    '''print(index_start[1], index_start[1]+128, x_end - 10, x_end)
    print(output)
    print(data[index_start[1]: index_start[1]+128, index_start[0]: x_end])
    print(tensor)'''

auc_diff = roc_auc_score(has_pick_list, diff_list)
auc_diff_abs = roc_auc_score(has_pick_list, diff_abs_list)
auc_emd = roc_auc_score(has_pick_list, diff_emd_list)
auc_sum = roc_auc_score(has_pick_list, sum_list)
auc_sliced_emd = roc_auc_score(has_pick_list, diff_sliced_emd_list)

auc_kl = roc_auc_score(has_pick_list, diff_kl_list)
auc_js = roc_auc_score(has_pick_list, diff_js_list)

auc_hist_kl = roc_auc_score(has_pick_list, np.nan_to_num(diff_hist_kl_list))
auc_hist_js = roc_auc_score(has_pick_list, np.nan_to_num(diff_hist_js_list))

print('AUC_DIFF: {}\nAUC_DIFF_ABS: {}\nAUC_EMD: {}\nAUC_SUM: {}\nAUC_SLICED_EMD: {}'.format(auc_diff, auc_diff_abs, auc_emd, auc_sum, auc_sliced_emd))
print('AUC_KL: {}\nAUC_JS: {}\nAUC_HIST_KL: {}\nAUC_HIST_JS: {}'.format(auc_kl, auc_js, auc_hist_kl, auc_hist_js))
# roc_auc_score(has_pick_list, diff_list)


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62622/62622 [03:58<00:00, 262.67it/s]


AUC_DIFF: 0.51635217635449
AUC_DIFF_ABS: 0.16063142675802528
AUC_EMD: 0.8649358148163835
AUC_SUM: 0.5782417695745816
AUC_SLICED_EMD: 0.8282211596400559
AUC_KL: 0.7240955464208712
AUC_JS: 0.7242464931790527
AUC_HIST_KL: 0.7061906496383542
AUC_HIST_JS: 0.7141182404802505


  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


In [281]:
ret_auc, ret_f1, ret_threshold = get_metrics(has_pick_list, diff_sliced_emd_list)

In [283]:
ret_auc, ret_f1, ret_threshold

(0.8282211596400559, 0.19660207966361726, 5.569320248818258)

In [326]:
pred_dict = {}
heat_dict = {}
mask_size = 10
mask_start = 128 - 23

for i in tqdm(range(len(metas))):
    tensor = tensors[i]
    pick_np = picks[i].squeeze().numpy()
    index_start = indices_start[i]
    meta = metas[i]
    input_np = tensor.squeeze().numpy()
    output_np = outputs[i]
    
    ### diff_emd_list should be replaced
    
    # pred_label = diff_emd_list[i] > (0.1)
    #pred_label = diff_sliced_emd_list[i] > (ret_threshold)
    
    pred_label = diff_emd_list[i]
    
    # print(meta['file'])
    
    #### Load pred, heat
    if meta['file'] not in pred_dict:
        pred_dict[meta['file']] = load_npz(meta['file'])['data']
        heat_dict[meta['file']] = np.zeros(pred_dict[meta['file']].shape)
    
    data = pred_dict[meta['file']]
    heat = heat_dict[meta['file']]
    
    # patch_origin = data[index_start[0]: index_start[0]+128, index_start[1]: index_start[1]+128]
    
    #### Calculate new data and new heat based on threshold
    x_start = index_start[0] + mask_start
    
    data[index_start[1]: index_start[1]+128, x_start: x_start + mask_size] = output_np[:, mask_start: mask_start + mask_size]
    
    heat[index_start[1]: index_start[1]+128, x_start: x_start + mask_size] = pred_label
    
    pred_dict[meta['file']] = data
    heat_dict[meta['file']] = heat
    
    
    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62622/62622 [00:06<00:00, 10101.41it/s]


In [262]:
x, y = np.where(heat == 1)
np.stack([x,y], axis=1)

array([[   0,  333],
       [   0,  334],
       [   0,  335],
       ...,
       [1215,  880],
       [1215,  881],
       [1215,  882]])

In [329]:
def save_z(z, file_name, heat=None): # N x M
    plt.figure(figsize=(8,8), frameon=False)
    plt.axis('off')
    plt.imshow(z, aspect='auto', vmin=-2.0, vmax=2.0, cmap="seismic")
    
    heat = heat / heat.max()
    plt.imshow(heat, aspect='auto', vmin=0., vmax=1.0, cmap="seismic", alpha=0.8)
    ### Add heat
    '''if heat is not None:
        x, y = np.where(heat > 0)
        # np.stack([x,y], axis=1)
        # plt.plot(y, x, alpha=0.5, color='r')
        alpha = 1. if heat[x, y] > 2. else heat[x, y]
        plt.plot(y, x, alpha=heat[x, y], color='r')'''
    
    plt.tight_layout()
    plt.savefig(os.path.join('./samples/inference_10', os.path.split(file_name)[-1] + '.jpg'), pad_inches=0, dpi=64)
    
    plt.cla()
    plt.clf()
    plt.close()
    


In [330]:
for file_name, data_pred in pred_dict.items():
    print(file_name)
    data = load_npz(file_name)['data']
    heat = heat_dict[file_name]
    
    save_z(data_pred, file_name + '_pred', heat)
    save_z(data, file_name, heat)
    

./data/DAS/test_npz/2020-12-01T01-54-49-M3-24km.npz
./data/DAS/test_npz/2021-05-27T15-18-26-M3-30km.npz
./data/DAS/test_npz/2020-07-25T19-36-01-M2-15km.npz
./data/DAS/test_npz/2020-04-05T08-21-41-M2-55km.npz
./data/DAS/test_npz/2020-04-04T04-19-09-M2-17km.npz
./data/DAS/test_npz/2019-11-10T12-35-18-M2-24km.npz
./data/DAS/test_npz/2019-08-18T22-02-28-M3-22km.npz
./data/DAS/test_npz/2020-10-02T06-00-56-M2-24km.npz
./data/DAS/test_npz/2019-12-08T09-59-23-M1-12km.npz
./data/DAS/test_npz/2021-02-20T09-54-10-M2-33km.npz
./data/DAS/test_npz/2019-08-25T10-36-27-M3-19km.npz
./data/DAS/test_npz/2019-09-03T10-36-23-M2-22km.npz
./data/DAS/test_npz/2019-09-17T05-01-20-M2-16km.npz
./data/DAS/test_npz/2021-04-08T20-51-15-M4-15km.npz
./data/DAS/test_npz/2019-11-07T06-06-05-M2-17km.npz
./data/DAS/test_npz/2019-12-03T04-41-40-M2-19km.npz
./data/DAS/test_npz/2020-05-14T01-36-08-M2-19km.npz
./data/DAS/test_npz/2020-04-10T04-55-32-M2-18km.npz
./data/DAS/test_npz/2020-04-22T10-21-56-M2-18km.npz
./data/DAS/t

In [304]:
max(diff_emd_list)

3.6707442508025903

In [267]:
ret_f1

0.3843120070113935

In [245]:
###########Test 
from scipy.fft import rfft, rfftfreq

def get_fft(z):
    n = len(z)
    yf = rfft(z)
    xf = rfftfreq(n, 0.1)
    return np.abs(yf)
    

In [246]:
output = outputs[0]
inp = tensors[0].squeeze().numpy()

output_f = np.zeros((128, 65), dtype=float)
for i in range(128):
    output_f[i] = get_fft(output[i])

In [247]:
plt.imshow(output_f, aspect='auto', vmin=-5.0, vmax=5.0, cmap="seismic")
plt.savefig('freq_output.jpg')

<Figure size 640x480 with 1 Axes>


In [249]:
input_f = np.zeros((128, 65), dtype=float)
for i in range(128):
    input_f[i] = get_fft(inp[i])

plt.imshow(input_f, aspect='auto', vmin=-5.0, vmax=5.0, cmap="seismic")
plt.savefig('freq_input.jpg')

<Figure size 640x480 with 1 Axes>


In [252]:
plt.imshow(output_f - input_f, aspect='auto', vmin=-1, vmax=1, cmap="seismic")
plt.savefig('freq_diff.jpg')

<Figure size 640x480 with 1 Axes>


In [244]:
tensors[0].squeeze().numpy()

array([[ 0.1548734 , -0.24163167,  0.21370491, ..., -0.7274417 ,
        -0.22254851, -0.5854341 ],
       [ 1.6204568 , -0.7617863 ,  0.13278896, ..., -0.03637507,
         0.67147183, -0.13801157],
       [ 1.0656402 , -1.6293908 ,  1.0746189 , ..., -1.6335528 ,
         0.40482336,  1.041287  ],
       ...,
       [-0.14853552, -0.04680762,  0.27739307, ...,  0.00496372,
        -0.10391421, -0.03731848],
       [-0.27020475,  0.2362417 ,  0.43385735, ..., -0.08310726,
         0.08506312, -0.26480883],
       [ 0.23798843,  0.04592361, -0.4435884 , ...,  0.06911898,
         0.0487297 , -0.02855118]], dtype=float32)