In [1]:
from tracemalloc import start
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
from datasets.dataset import *
# from models.Autoformer import *(我h'h'h)
from tqdm import tqdm
import argparse
import random

import scipy.stats
from scipy.signal import convolve2d
from scipy.signal import find_peaks
from sklearn.metrics import *

from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.metrics import metric
import os

from my_utils import *

from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
SEQ_LEN = 128
MAST_RATIO = 0.2
BATCH_SIZE = 10

data_train = Qiang_Anomaly(setting='train', window_size=SEQ_LEN, mask_ratio=MAST_RATIO, test_mode=False, is_abs=False)
data_valid = Qiang_Anomaly(setting='valid', window_size=SEQ_LEN, mask_ratio=MAST_RATIO, test_mode=False, is_abs=False)
data_test = Qiang_Anomaly(setting='test', window_size=SEQ_LEN, mask_ratio=MAST_RATIO, test_mode=False, is_abs=False)

data_train = ConcatDataset([data_train, data_valid])

dataloader_train = DataLoader(data_train, batch_size=64, shuffle=False, num_workers=1, collate_fn=collate_fn)
dataloader_test = DataLoader(data_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, collate_fn=collate_fn)
## Here batch size is equivelant to stride

In [3]:
dpath = '/data/rech/dingqian/data_das/patches_test.pt'
ppath = '/data/rech/dingqian/data_das/picks_test.pt'
ipath = '/data/rech/dingqian/data_das/index_test.npy'
mpath = '/data/rech/dingqian/data_das/meta_test.npy'

print(f'Loading data from: {dpath}')
# img_list = sorted(glob.glob(dpath + '/*.png') + glob.glob(dpath + '/*.jpg'))
tensors = torch.load(dpath)
# tensors = tensors.unsqueeze(1) # N, 1, 128, 128

print(f'Loading picks from: {ppath}')
picks = torch.load(ppath)
# picks = picks.unsqueeze(1) # N, 1, 128, 128

print(f'Loading indices_start from: {ipath}')
indices_start = np.load(ipath, allow_pickle=True)

print(f'Loading metas from: {mpath}') # [{'file': 'xxx'}]
metas = np.load(mpath, allow_pickle=True)

Loading data from: /data/rech/dingqian/data_das/patches_test.pt
Loading picks from: /data/rech/dingqian/data_das/picks_test.pt
Loading indices_start from: /data/rech/dingqian/data_das/index_test.npy
Loading metas from: /data/rech/dingqian/data_das/meta_test.npy


In [6]:
opath = '/data/rech/dingqian/data_das/output.npy'
opath_s1 = '/data/rech/dingqian/data_das/output_s1.npy'

print(f'Loading outputs from: {opath}') # [{'file': 'xxx'}]
outputs = np.load(opath, allow_pickle=True)
outputs_s1 = np.load(opath_s1, allow_pickle=True)

Loading outputs from: /data/rech/dingqian/data_das/output.npy


In [9]:
len(outputs_s1), len(outputs), len(picks)

(9792, 9792, 9792)

In [71]:
def load_npz(f):
    meta = dict(np.load(f))
    data = meta["data"]
    data = data.astype(np.float32)

    data = moving_average_fast(data.T, 10).T # 900 x 1250

    data -= np.median(data, axis=1, keepdims=True)
    data -= np.mean(data, axis=0)
    data /= np.std(data, axis=0)
    data = data.T # 1250 x 900

    meta["data"] = data
    meta["file"] = f

    return meta 

def map_emd(a, b):
    return scipy.stats.wasserstein_distance(a.flatten(), b.flatten())

def sliced_emd(a, b, slice_size=32, step_size=16):
    a_ts = torch.tensor(a)
    b_ts = torch.tensor(b)

    a_sliced = a_ts.unfold(0, slice_size, step_size).unfold(1, slice_size, step_size).flatten(0,1)
    b_sliced = b_ts.unfold(0, slice_size, step_size).unfold(1, slice_size, step_size).flatten(0,1)

    emd_list = list(map(map_emd, a_sliced, b_sliced))
    emd = sum(emd_list)

    return emd

def kl(a, b):
    a_norm = scipy.special.softmax(a.flatten())
    b_norm = scipy.special.softmax(b.flatten())

    ret = 0.5 * (scipy.stats.entropy(a_norm.flatten(), b_norm.flatten()) + scipy.stats.entropy(b_norm.flatten(), a_norm.flatten()))
    return ret

