In [None]:
import torch
import scipy
import numpy as np
from dict_layer import Model, dictloss
import os
from utlis import survey_deisgn, slownessMap, patchSamp, getPatches, rmse, vd
from LsTomo import conventional_tomo2
import matplotlib.pyplot as plt
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import lsqr
from itkm import itkm
from omp_n import omp_n
from mpl_toolkits.axes_grid1 import make_axes_locatable
####################################################################################
np.random.seed(123)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
####################################################################################
cuda = torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float64)
print(torch.cuda.device_count(),torch.cuda.get_device_properties(0), cuda, device)
####################################################################################
model = Model()
model = model.to(device)

### Build the slowness model and set the stations

In [None]:
save_path = './sd/'
sTrue = slownessMap('sd')
Tarr, A, vb, vb2 = survey_deisgn(64, sTrue, save_path=save_path)
noiseFrac = 0.05                                          # noise STD as fraction of mean value of travel time (0=noise free case)                                                
stdNoise = np.mean(Tarr) * noiseFrac
noise    = stdNoise * np.random.randn(Tarr.shape[0], Tarr.shape[1])
Tarr_n   = Tarr + noise                                  
Asum = np.sum(A, axis=1, keepdims=True)                                    # estimating referense slowness from travel time observations
invAsum = np.linalg.pinv(Asum)
sRef = invAsum @ Tarr_n                                     # Tarr_n: 1x2016 A: 2016x10000 vb:10000x1
#######################################################################################
if noiseFrac == 0:                  
    eta  = 0.1                                              # conventional \eta regularization parameter
    L    = 10                                               # smoothness length scale
else:
    eta  = 10      
    L    = 20                                              
#######################################################################################
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = dictloss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100], gamma=0.1)

w1, w2 = sTrue.shape
patch_size = 20
patches  = getPatches(w1, w2, patch_size)        # calculating image patch indices
percZero = patchSamp(A, patches)       # percentage ray coverage in patches
ss = np.zeros((vb.size, 1))
npp, npatches = patches.shape
nrays, npix   = A.shape
############################################################
epoch = 50
natom = 150
iters = 50
lam2  = 0
D_sparse = 1
C_sparse = 1
##########################################################

### NN training and dictionary optimizing

In [None]:
ds = conventional_tomo2(eta, L, A, sRef, Tarr_n, sTrue, vb)
Y  = ds.flatten()[patches]
meanY = np.mean(Y, axis=0, keepdims=True)
Yc = Y - meanY
D0 = itkm(Yc, natom, D_sparse)
X  = omp_n(D0, Yc, C_sparse)
D0 = torch.tensor(D0).to(device)
RMSE = []
###################### Ditcionary Learning ###############################
for j in range(epoch):
    D = model(D0)
    loss = criterion(D, X, ss, vb, npatches, patches, \
        npp, sRef, A, Tarr_n, meanY, ds, lam2, device)
    optimizer.zero_grad()
    loss.backward()               
    optimizer.step()
    lr_scheduler.step()         
    print('epoch: {}/{}, loss: {:.8f}'.format(j+1, epoch, loss.item()))
    RMSE.append(loss.item())
    
# torch.save(model.state_dict(), os.path.join(save_path, ('dict_noise{}_eta{}_L{}_ps{}_natom{}_lam2{}_D_sparse{}_C_sparse{}.pth').\
#     format(noiseFrac, eta, L, patch_size, natom, lam2, D_sparse, C_sparse)))
D = D.detach().squeeze().cpu().numpy() 

### Load the saved weight

In [None]:
# model.load_state_dict(torch.load(os.path.join(save_path, 'dict_noise0.05_eta10_L20_ps20_natom150_lam20_D_sparse1_C_sparse1.pth')))
# model.eval()
# D = model(D0)
# D = D.detach().squeeze().cpu().numpy() 

### Reconstruct the velocity model

In [None]:
if noiseFrac == 0.02:
    C_sparse = 25
elif noiseFrac == 0.05:
    C_sparse = 5              
    
