### This notebook shows how to train a RENO using a single velocity model. Please adjust the training data size as needed. 

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'
device_ids = [0]
import torch
import numpy as np
import tqdm

from utils import lossfunc
from modules.reno import RENO
from torch.utils.data import Dataset

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# raw
xrange = 85e3 
zrange = 20e3
h = 125  # spacing in m
nx = int(1.5+xrange/h)
nz = int(1.5+zrange/h)
xcoor = np.arange(nx) * h
zcoor = np.arange(nz) * h

# downsample 
ds = 2
h = 125 * ds  # spacing in m
nx = int(1.5+xrange/h)
nz = int(1.5+zrange/h)
xcoor = np.arange(nx) * h
zcoor = np.arange(nz) * h
assert xcoor[-1] == xrange

# time-freq
start_time_in_seconds = 0  # training actually starts from 0 s
end_time_in_seconds = 50.0
dt = 0.1
i0 = round(-start_time_in_seconds / dt)
fs = 1 / dt
fn = fs / 2  # Nyquist frequency
T = round((end_time_in_seconds - start_time_in_seconds) / dt + 1)
nt = T - i0 
if nt % 2 == 0:
    NT = nt
else:
    NT = nt + 1

freqs = torch.arange(NT // 2 + 1) * fs / (NT - 1)
freqmin = 0.1
freqmax = 0.5
freq_to_keep = torch.where((freqs>=freqmin)&(freqs<=freqmax))[0].tolist()
NF = len(freq_to_keep)

In [3]:
batch_size = 32

lr = 1e-4
weight_decay = 1e-2

nepoch = 100

In [4]:
latent_dim = [32, 16]

encoder_config = {
            "input_dim": 3,
            "enc_dim": 128,  
            "enc_depth": 5,
            "enc_num_heads": 4,
            'radius': 0.033,  
            'P': 1,
            'H': latent_dim[0],
            'W': latent_dim[1],
        }

decoder_config = {
            "output_dim": 2,
            "dec_dim": 128,
            "dec_depth": 1,
            "dec_num_heads": 4,
            'P': 1,
            'H': latent_dim[0],
            'W': latent_dim[1],          
        }

In [5]:
class MyDataset(Dataset):
    def __init__(self, data_path, offset_vel, ndp, input_pos, query_pos, norm_freq, nsrc):
        self.data_path = data_path
        self.offset_vel = offset_vel
        self.ndp = ndp
        self.input_pos = input_pos  # (ngp_in, 2)
        self.query_pos = query_pos  # (ngp_latent, 2)
        self.norm_freq = norm_freq  # (nf,)
        self.nfreq = len(norm_freq)
        self.nsrc = nsrc

    def __len__(self):
        return self.ndp
    
    def __getitem__(self, idx):
        iv, js, kf = self.get_iv_js_kf(idx)
        v = torch.from_numpy(np.load(self.data_path+"v_iv{}.npy".format(iv+self.offset_vel))).float()  # (nx*nz, 2)
        srcloc = torch.from_numpy(np.load(self.data_path+"srcloc_iv{}.npy".format(iv+self.offset_vel))).float()  # (nsrc, 2)
        recloc = torch.from_numpy(np.load(self.data_path+"recloc_iv{}.npy".format(iv+self.offset_vel))).float()  # (nrec, 2)
        f = self.norm_freq[kf].unsqueeze(0).repeat(len(v), 1)  # (nx*nz, 1)
        input_feat = torch.cat((v, f), dim=-1)  # (nx*nz, 3)
        y = torch.from_numpy(np.load(self.data_path+"y_iv{}_js{}_kf{}.npy".format(iv+self.offset_vel, js, kf))).float()  # (nrec, 2)
        rec_pos = recloc  # (nrec, 2)
        src_pos = srcloc[js]  # (2,)
        src_pos = src_pos.unsqueeze(0).repeat(len(rec_pos), 1)  # (nrec, 2)
        return input_feat, self.input_pos, self.query_pos, src_pos, rec_pos, y
    
    def get_iv_js_kf(self, idx):
        '''
        get vel index, source index, and frequency index from the flattened index
        '''
        iv = idx // (self.nsrc * self.nfreq)
        js = (idx % (self.nsrc * self.nfreq)) // self.nfreq
        kf = idx % self.nfreq
        return iv, js, kf


In [6]:
def make_2d_grid(dims, x1=None, x2=None, x1_min=0, x1_max=1, x2_min=0, x2_max=1):
    if x1 is None:  
        x1 = torch.linspace(x1_min, x1_max, dims[0])
    if x2 is None:
        x2 = torch.linspace(x2_min, x2_max, dims[1])

    assert x1.min() >= x1_min and x1.max() <= x1_max
    assert x2.min() >= x2_min and x2.max() <= x2_max
    x1, x2 = torch.meshgrid(x1, x2, indexing='ij')
    grid = torch.cat((
        x1.contiguous().view(x1.numel(), 1),
        x2.contiguous().view(x2.numel(), 1),
        ),
        dim=1)    
    return grid  # (ng2, 3)

In [7]:
input_pos = make_2d_grid([nx, nz])
query_pos = make_2d_grid(latent_dim)

In [8]:
norm_freq = freqs[freq_to_keep]
norm_freq = (norm_freq - norm_freq.mean()) / norm_freq.std()

In [9]:
train_set = MyDataset(
    "../data/", 
    offset_vel=0, 
    ndp=1*1*NF, # nv * nsrc * nf
    input_pos=input_pos, 
    query_pos=query_pos, 
    norm_freq=norm_freq, 
    nsrc=1, 
    ) 
valid_set = MyDataset(
    "../data/", 
    offset_vel=1, 
    ndp=1*1*NF,  # nv * nsrc * nf
    input_pos=input_pos, 
    query_pos=query_pos, 
    norm_freq=norm_freq, 
    nsrc=1, 
    ) 


train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4, 
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True,
)

