In [None]:
import glob
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy
import pickle
import os
import h5py
import obspy
import pandas as pd
import time
from obspy.core.utcdatetime import UTCDateTime
os.environ['CUDA_VISIBLE_DEVICES'] = "3"
import torch
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from torch.optim import lr_scheduler

from obspy.signal.filter import bandpass, lowpass, highpass
from obspy.signal.invsim import cosine_taper
from scipy.signal import butter, filtfilt, detrend

In [None]:
class DASDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, outputs):
        'Initialization'
        if isinstance(inputs, torch.Tensor):
            self.inputs = inputs
            self.outputs = outputs
        else:
            self.inputs = inputs.astype(np.float32)
            self.outputs = outputs.astype(np.float32)

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.outputs)

    def __getitem__(self, index):
        'Generates one sample of data'
        X = self.inputs[index, :]
        y = self.outputs[index, :]

        return X, y
    
class SHRED(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.2)
        self.sdn1 = torch.nn.Linear(hidden_size, output_size//2)
        self.sdn3 = torch.nn.Linear(output_size//2, output_size)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.lstm(x)[1][0][-1] # should be -1
        x = self.relu(self.sdn1(x))
        x = self.sdn3(x)
        return x

In [None]:
flist = glob.glob("../../datasets/earthquakes/*")

nsample_train = 300
nsample_val = 300
nsample_test = 300

ncha = 201
ntime = 200

ncha_start = 1000
ncha_end = 6000
noutput = 1000
dcha = int(noutput/ncha)

In [None]:
cidx = np.linspace(1, noutput, ncha, dtype = 'int') - 1

X = np.zeros([len(flist)*nsample_train, ntime, ncha])
Y = np.zeros([len(flist)*nsample_train, noutput])

In [None]:
i = 0
for fname in tqdm(flist):
    f = h5py.File(fname, 'r')
    
    data = f['/Acquisition/Raw[0]/RawData'][:, ncha_start:ncha_end].T
    data -= np.mean(data, axis=-1, keepdims=True)
    data /= np.std(data, axis=-1, keepdims=True)
    f.close()

    for _ in range(nsample_train):
        idt = np.random.randint(ntime+1, 3000)             # last time index
        ic  = np.random.randint(0, ncha_end-ncha_start-noutput)       # first channel indexes
        X[i, :, :] = data[ic+cidx, idt-(ntime-1):idt+1].T
        Y[i, :] = data[ic:ic+noutput, idt]
        i += 1
print(f"training set size: {i}")  

In [None]:
vmax = 5

plt.figure(figsize=(10, 8), dpi=300)
plt.imshow(data, aspect='auto', cmap='RdBu',  origin='lower', 
           vmax = vmax, vmin = -vmax)
plt.title("original", fontsize=20)
plt.xticks([]); 
plt.yticks([])

In [None]:
idx = np.arange(len(flist)*nsample_train)
np.random.shuffle(idx)

idx_train = idx[:int(0.7*len(idx))]
idx_val = idx[int(0.7*len(idx)):int(0.75*len(idx))]
idx_test = idx[int(0.75*len(idx)):]

In [None]:
X_train_ts = torch.Tensor(X[idx_train, :, :])
Y_train_ts = torch.Tensor(Y[idx_train, :])

X_val_ts = torch.Tensor(X[idx_val, :, :])
Y_val_ts = torch.Tensor(Y[idx_val, :])

X_test_ts = torch.Tensor(X[idx_test, :, :])
Y_test_ts = torch.Tensor(Y[idx_test, :])

dataset = DASDataset(X_train_ts, Y_train_ts)
data_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

val_dataset = DASDataset(X_val_ts, Y_val_ts)
val_data_loader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=0)

test_dataset = DASDataset(X_test_ts, Y_test_ts)
test_data_loader = DataLoader(test_dataset, batch_size=512, shuffle=True, num_workers=0)

print("train: ", X_train_ts.shape, Y_train_ts.shape)
print("validate: ", X_val_ts.shape, Y_val_ts.shape)
print("test: ", X_test_ts.shape, Y_test_ts.shape)

In [None]:
import gc
gc.collect()

In [None]:
nhidden = 150
nlstm = 2

model = SHRED(ncha, nhidden, noutput, nlstm)
device = torch.device('cuda')
model.to(device);

for m in model.modules():
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0.01)
        
n_weights = 0
for i in model.parameters():
    n_weights += len(i.data.flatten())
print(f"have total {n_weights} weights")

In [None]:
nepoch = 80
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)
loss_fn = MSELoss()
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=nepoch)

In [None]:
t0 = time.time()
train_loss_log = []
val_loss_log = []
test_loss_log = []
    
