In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }; .cell {width:100%} ; .code_cell{width:100%}</style>"))

In [None]:
%matplotlib inline
import matplotlib
import os
import glob
import datetime
import traceback
from obspy.core import read, UTCDateTime
from obspy import UTCDateTime, Stream, read
from obspy.geodetics.base import gps2dist_azimuth
from obspy.core.util import AttribDict
import matplotlib
import matplotlib as mpl
new_style = {'grid': False}
mpl.rc('axes', **new_style)
# mpl.rcParams['font.family'] = 'Helvetica'
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
sns.set_style("whitegrid")
sns.set_palette("dark")
# import tqdm

from obspy.signal.cross_correlation import xcorr_pick_correction
import warnings
from collections import defaultdict
from obspy.signal.cross_correlation import correlate,xcorr_max

from scipy.cluster.hierarchy import dendrogram, set_link_color_palette, linkage
from scipy.spatial.distance import squareform
import matplotlib.gridspec as gridspec
from matplotlib.colors import rgb2hex, colorConverter


In [None]:
station = "GRW0"
channel = "BHZ"

traces = []
for typ in ["VTB", "MP", "gugu_long", "gugu_short", "NN", "ND"]:
    for file in glob.glob("events/%s/*.mseed" % typ):
     tr = read(file).select(station=station,channel=channel)[0]
     tr.stats.event_type = typ
     traces.append(tr)
st = Stream(traces=traces)

In [None]:
st

In [None]:
freqlow=1.0
freqhigh=10.0

st2 = st.copy()
# st2.sort()
st2.detrend("demean")
st2.taper(None, max_length=0.5)
st2.filter("bandpass", freqmin=freqlow, freqmax=freqhigh, corners=8, zerophase=True)
# st2.resample(100.0)
# st2.sort()
# st2.filter("bandpass", freqmin=8.0, freqmax=30.0, corners=8) #.plot(equal_scale=False, size=(1200, 1200))
st2[0].plot(automerge=False)

In [None]:
st2

In [None]:
similarity = np.eye(len(st2))
dts = np.zeros((len(st2),len(st2)))

In [None]:
picks = []
relpicks = []
for tr in st2:
    t = tr.stats.starttime + 5.0
    picks.append(t)

In [None]:
before= 3
after = 20
freqmin = freqlow
freqmax = freqhigh
cc_maxlag = 5
phase = "*"
taxis = (np.arange(st2[0].stats.npts) * st2[0].stats.delta)
sel = np.where((taxis>=before) & (taxis<=after))[0]

In [None]:


for i, pick1 in enumerate(picks):
    tr1 = st2[i].copy()
    print(i, "vs all")
    for j, pick2 in enumerate(picks):
        if j <= i:
            continue
        
        tr2 = st2[j].copy()

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            try:
#                 dt, coeff = xcorr_pick_correction(pick1, tr1, pick2, tr2, t_before=before, t_after=after, 
#                                                   cc_maxlag=cc_maxlag, filter=None,
#                                                   filter="bandpass", filter_options={"freqmin":freqmin, "freqmax":freqmax,"corners":8},
#                                                   plot=False)
                cc = correlate(tr1.data[sel], tr2.data[sel], 0)
                dt, coeff = xcorr_max(cc)    
            except:
                traceback.print_exc()
                coeff = 0
                continue
            similarity[i,j] = coeff
            similarity[j,i] = coeff
            dts[i,j] = dt
            dts[j,i] = dt

In [None]:
similarity = abs(similarity)

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(similarity,interpolation="none", cmap="viridis", vmin=-0., vmax=0.7, )
plt.xlabel("event ID")
plt.ylabel("event ID")
cb = plt.colorbar(shrink=0.5)
cb.ax.set_ylabel('similarity')

# plt.savefig('matrix %s.%s %s.png'%(sta,comp, phase), dpi=300)
plt.show()

In [None]:
sns.set_palette(sns.color_palette("tab20", 20))
palette = sns.color_palette()