X = omp_n(D, Yc, C_sparse)
ss_b = D @ X + meanY
ss_p_sum = np.zeros(ss.shape)

for k in range(npatches):
    ss_p_sum[patches[:, k], 0] = ss_p_sum[patches[:, k], 0] + ss_b[:, k]

ss_f = (lam2 * ds + ss_p_sum)/(lam2 + npp)
ss   = ss_f * vb  
Tref = A @ (ss + sRef)

### Plot the updated dictionary

In [None]:
vd(D, patch_size, natom//patch_size, 'us', save=save_path)

### Plot slice of velocity model by the proposed method

In [None]:
y  = np.reshape(ss + 0.1*ds*vb + sRef, (w1, w2))
vb1 = np.reshape(vb, sTrue.shape)

hsT = sTrue[45, :]
hsP = y[45, :]
vsT = sTrue[:, 35]
vsP = y[:, 35]

hmask = vb1[45, :]
vmask = vb1[:, 35]

extent = 0, w1, 0, w2
###########################plot result#######################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax = fig.add_subplot()
im = ax.imshow(y, extent=extent)
plt.hlines(45, 0, 100)
plt.vlines(35, 0, 100)
ax.set_xlabel("Range (km)")
ax.set_ylabel("Range (km)")
ax.text(80, 5, '{:.2f}'.format(rmse(sTrue, y, vb1)), fontsize=14)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'slowness_natom_{}_noise_{}_sp_{}.pdf'\
#     .format( natom, noiseFrac, C_sparse))
fig.savefig(save_path + 'slowness_natom_{}_noise_{}_sp_{}.pdf'\
    .format( natom, noiseFrac, C_sparse))

###########################polt slice###############################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax3 = fig.add_subplot(121)
ax3.plot(hsT)
ax3.plot(hsP)
ax3.text(8, 0.39, '{:.2f}'.format(rmse(hsT, hsP, hmask)), fontsize=14)
# ax3.set_xlabel("Range (km)")
# ax3.set_ylabel("Slowness (s/km)")
# ax3.legend(['True','Predict'])
plt.tick_params(top=True,bottom=True,left=True,right=True)
ax3.margins(0)


ax4 = fig.add_subplot(122)
ax4.plot(vsT, range(len(vsT)))
ax4.plot(vsP, range(len(vsP)))
ax4.text(0.35, 5, '{:.2f}'.format(rmse(vsT, vsP, vmask)), fontsize=14)
# ax4.set_ylabel("Range (km)")
# ax4.set_xlabel("Slowness (s/km)")
# ax4.set_title('Vertical slice RMSE:{:.4f}'.format(rmse(vsT, vsP, vmask)))
# ax4.legend(['True','Predict'])
plt.tick_params(top=True,bottom=True,left=True,right=True)
ax4.margins(0)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'slice_natom_{}_noise_{}_sp_{}.pdf'\
#     .format(natom, noiseFrac, C_sparse))
################################plot errors####################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax = fig.add_subplot()
im = ax.imshow((sTrue - y)*vb1, extent=extent)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'errors_natom_{}_noise_{}_sp_{}.pdf'\
#     .format(natom, noiseFrac, C_sparse))


### Plot slice of velocity model by LSQR

In [None]:
############ Plot Conv results ####################
y2  = np.reshape(ds*vb + sRef, (100, 100))
# vb1 = np.reshape(vb, sTrue.shape)
hsT = sTrue[45, :]
hsP = y2[45, :]
vsT = sTrue[:, 35]
vsP = y2[:, 35]
hmask = vb1[45, :]
vmask = vb1[:, 35]
extent = 0, 100, 0, 100
#############################plot results #######################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax1 = fig.add_subplot()
im1 = ax1.imshow(y2, extent=extent)
plt.hlines(45, 0, 100)
plt.vlines(35, 0, 100)
ax1.set_xlabel("Range (km)")
ax1.set_ylabel("Range (km)")
ax1.text(80, 5, '{:.2f}'.format(rmse(sTrue, y2, vb1)), fontsize=14)
divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im1, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'LSslownessnoise_{}.pdf'.format(noiseFrac))
################################plot slice#######################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax3 = fig.add_subplot(121)
ax3.plot(hsT)
ax3.plot(hsP)
# ax3.set_xlabel("Range (km)")
# ax3.set_ylabel("Slowness (s/km)")
plt.tick_params(top=True,bottom=True,left=True,right=True)
ax3.text(8, 0.39, '{:.2f}'.format(rmse(hsT, hsP, hmask)), fontsize=14)
ax3.legend(['True','Predicted'])
ax3.margins(0)