def js(a, b):
    a_norm = scipy.special.softmax(a.flatten())
    b_norm = scipy.special.softmax(b.flatten())

    ret = scipy.spatial.distance.jensenshannon(a_norm.flatten(), b_norm.flatten())
    return ret

def hist_kl(a, b):
    a_clip = np.clip(a, -2, 2)
    b_clip = np.clip(b, -2, 2)

    a_hist_prob = np.histogram(a_clip, 10)[0] / np.histogram(a_clip, 10)[0].sum()
    b_hist_prob = np.histogram(b_clip, 10)[0] / np.histogram(b_clip, 10)[0].sum()
    
    ret = 0.5 * (scipy.stats.entropy(a_hist_prob, b_hist_prob) + scipy.stats.entropy(b_hist_prob, a_hist_prob))
    return ret

def hist_js(a, b):
    a_clip = np.clip(a, -2, 2)
    b_clip = np.clip(b, -2, 2)

    a_hist_prob = np.histogram(a_clip, 10)[0] / np.histogram(a_clip, 10)[0].sum()
    b_hist_prob = np.histogram(b_clip, 10)[0] / np.histogram(b_clip, 10)[0].sum()
    
    ret = scipy.spatial.distance.jensenshannon(a_hist_prob, b_hist_prob)
    return ret

def get_metrics(y, pred):
    ret_auc = roc_auc_score(y, pred)
    fpr, tpr, thresholds = roc_curve(y, pred)
    
    maxindex = (tpr - fpr).tolist().index(max(tpr - fpr))
    threshold = thresholds[maxindex]
    
    y_pred = pred > threshold
    
    '''ret_f1 = 0.
    for threshold in tqdm(thresholds):
        f1_now = f1_score(y, pred > threshold)
        if f1_now > ret_f1:
            ret_f1 = f1_now
            ret_threshold = threshold'''
    
    ret_threshold = threshold
    ret_f1 = f1_score(y, pred > threshold)
    
    return ret_auc, ret_f1, ret_threshold
    

In [11]:
i = 0
pick_np = picks[i].squeeze().numpy()
output_np = outputs[i]

In [17]:
(pick_np > 0.5).flatten()

array([False, False, False, ..., False, False, False])

In [35]:
thres = 0.5

pick_pred_list = []
pick_list = []

for i in tqdm(range(len(metas))):
    tensor = tensors[i]
    pick = picks[i].squeeze().numpy()
    output = torch.sigmoid(torch.tensor(outputs[i])).numpy()
    
    pick = (pick > 0.5).flatten()
    pick_pred = output.flatten()
    
    pick_list.append(pick)
    pick_pred_list.append(pick_pred)


 26%|██████████████████████████████████▋                                                                                                  | 2553/9792 [00:00<00:00, 13102.47it/s]

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
97
137
100
53
63
119
181
171
119
73
24
22
40
71
99
121
152
130
137
100
53
63
119
181
172
128
126
105
111
145
169
157
133
152
130
98
0
0
0
0
0
1
9
53
81
89
105
98
58
12
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
38
48
34
17
7
24
20
0
0
0
0
8
19
23
29
67
66
34
102
72
32
38
76
49
15
14
2
0
8
28
37
34
80
105
65
41
38
15
31
52
29
15
14
2
0
0
9
14
5
13
39
31
11
26
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


 56%|█████████████████████████████████████████████████████████████████████████▉                                                           | 5441/9792 [00:00<00:00, 13993.76it/s]

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
116
111
79
78
93
97
101
78
30
27
59
72
87
100
80
88
119
124
200
136
137
168
186
209
203
106
99
181
173
170
190
148
129
183
227
212
57
59
75
89
108
125
76
72
122
101
83
90
68
41
64
103
98
101
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
212
232
201
158
162
205
236
241
219
227
248
225
183
153
148
156
189
229
232
201
158
162
205
236
241
219
227
248
225
183
153
148
156
189
229
226
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0

 70%|█████████████████████████████████████████████████████████████████████████████████████████████                                        | 6848/9792 [00:00<00:00, 14018.26it/s]

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
100
124
118
115
111
114
123
121
87
44
65
56
24
29
65
145
183
169
248
219
185
198
226
240
246
184
130
180
159
126
144
142
168
186
205
198
101
70
87
112
117
125
97
86
115
103
102
115
77
23
3
36
72
99
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
195
232
211
142
148
221
195
148
105
96
115
114
120
124
112
137
177
218
232
211
142
148
226
229
228
189
190
240
235
218
187
137
140
177
218
203
0
0
0
5
34
80
84
94
125
121
98
63
25
3
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9792/9792 [00:00<00:00, 13163.73it/s]

