In [None]:
import torch
import numpy as np
from dict_layer import Model, dictloss
import os
import NES
from utlis import survey_deisgn, patchSamp, getPatches, rmse
from LsTomo import conventional_tomo2
import matplotlib.pyplot as plt
from scipy.sparse.linalg import lsqr
from itkm import itkm
from omp_n import omp_n
from scipy.sparse import csc_matrix
from mpl_toolkits.axes_grid1 import make_axes_locatable
####################################################################################
np.random.seed(123)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
####################################################################################
cuda = torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_dtype(torch.float64)
print(torch.cuda.device_count(),torch.cuda.get_device_properties(0), cuda, device)
####################################################################################
model = Model()
model = model.to(device)
####################################################################################
save_path = './Marmousi_smooth3_noise5/'
if not os.path.exists(save_path):
    os.mkdir(save_path)

#### Load Marmousi model

In [None]:
Vel2D = NES.misc.Marmousi(smooth=3, section=[[600, 900], None]) # importing from NES package data
vmin, vmax = Vel2D.min, Vel2D.max
xmin, zmin = Vel2D.xmin
xmax, zmax = Vel2D.xmax
nx, nz = 100, 100
x = np.linspace(xmin, xmax, nx)
z = np.linspace(zmin, zmax, nz)
Xr_2d = np.stack(np.meshgrid(x, z, indexing='ij'), axis=-1)
V_2d = Vel2D(Xr_2d)
sTrue = np.transpose(1/V_2d)
plt.imshow(sTrue)
plt.show()

In [None]:
# Vel2D = NES.misc.Marmousi(smooth=3)
# vmin, vmax = Vel2D.min, Vel2D.max
# xmin, zmin = Vel2D.xmin
# xmax, zmax = Vel2D.xmax
# nx, nz = 4000, 2000
# x = np.linspace(xmin, xmax, nx)
# z = np.linspace(zmin, zmax, nz)
# Xr_2d = np.stack(np.meshgrid(x, z, indexing='ij'), axis=-1)
# V_2d = Vel2D(Xr_2d)
# fig = plt.figure()
# ax = fig.add_subplot()
# im = ax.imshow(np.transpose(1/V_2d))
# ax.set_xticklabels([])
# ax.set_yticklabels([])
# ax.set_xticks([])
# ax.set_yticks([])
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im, cax=cax, fraction=0.046)
# plt.savefig(save_path + 'Marmousi.pdf')  
# plt.show()

### Setting stations and get trace-rays

In [None]:
Tarr, A, vb, vb2 = survey_deisgn(64, sTrue, save_path=save_path)
noiseFrac = 0.05                                          # noise STD as fraction of mean value of travel time (0=noise free case)                                                
stdNoise = np.mean(Tarr) * noiseFrac
noise    = stdNoise * np.random.randn(Tarr.shape[0], Tarr.shape[1])
Tarr_n   = Tarr + noise                                  
Asum = np.sum(A, axis=1, keepdims=True)                                    # estimating referense slowness from travel time observations
invAsum = np.linalg.pinv(Asum)
sRef = invAsum @ Tarr_n                                     # Tarr_n: 1x2016 A: 2016x10000 vb:10000x1
#######################################################################################
if noiseFrac == 0:                  
    eta  = 0.1                                              # conventional \eta regularization parameter
    L    = 10                                               # smoothness length scale
else:
    eta  = 10      
    L    = 20                                              
#######################################################################################
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
criterion = dictloss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100], gamma=0.1)

w1, w2 = sTrue.shape
patch_size = 20
patches  = getPatches(w1, w2, patch_size)        # calculating image patch indices
percZero = patchSamp(A, patches)       # percentage ray coverage in patches
ss = np.zeros((vb.size, 1))
npp, npatches = patches.shape
nrays, npix   = A.shape
normError = [np.linalg.norm(Tarr_n - A @ (ss + sRef))]
print('Trave time initinal error: {}'.format(normError))
############################################################
epoch = 50
natom = 150
iters = 50
lam2  = 0
D_sparse = 1
C_sparse = 1
##########################################################

### Training or dictionary updating