for t in range(nepoch):
    model.train()
    train_loss = []
    for batch_id, batch in tqdm(enumerate(data_loader), total=len(data_loader)):
        optimizer.zero_grad() # Backpropagation
        pred = model(batch[0].to(device))
        loss = loss_fn(pred, batch[1].to(device))
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())

    model.eval()
    val_loss = []
    with torch.no_grad():
        for batch_id, batch in tqdm(enumerate(val_data_loader), total=len(val_data_loader)):
            pred = model(batch[0].to(device))
            loss = loss_fn(pred, batch[1].to(device))
            val_loss.append(loss.item())
    val_loss_log.append(np.mean(val_loss))
    
    test_loss = []
    with torch.no_grad():
        for batch_id, batch in tqdm(enumerate(test_data_loader), total=len(test_data_loader)):
            pred = model(batch[0].to(device))
            loss = loss_fn(pred, batch[1].to(device))
            test_loss.append(loss.item())
    test_loss_log.append(np.mean(test_loss))
    
    train_loss_log.append(np.mean(train_loss))
    before_lr = optimizer.param_groups[0]["lr"]
    scheduler.step()
    after_lr = optimizer.param_groups[0]["lr"]
    print("Epoch %d: Adam lr %.4f -> %.4f" % (t, before_lr, after_lr))
    print("%d, %.4f, %.4f, %.4f" % (t, np.mean(train_loss), np.mean(test_loss), np.mean(val_loss)))
    
#     torch.save(model.state_dict(), 
#            f"/home/niyiyu/Research/DAS-NIR/gci-summary/results/weights/SHRED_KKFLS_25Hz_201i_1000o_200sp_epo{t}.pt")
print(time.time() - t0)

In [None]:
with open("../../datasets/loss.pt", "wb") as f:
    pickle.dump({"train": train_loss_log,
                 "validate": val_loss_log,
                 "test": test_loss_log}, f)

In [None]:
plt.figure(figsize = (8, 5), dpi = 300)
plt.plot(train_loss_log, ".-", label = 'training')
plt.plot(val_loss_log, ".-", label = 'validation')
plt.plot(test_loss_log, ".-", label = 'testing')
plt.legend()
plt.xlabel("Epoch", fontsize = 15)
plt.ylabel("Loss", fontsize = 15)
plt.grid(True)
plt.savefig("../figures/manuscripts/FigS_loss.pdf", bbox_inches='tight', dpi=300)
# plt.yscale('log')

In [None]:
model.train()
idx = np.random.randint(0, len(X_val_ts))

# model.eval()
inputs = X_val_ts[idx, :, :]
label = Y_val_ts[idx, :]
predict = model(inputs.to(device)).cpu().detach().numpy()
print(torch.mean((label-predict)**2))

plt.figure(figsize = (25, 10))
plt.subplot(2,1,1)
plt.plot(label, linewidth = 2.5)
plt.plot(predict, '--', linewidth = 2.5)

plt.scatter(cidx, label[cidx], marker='+', color='r', zorder=90, s=400, linewidth=3)

plt.subplot(2,1,2)
plt.plot(label - predict, linewidth = 2.5)
plt.ylim([-1, 1])

In [None]:
flist2 = glob.glob(f"../../datasets/earthquakes/*.h5")
fname = flist2[10]
f = h5py.File(fname, 'r')
data = f['/Acquisition/Raw[0]/RawData'][:, 1000:2000].T
starttime = UTCDateTime(dict(f['/Acquisition/'].attrs)['MeasurementStartTime'])
f.close()

vmax = 3
x_max = 2000

data -= np.mean(data, axis=-1, keepdims=True)
data /= np.std(data, axis=-1, keepdims=True)

dout = np.zeros(data.shape)
for i in range(ntime, data.shape[1]):
    din = torch.Tensor(data[cidx, i-(ntime-1):i+1].copy().T)
    dout[:, i] = model(din.to(device)).cpu().detach().numpy()

res = dout - data

In [None]:
plt.figure(figsize=(25, 8), dpi=300)
# plt.suptitle(f"{eid} M{mag}")
plt.subplot(1,3,1)
plt.imshow(data[:500, :], aspect='auto', cmap='RdBu',  origin='lower', 
           vmax = vmax, vmin = -vmax)
plt.title("original", fontsize=20)
plt.ylim([0, 500]); plt.xlim([ntime, x_max])
plt.xticks([]); plt.yticks([])

plt.subplot(1,3,2)
plt.imshow(dout[:500, :], aspect='auto', cmap='RdBu',  origin='lower', 
           vmax = vmax, vmin = -vmax)
plt.title("reconstruction", fontsize=20)
plt.xticks([]); plt.yticks([])
plt.ylim([0, 500]); plt.xlim([ntime, x_max])

plt.subplot(1,3,3)
plt.imshow(dout[:500, :] - data[:500, :], aspect='auto',  origin='lower', 
           cmap='RdBu', vmax = vmax, vmin = -vmax)
plt.xticks([]); plt.yticks([])
plt.ylim([0, 500]); plt.xlim([ntime, x_max])
plt.title("residual", fontsize=20)
# plt.savefig(f"../figures/tmp/shred_epo{t}.png", bbox_inches = 'tight', dpi = 400)
# plt.close()