0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
190
225
225
225
242
216
224
250
192
186
199
201
234
202
157
160
178
157
225
225
225
242
216
224
250
192
186
199
201
234
202
157
160
178
157
166
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
68
92
90
19
44
106
105
47
25
18
10
26
37
21
36
132
216
194
92
90
19
44
106
105
47
25
18
10
26
37
21
36
132
216
194
156
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0




In [36]:
pick = np.concatenate(pick_list)
pick_pred = np.concatenate(pick_pred_list)

In [92]:
pick_pred

array([1.4222283e-03, 2.0534967e-06, 9.1854270e-07, ..., 4.7105919e-03,
       5.6040441e-03, 3.6003070e-03], dtype=float32)

In [93]:
pick.sum() / pick.size

0.00136666516073389

In [72]:
pred_dict = {}
heat_dict = {}
mask_size = 128
mask_start = 0

for i in tqdm(range(len(metas))):
    tensor = tensors[i]
    pick_np = picks[i].squeeze().numpy()
    index_start = indices_start[i]
    meta = metas[i]
    input_np = tensor.squeeze().numpy()
    output_np = torch.sigmoid(torch.tensor(outputs[i])).numpy()
    
    ### diff_emd_list should be replaced
    
    # pred_label = diff_emd_list[i] > (0.1)
    #pred_label = diff_sliced_emd_list[i] > (ret_threshold)
    
   #  pred_label = diff_emd_list[i]
    
    # print(meta['file'])
    
    #### Load pred, heat
    if meta['file'] not in pred_dict:
        pred_dict[meta['file']] = load_npz(meta['file'])['data']
        heat_dict[meta['file']] = np.zeros(pred_dict[meta['file']].shape)
    
    data = pred_dict[meta['file']]
    heat = heat_dict[meta['file']]
    
    # patch_origin = data[index_start[0]: index_start[0]+128, index_start[1]: index_start[1]+128]
    
    #### Calculate new data and new heat based on threshold
    x_start = index_start[0] + mask_start
    
    data[index_start[1]: index_start[1]+128, x_start: x_start + mask_size] = output_np[:, mask_start: mask_start + mask_size]
    
    
    pred_dict[meta['file']] = data
    
    
    

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9792/9792 [00:04<00:00, 2305.57it/s]


In [74]:
x, y = np.where(data > 0.5)
np.stack([x,y], axis=1)

array([[   0,    1],
       [   0,    3],
       [   0,    5],
       ...,
       [1249,  877],
       [1249,  879],
       [1249,  882]])

In [75]:
data > 0.5

array([[False,  True, False, ...,  True,  True, False],
       [False, False,  True, ...,  True, False, False],
       [False,  True, False, ...,  True, False,  True],
       ...,
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])

In [None]:
######## Old

In [262]:
x, y = np.where(heat == 1)
np.stack([x,y], axis=1)

array([[   0,  333],
       [   0,  334],
       [   0,  335],
       ...,
       [1215,  880],
       [1215,  881],
       [1215,  882]])

In [76]:
def save_z(z, file_name, heat=None): # N x M
    plt.figure(figsize=(8,8), frameon=False)
    plt.axis('off')
    plt.imshow(z, aspect='auto', vmin=-2.0, vmax=2.0, cmap="seismic")
    
    heat = heat / heat.max()
    plt.imshow(heat, aspect='auto', vmin=0., vmax=1.0, cmap="seismic", alpha=0.8)
    ### Add heat
    '''if heat is not None:
        x, y = np.where(heat > 0)
        # np.stack([x,y], axis=1)
        # plt.plot(y, x, alpha=0.5, color='r')
        alpha = 1. if heat[x, y] > 2. else heat[x, y]
        plt.plot(y, x, alpha=heat[x, y], color='r')'''
    
    plt.tight_layout()
    plt.savefig(os.path.join('./samples/finetune', os.path.split(file_name)[-1] + '.jpg'), pad_inches=0, dpi=64)
    
    plt.cla()
    plt.clf()
    plt.close()
    


In [80]:
for file_name, data_pred in pred_dict.items():
    print(file_name)
    data = load_npz(file_name)['data']
    
    save_z(data_pred>0.05, file_name + '_pred', heat)
    save_z(data, file_name, heat)
    

./data/DAS/test_npz/2020-12-01T01-54-49-M3-24km.npz


  heat = heat / heat.max()