valid_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=128,
    shuffle=False,
    num_workers=4, 
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True,
)


In [10]:
for input_feat, input_pos, query_pos, src_pos, rec_pos, y in train_loader:
    print(input_feat.shape, input_pos.shape, query_pos.shape, src_pos.shape, rec_pos.shape, y.shape)
    break
    

torch.Size([20, 27621, 3]) torch.Size([20, 27621, 2]) torch.Size([20, 512, 2]) torch.Size([20, 339, 2]) torch.Size([20, 339, 2]) torch.Size([20, 339, 2])


In [11]:
model = RENO(encoder_config, decoder_config)
model = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)

In [12]:
L2 = lossfunc.LpLoss(p=2, size_average=True) 
rel_l2_train = np.full(nepoch, np.nan)
rel_l2_valid = np.full(nepoch, np.nan)
rel_l2_reci = np.full(nepoch, np.nan)
for i in tqdm.trange(nepoch):
    model.train()
    l2_train = 0
    count = 0
    for input_feat, input_pos, query_pos, src_pos, rec_pos, y in train_loader:
        input_feat, input_pos, query_pos, src_pos, rec_pos, y = \
            input_feat.to(device), input_pos.to(device), query_pos.to(device), src_pos.to(device), rec_pos.to(device), y.to(device)
        out = model(input_feat, input_pos, query_pos, src_pos, rec_pos)
        loss = L2(out, y)
        optimizer.zero_grad()
        loss.backward()
        l2_train += loss.item() * len(y)
        count += len(y)
        optimizer.step()
    
    rel_l2_train[i] = l2_train / count

    model.eval()
    l2_valid = 0
    l2_reci = 0
    count = 0
    with torch.no_grad():
        for input_feat, input_pos, query_pos, src_pos, rec_pos, y in valid_loader:
            input_feat, input_pos, query_pos, src_pos, rec_pos, y = \
                input_feat.to(device), input_pos.to(device), query_pos.to(device), src_pos.to(device), rec_pos.to(device), y.to(device)
            out = model(input_feat, input_pos, query_pos, src_pos, rec_pos)
            loss = L2(out, y)
            l2_valid += loss.item() * len(y)
            count += len(y)

            # evaluate reciprocity
            j = np.random.randint(0, rec_pos.shape[1])
            rec_pos_j = rec_pos[:, [j], :]  # (b, 1, 2)
            src_pos_j = src_pos[:, [j], :]  # (b, 1, 2)
            out_ori = model(input_feat, input_pos, query_pos, src_pos_j, rec_pos_j)  # (b, 1, 2)
            out_reci = model(input_feat, input_pos, query_pos, rec_pos_j, src_pos_j)  # (b, 1, 2)
            loss = L2(out_reci, out_ori)
            l2_reci += loss.item() * len(y)
        
        rel_l2_valid[i] = l2_valid / count
        rel_l2_reci[i] = l2_reci / count
        
    torch.cuda.empty_cache()
            
    scheduler.step()
    

100%|██████████| 100/100 [00:31<00:00,  3.17it/s]


In [13]:
torch.save(model.state_dict(), "../model/myreno.pth")