In [None]:
ds = conventional_tomo2(eta, L, A, sRef, Tarr_n, sTrue, vb)
Y  = ds.flatten()[patches]
meanY = np.mean(Y, axis=0, keepdims=True)
Yc = Y - meanY
D0 = itkm(Yc, natom, D_sparse)
X  = omp_n(D0, Yc, C_sparse)
D0 = torch.tensor(D0).to(device)
RMSE = []
###################### Ditcionary Learning ###############################
for j in range(epoch):
    D = model(D0)
    loss = criterion(D, X, ss, vb, npatches, patches, \
        npp, sRef, A, Tarr_n, meanY, ds, lam2, device)
    optimizer.zero_grad()
    loss.backward()               
    optimizer.step()
    lr_scheduler.step()         
    print('epoch: {}/{}, loss: {:.8f}'.format(j+1, epoch, loss.item()))
    RMSE.append(loss.item())
    
# torch.save(model.state_dict(), os.path.join(save_path, ('dict_noise{}_eta{}_L{}_ps{}_natom{}_lam2{}_D_sparse{}_C_sparse{}.pth').\
#     format(noiseFrac, eta, L, patch_size, natom, lam2, D_sparse, C_sparse)))
D = D.detach().squeeze().cpu().numpy() 

### Load the saved weight

In [None]:
# model.load_state_dict(torch.load(os.path.join(save_path, 'dict_noise0.02_eta10_L20_ps20_natom50_lam20_D_sparse1_C_sparse1.pth')))
# model.eval()
# D = model(D0)
# D = D.detach().squeeze().cpu().numpy() 

### Reconstruct the velocity model

In [None]:
if noiseFrac == 0.02:
    C_sparse = 25
elif noiseFrac == 0.05:
    C_sparse = 5
else:
    print('noiseFrac setting is not correct!')
    
X = omp_n(D, Yc, C_sparse)
ss_b = D @ X + meanY
ss_p_sum = np.zeros(ss.shape)

for k in range(npatches):
    ss_p_sum[patches[:, k], 0] = ss_p_sum[patches[:, k], 0] + ss_b[:, k]

ss_f = (lam2 * ds + ss_p_sum)/(lam2 + npp)
ss   = ss_f * vb  
Tref = A @ (ss + sRef)
normError_new = np.linalg.norm(Tref - Tarr_n)
print('Inversion norm error: {}'.format(normError_new))

### Plot the updated dictionary

In [None]:
# vd(D, patch_size, natom//patch_size, 'us_ps20', save=save_path)

### Plot slice of velocity model

In [None]:
y  = np.reshape(ss + 0.*ds*vb + sRef, (w1, w2))
v_mask = np.reshape(vb, sTrue.shape)
slice_sTrue = sTrue[:, 35]
slice_Pre = y[:, 35]
slice_mask = v_mask[:, 35]
extent = 0, w1, 0, w2
###########################plot result#######################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax = fig.add_subplot()
im = ax.imshow(y, extent=extent)
ax.set_xlabel("Range (km)")
ax.set_ylabel("Range (km)")
ax.text(80, 5, '{:.2f}'.format(rmse(sTrue, y, v_mask)), fontsize=14)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'slowness_natom_{}_noise_{}_sp_{}.pdf'\
#     .format( natom, noiseFrac, C_sparse), dpi=600)

###########################polt slice###############################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax4 = fig.add_subplot()
ax4.plot(slice_sTrue, range(len(slice_sTrue)))
ax4.plot(slice_Pre, range(len(slice_Pre)))
ax4.text(0.6, 5, '{:.2f}'.format(rmse(slice_sTrue, slice_Pre, slice_mask)), fontsize=14)
ax4.set_ylabel("Range (km)")
ax4.set_xlabel("Slowness (s/km)")
# ax4.set_title('Vertical slice RMSE:{:.4f}'.format(rmse(vsT, vsP, vmask)))
ax4.legend(['True','Predict'])
plt.tick_params(top=True,bottom=True,left=True,right=True)
ax4.margins(0)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'slice_natom_{}_noise_{}_sp_{}.pdf'\
#     .format(natom, noiseFrac, C_sparse), dpi=600)
################################plot errors####################################
# fig = plt.figure()
# plt.rcParams['xtick.direction'] = 'in'
# plt.rcParams['ytick.direction'] = 'in'
# ax = fig.add_subplot()
# im = ax.imshow((sTrue - y)*vb1, extent=extent)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im, cax=cax, fraction=0.046)
# plt.tick_params(top=True,bottom=True,left=True,right=True)
# plt.tight_layout()
# plt.show()
# fig.savefig(save_path + 'errors_natom_{}_noise_{}_sp_{}.pdf'\
#     .format(natom, noiseFrac, C_sparse), dpi=600)