./data/DAS/test_npz/2021-05-27T15-18-26-M3-30km.npz
./data/DAS/test_npz/2020-07-25T19-36-01-M2-15km.npz
./data/DAS/test_npz/2020-04-05T08-21-41-M2-55km.npz
./data/DAS/test_npz/2020-04-04T04-19-09-M2-17km.npz
./data/DAS/test_npz/2019-11-10T12-35-18-M2-24km.npz
./data/DAS/test_npz/2019-08-18T22-02-28-M3-22km.npz
./data/DAS/test_npz/2020-10-02T06-00-56-M2-24km.npz
./data/DAS/test_npz/2019-12-08T09-59-23-M1-12km.npz
./data/DAS/test_npz/2021-02-20T09-54-10-M2-33km.npz
./data/DAS/test_npz/2019-08-25T10-36-27-M3-19km.npz
./data/DAS/test_npz/2019-09-03T10-36-23-M2-22km.npz
./data/DAS/test_npz/2019-09-17T05-01-20-M2-16km.npz
./data/DAS/test_npz/2021-04-08T20-51-15-M4-15km.npz
./data/DAS/test_npz/2019-11-07T06-06-05-M2-17km.npz
./data/DAS/test_npz/2019-12-03T04-41-40-M2-19km.npz
./data/DAS/test_npz/2020-05-14T01-36-08-M2-19km.npz
./data/DAS/test_npz/2020-04-10T04-55-32-M2-18km.npz
./data/DAS/test_npz/2020-04-22T10-21-56-M2-18km.npz
./data/DAS/test_npz/2019-12-19T09-12-59-M2-16km.npz
./data/DAS/t

In [304]:
max(diff_emd_list)

3.6707442508025903

In [267]:
ret_f1

0.3843120070113935

In [245]:
###########Test 
from scipy.fft import rfft, rfftfreq

def get_fft(z):
    n = len(z)
    yf = rfft(z)
    xf = rfftfreq(n, 0.1)
    return np.abs(yf)
    

In [246]:
output = outputs[0]
inp = tensors[0].squeeze().numpy()

output_f = np.zeros((128, 65), dtype=float)
for i in range(128):
    output_f[i] = get_fft(output[i])

In [247]:
plt.imshow(output_f, aspect='auto', vmin=-5.0, vmax=5.0, cmap="seismic")
plt.savefig('freq_output.jpg')

<Figure size 640x480 with 1 Axes>


In [249]:
input_f = np.zeros((128, 65), dtype=float)
for i in range(128):
    input_f[i] = get_fft(inp[i])

plt.imshow(input_f, aspect='auto', vmin=-5.0, vmax=5.0, cmap="seismic")
plt.savefig('freq_input.jpg')

<Figure size 640x480 with 1 Axes>


In [252]:
plt.imshow(output_f - input_f, aspect='auto', vmin=-1, vmax=1, cmap="seismic")
plt.savefig('freq_diff.jpg')

<Figure size 640x480 with 1 Axes>


In [244]:
tensors[0].squeeze().numpy()

array([[ 0.1548734 , -0.24163167,  0.21370491, ..., -0.7274417 ,
        -0.22254851, -0.5854341 ],
       [ 1.6204568 , -0.7617863 ,  0.13278896, ..., -0.03637507,
         0.67147183, -0.13801157],
       [ 1.0656402 , -1.6293908 ,  1.0746189 , ..., -1.6335528 ,
         0.40482336,  1.041287  ],
       ...,
       [-0.14853552, -0.04680762,  0.27739307, ...,  0.00496372,
        -0.10391421, -0.03731848],
       [-0.27020475,  0.2362417 ,  0.43385735, ..., -0.08310726,
         0.08506312, -0.26480883],
       [ 0.23798843,  0.04592361, -0.4435884 , ...,  0.06911898,
         0.0487297 , -0.02855118]], dtype=float32)

In [None]:
###### Finetune analysis


In [82]:
import h5py

f = h5py.File('/Tmp/dingqian/Ocean/Cascadia-North-2021-11-01T215345Z.h5', 'r')

In [94]:
f['Data']

<HDF5 dataset "Data": shape (32600, 360000), type "<f4">

In [None]:
###### Convert data
from six.moves import cPickle as pickle

with open('/data/rech/dingqian/data_das/finetune.npy', 'rb') as f:
    x_ft_train, style_ft_train, skip_ft_train = pickle.load(f)
    
with open('/data/rech/dingqian/data_das/finetune_test.npy', 'rb') as f:
    x_ft_test, style_ft_test, skip_ft_test = pickle.load(f)


In [None]:
x_ft_train