set_link_color_palette(list(map(rgb2hex, palette)))

def get_cluster_classes(den, label='ivl'):
    cluster_idxs = defaultdict(str)
    for c, pi in zip(den['color_list'], den['icoord']):
        for leg in pi[1:3]:
            i = (leg - 5.0) / 10.0
            if abs(i - int(i)) < 1e-5:
#                 cluster_idxs[c].append(int(i))
                cluster_idxs[den[label][int(i)]] =  c
    return cluster_idxs

fig = plt.figure(figsize=(12,20))

gs = gridspec.GridSpec(1, 2,
                       width_ratios=[4,1],
                       )



linkage_method = "complete"
dissimilarity_threshold = 0.6

distance = squareform(1-similarity)
linkage_matrix = linkage(distance, method=linkage_method)

plt.subplot(gs[1])
D = dendrogram(linkage_matrix, color_threshold=dissimilarity_threshold, orientation="right",above_threshold_color='k')
plt.gca().yaxis.tick_right()
plt.gca().yaxis.set_label_position("right")
plt.axvline(dissimilarity_threshold, c='k',ls="--")
plt.xlabel("Dissimilarity")
plt.ylabel("Event ID")
plt.setp(plt.gca().get_xticklabels()[0], visible=False)    

plt.subplot(gs[0])

colors = get_cluster_classes(D)

previous = "w"
previous_idx = 0
clusters = {}
clusters_shifts = {}
cluster_color = {}
cluster_id = 0

for i, index in enumerate(D['ivl']):
    tri = st2[int(index)]
#     taxis = np.arange(tri.stats.npts)*tri.stats.delta - Pcorr
    tri.normalize()
    try:
        c = colors[str(int(index))]
    except:
        c='k'
        pass
    
    
    if c != previous:
        cluster_id += 1
        previous = c
        previous_idx = int(index)
        shift = 0
        shift_to_first = 0
    else:
        shift = dts[previous_idx, int(index)]
        shift_to_first = dts[0, int(index)]
    
#     max_diff = relpicks[int(index)] - relpicks[previous_idx]
    
    if cluster_id not in clusters:
        clusters[cluster_id] = []
        clusters_shifts[cluster_id] = []
        cluster_color[cluster_id] = c

    clusters[cluster_id].append(int(index))
    clusters_shifts[cluster_id].append(shift)

    plt.plot(taxis+shift-shift_to_first, tri.data*0.8+i, lw=1,c=c)
    

plt.xlabel("Time relative to pick (s)")
plt.ylim(-.5,i+.5)


plt.axvspan(before, after, zorder=-10, alpha=0.1, facecolor='silver')
indexes = [int(_) for _ in D['ivl']]
dlabel = []
for i in np.array(indexes):
    d = np.array(picks)[i]
    et = st2[i].stats.event_type
    label = "%s " % et + d.strftime('%Y-%m-%d %H:%M:%S')
    dlabel.append(label)
# dlabel = [d.strftime('%Y-%m-%d %H:%M:%S')+" %" % t.stats.event_type for d,t in zip(np.array(picks)[np.array(indexes)],st2[np.array(indexes)])]

# xEvents = picks[np.array(indexes)]
yEvents = [st[int(d)].data.ptp() for d in D['ivl']]
cEvents = [colors[str(int(d))] for d in D['ivl']]

plt.yticks(np.arange(len(D['ivl'])), dlabel)
plt.ylabel("%s pick date & time"%phase)

plt.suptitle("%s.%s - Based on %s pick. Cross-Correlation on [-%.2f:+%.2f] s - Bandpass: [%.1f:%.1f] Hz - Maxlag: %.2f s" % 
             (station, channel, phase, before, after, freqmin, freqmax, cc_maxlag))
plt.tight_layout()
plt.subplots_adjust(top=0.96, wspace=0)

plt.setp(plt.gca().get_xticklabels()[-1], visible=False)    

plt.xlim(0,25)

plt.savefig('similarity %s.%s.png'%(station,channel), dpi=300)
plt.show()