## Inversion by LSQR

In [None]:
y_lsqr  = np.reshape(ds*vb + sRef, (w1, w2))
slice_lsqr = y_lsqr[:, 35]

fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax1 = fig.add_subplot()
im1 = ax1.imshow(y_lsqr, extent=extent)
ax1.set_xlabel("Range (km)")
ax1.set_ylabel("Range (km)")
ax1.text(80, 5, '{:.2f}'.format(rmse(sTrue, y_lsqr, v_mask)), fontsize=14)
divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im1, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'LSQR_noise_{}.pdf'\
#     .format(noiseFrac), dpi=600)

################################plot slice#######################################
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
ax4 = fig.add_subplot()
ax4.plot(slice_sTrue, range(len(slice_sTrue)))
ax4.plot(slice_lsqr, range(len(slice_lsqr)))
ax4.set_ylabel("Range (km)")
ax4.set_xlabel("Slowness (s/km)")
ax4.text(0.6, 5, '{:.2f}'.format(rmse(slice_sTrue, slice_lsqr, slice_mask)), fontsize=14)
ax4.legend(['True','Predict'])
ax4.margins(0)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
# fig.savefig(save_path + 'LSQR_slice_noise_{}.pdf'\
#     .format(noiseFrac), dpi=600)
################################plot errors####################################
# fig = plt.figure()
# ax = fig.add_subplot()
# im = ax.imshow((sTrue - y2)*vb1, extent=extent)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im, cax=cax, fraction=0.046)
# plt.tick_params(top=True,bottom=True,left=True,right=True)
# # plt.tight_layout()
# plt.show()
# fig.savefig(save_path + 'LSerrors_noise_{}.pdf'\
#     .format(noiseFrac), dpi=600)

## Inversion by dictionary learning

In [None]:
np.random.seed(123)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [None]:
Tarr, A, vb, vb2 = survey_deisgn(64, sTrue, save_path=save_path)                                               
stdNoise = np.mean(Tarr) * noiseFrac
noise    = stdNoise * np.random.randn(Tarr.shape[0], Tarr.shape[1])
Tarr_n   = Tarr + noise                                  
Asum = np.sum(A, axis=1, keepdims=True)                                    # estimating referense slowness from travel time observations
invAsum = np.linalg.pinv(Asum)
sRef = invAsum @ Tarr_n                                     # Tarr_n: 1x2016 A: 2016x10000 vb:10000x1
#######################################################################################
w1, w2 = sTrue.shape
patch_size = 10
patches  = getPatches(w1, w2, patch_size)        # calculating image patch indices
percZero = patchSamp(A, patches)       # percentage ray coverage in patches
ss = np.zeros((vb.size, 1))
npp, npatches = patches.shape
nrays, npix   = A.shape
normError = [np.linalg.norm(Tarr_n - A @ (ss + sRef))]
print('Trave time initinal error: {}'.format(normError))
############################################################
epoch = 50
natom = 150
iters = 50
lam2  = 0
D_sparse = 1
C_sparse = 1