ax4 = fig.add_subplot(122)
ax4.plot(vsT, range(len(vsT)))
ax4.plot(vsP, range(len(vsP)))
# ax4.set_ylabel("Range (km)")
# ax4.set_xlabel("Slowness (s/km)")
ax4.text(0.35, 5, '{:.2f}'.format(rmse(vsT, vsP, vmask)), fontsize=14)
# ax4.legend(['True','Predict'])
ax4.margins(0)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'LSslicenoise_{}.pdf'.format(noiseFrac))
################################plot errors####################################
fig = plt.figure()
ax = fig.add_subplot()
im = ax.imshow((sTrue - y2)*vb1, extent=extent)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
# plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'LSerrors_noise_{}.pdf'.format(noiseFrac))

### Plot slice of velocity model by filter

In [None]:
fs = 9
xx = scipy.ndimage.median_filter(y2, (fs, fs))
hsT = sTrue[45, :]
hsP = xx[45, :]
vsT = sTrue[:, 35]
vsP = xx[:, 35]
hmask = vb1[45, :]
vmask = vb1[:, 35]
extent = 0, 100, 0, 100

fig = plt.figure()
ax1  = fig.add_subplot()
im1 = ax1.imshow(xx, extent=extent)
ax1.set_xlabel("Range (km)")
ax1.set_ylabel("Range (km)")
# ax1.set_title('RMSE:{:.4f}'.format(rmse(sTrue, xx, vb1)))
divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im1, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
ax1.text(80, 5, '{:.2f}'.format(rmse(sTrue, xx, vb1)), fontsize=14)
ax1.margins(0)
fig.savefig(save_path + 'media_filter_rmse_{:.2f}_fs_{}.pdf'\
    .format(rmse(sTrue, xx, vb1), fs))

# fig = plt.figure()
# ax2 = fig.add_subplot(121)
# ax2.plot(hsT)
# ax2.plot(hsP)
# ax2.set_xlabel("Range (km)")
# ax2.set_ylabel("Slowness (s/km)")
# # ax2.set_title('Horizone slice RMSE:{:.4f}'.format(rmse(hsT, hsP, hmask)))
# ax2.legend(['True','Predict'])
# ax2.margins(0)

# ax3 = fig.add_subplot(122)
# ax3.plot(vsT, range(len(vsT)))
# ax3.plot(vsP, range(len(vsP)))
# ax3.set_ylabel("Range (km)")
# ax3.set_xlabel("Slowness (s/km)")
# ax3.set_title('Vertical slice RMSE:{:.4f}'.format(rmse(vsT, vsP, vmask)))
# ax3.legend(['True','Predict'])
# ax3.margins(0)
# # fig.suptitle(' median_filter noise:{} eta:{} L:{}'.format(noiseFrac, eta, L), fontsize=16)
# plt.show()

### Plot the training loss

In [None]:
# import pandas as pd
# df = pd.read_excel('./loss.xlsx', sheet_name='0.05')
# a150=df['loss150']
# a100=df['loss100']
# a50=df['loss50']
# ########## Plot RMSE ####################
# fig = plt.figure()
# plt.rcParams['xtick.direction'] = 'in'
# plt.rcParams['ytick.direction'] = 'in'
# ax = fig.add_subplot()
# ax.plot(a50)
# ax.plot(a100)
# ax.plot(a150)
# ax.set_ylabel("Traveltime MSE loss")
# ax.set_xlabel("Epoch")
# plt.tick_params(top=True,bottom=True,left=True,right=True)
# plt.tight_layout()
# ax.legend(['50 atoms','100 atoms', '150 atoms'])
# plt.show()
# fig.savefig(save_path + 'mse_noise_{}.pdf'\.format(0.05))