In [None]:
for i in range(iters):
    dt  = Tarr_n - A @ (ss + sRef)
    spA = csc_matrix(A)
    ds  = lsqr(spA, dt, damp=10, iter_lim=1e3)[0]
    ds  = np.expand_dims(ds, axis=1)
    sg = ds + ss
    Y   = sg.flatten()[patches]
    meanY = np.mean(Y, axis=0, keepdims=True)
    Yc = Y - meanY
    Yl = Yc[:, percZero <= 0.1]
    D = itkm(Yl, natom, 2, 50)
    X = omp_n(D, Yc, 2)
    ss_b = D @ X + meanY
    ss_p_sum = np.zeros(ss.shape)
    
    for k in range(npatches):
        ss_p_sum[patches[:, k], 0] = ss_p_sum[patches[:, k], 0] + ss_b[:, k]
    
    ss_f = (lam2 * sg + ss_p_sum)/(lam2 + npp)
    ss   = ss_f * vb2
    Tref = A @ (ss + sRef)
    normError_new = np.linalg.norm(Tref - Tarr_n)
    print('Iter: {}, inversion norm error: {}'.format(i+1, normError_new))

### Plot dictionary

In [None]:
# vd(D, patch_size, natom//patch_size, 'DL', save=save_path)

### Plot inversion results by dictionary learning

In [None]:
y_DL  = np.reshape(ss + sRef, (w1, w2))
v_mask = np.reshape(vb, sTrue.shape)
slice_sTrue = sTrue[:, 35]
slice_mask = v_mask[:, 35]
slice_DL = y_DL[:, 35]
fig = plt.figure()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
extent = 0, w1, 0, w2
ax = fig.add_subplot()
im = ax.imshow(y_DL, extent=extent)
ax.set_xlabel("Range (km)")
ax.set_ylabel("Range (km)")
ax.text(80, 5, '{:.2f}'.format(rmse(sTrue, y_DL, v_mask)), fontsize=14)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, fraction=0.046)
plt.tick_params(top=True,bottom=True,left=True,right=True)
plt.tight_layout()
plt.show()
fig.savefig(save_path + 'DL_iters_{}_natom_{}_noise_{}_ps_{}.pdf'\
    .format(iters, natom, noiseFrac, patch_size))

###########################polt slice###############################################
fig = plt.figure()
ax4 = fig.add_subplot()
ax4.plot(slice_sTrue, range(len(slice_sTrue)))
ax4.plot(slice_DL, range(len(slice_DL)))
ax4.text(0.6, 5, '{:.2f}'.format(rmse(slice_sTrue, slice_DL, slice_mask)), fontsize=14)
ax4.set_ylabel("Range (km)")
ax4.set_xlabel("Slowness (s/km)")
ax4.legend(['True','Dicttionary learning'])
plt.tick_params(top=True,bottom=True,left=True,right=True)
ax4.margins(0)
plt.tight_layout()
plt.show()
fig.savefig(save_path + 'DL_slice_noise_{}.pdf'.format(noiseFrac))
################################plot errors####################################
# fig = plt.figure()
# plt.rcParams['xtick.direction'] = 'in'
# plt.rcParams['ytick.direction'] = 'in'
# ax = fig.add_subplot()
# im = ax.imshow((sTrue - y)*vb1, extent=extent)
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im, cax=cax, fraction=0.046)
# plt.tick_params(top=True,bottom=True,left=True,right=True)
# plt.tight_layout()
# plt.show()
# fig.savefig(save_path + 'SDerrors_iters_{}_natom_{}_noise_{}_ps_{}.pdf'\
#     .format(iters, natom, noiseFrac, patch_size), dpi=600)

### Plot slices of all methods

In [None]:
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(sRef[0,0]*np.ones(len(slice_sTrue)), range(len(slice_sTrue)),'k')
ax.plot(slice_sTrue, range(len(slice_sTrue)))
ax.plot(slice_lsqr, range(len(slice_lsqr)))
ax.plot(slice_DL, range(len(slice_DL)))
ax.plot(slice_Pre, range(len(slice_Pre)))
ax.set_ylabel("Range (km)")
ax.set_xlabel("Slowness (s/km)")
ax.legend(['Reference slowness', 'True','LSQR','Dictionary learning', 'The proposed method'])
plt.tick_params(top=True, bottom=True, left=True, right=True)
ax.margins(0)
plt.tight_layout()
plt.show()
fig.savefig(save_path + 'slice_of_all_methods_noise_{}.pdf'.format(noiseFrac))