### Inversion by dictionary learning

In [None]:
patch_size = 10
patches  = getPatches(w1, w2, patch_size)        # calculating image patch indices
percZero = patchSamp(A, patches)       # percentage ray coverage in patches
ss = np.zeros((vb.size, 1))
npp, npatches = patches.shape
nrays, npix   = A.shape
normError = [np.linalg.norm(Tarr_n - A @ (ss + sRef))]
print('Trave time initinal error: {}'.format(normError))
natom = 150
iters = 50
lam2  = 0

for i in range(iters):
    dt  = Tarr_n - A @ (ss + sRef)
    spA = csc_matrix(A)
    ds  = lsqr(spA, dt, damp=10, iter_lim=1e3)[0]
    ds  = np.expand_dims(ds, axis=1)
    sg = ds + ss
    Y   = sg.flatten()[patches]
    meanY = np.mean(Y, axis=0, keepdims=True)
    Yc = Y - meanY
    Yl = Yc[:, percZero <= 0.1]
    D = itkm(Yl, natom, 2, 50)
    X = omp_n(D, Yc, 2)
    ss_b = D @ X + meanY
    ss_p_sum = np.zeros(ss.shape)
    
    for k in range(npatches):
        ss_p_sum[patches[:, k], 0] = ss_p_sum[patches[:, k], 0] + ss_b[:, k]
    
    ss_f = (lam2 * sg + ss_p_sum)/(lam2 + npp)
    ss   = ss_f * vb2
    Tref = A @ (ss + sRef)
    normError_new = np.linalg.norm(Tref - Tarr_n)
    print('Iter: {}, inversion norm error: {}'.format(i+1, normError_new))

### Plot dictionary

In [None]:
vd(D, patch_size, natom//patch_size, 'DL', save=save_path)

### Plot inversion results by dictionary learning

In [None]:
y_DL  = np.reshape(ss + sRef, (w1, w2))
slice_DL = y_DL[:, 35]
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax = fig.add_subplot()
im = ax.imshow(y_DL, extent=extent)
ax.set_xlabel("Range (km)")
ax.set_ylabel("Range (km)")
ax.text(80, 5, '{:.2f}'.format(rmse(sTrue, y_DL, vb1)), fontsize=14)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
fig.savefig(save_path + 'DL_iters_{}_natom_{}_noise_{}_ps_{}.pdf'.format(iters, natom, noiseFrac, patch_size))

###########################polt slice###############################################
fig = plt.figure()
ax4 = fig.add_subplot()
ax4.plot(vsT, range(len(vsT)))
ax4.plot(slice_DL, range(len(slice_DL)))
ax4.text(0.6, 5, '{:.2f}'.format(rmse(vsT, slice_DL, vmask)), fontsize=14)
ax4.set_ylabel("Range (km)")
ax4.set_xlabel("Slowness (s/km)")
ax4.legend(['True','Dicttionary learning'])
plt.tick_params(top=True,bottom=True,left=True,right=True)
ax4.margins(0)
plt.tight_layout()
plt.show()
################################plot errors####################################
# fig = plt.figure()
# plt.rcParams['xtick.direction'] = 'in'
# plt.rcParams['ytick.direction'] = 'in'
# ax = fig.add_subplot()
# im = ax.imshow((sTrue - y)*vb1, extent=extent)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im, cax=cax, fraction=0.046)
# plt.tick_params(top=True,bottom=True,left=True,right=True)
# plt.tight_layout()
# plt.show()
# fig.savefig(save_path + 'SDerrors_iters_{}_natom_{}_noise_{}_ps_{}.pdf'.format(iters, natom, noiseFrac, patch_size))