# Model Retraining

## Preliminaries

In [1]:
# general
import os
from tqdm import tqdm

# wandb - hyperparameter sweep and Train monitoring
import wandb
#torch - computing and machine learning libraries
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn.functional as F
# seisbench
import seisbench.models as sbm

#plotting
import matplotlib.pyplot as plt
# seisynth
from utils.common import load_dataset_and_labels, load_pretrained_model

In [2]:
# Possible values
DATASETS_ORIGINS = ['ethz', 'geofon']
SBM_CLASSES= [sbm.PhaseNet, sbm.EQTransformer]
MODEL_TO_NUM_SAMPLES = {sbm.EQTransformer:6000, sbm.PhaseNet: 3001}

In [3]:
dataset_origin = 'geofon'
assert dataset_origin in DATASETS_ORIGINS, f'Expected dataset one of {DATASETS_ORIGINS}. Got {dataset_origin}.'

In [4]:
SBM_CLASS= sbm.EQTransformer
assert SBM_CLASS in SBM_CLASSES
SBM_CLASS

seisbench.models.eqtransformer.EQTransformer

In [5]:
NUM_SAMPLES=MODEL_TO_NUM_SAMPLES[SBM_CLASS]
NUM_SAMPLES

6000

In [6]:
NUM_SHIFTS=6
SAMPLE_RATE=100
LARGE_ERROR_THRESHOLD_SECONDS=1
LARGE_ERROR_THRESHOLD_SAMPLES=LARGE_ERROR_THRESHOLD_SECONDS*SAMPLE_RATE
SYNTHESIZED_SNR=10

In [7]:
DATASET_PATH=f'/home/moshe/datasets/GFZ/noisy_datasets/{dataset_origin}_{NUM_SAMPLES}_sample_joachim_noises_energy_ratio_snr/'
NOISY_DATA_PATH = os.path.join(DATASET_PATH, f'noisy_dataset_snr_{SYNTHESIZED_SNR}')
DATASET_PATH, NOISY_DATA_PATH

('/home/moshe/datasets/GFZ/noisy_datasets/geofon_6000_sample_joachim_noises_energy_ratio_snr/',
 '/home/moshe/datasets/GFZ/noisy_datasets/geofon_6000_sample_joachim_noises_energy_ratio_snr/noisy_dataset_snr_10')

In [8]:
def assert_path_exists(path_str: str, name: str=''):
    assert os.path.exists(path_str), f'{name} {path_str} does not exist'

In [9]:
assert_path_exists(path_str=DATASET_PATH, name='DATASET_PATH')
assert_path_exists(path_str=NOISY_DATA_PATH, name='NOISY_DATA_PATH')

## Load Pretrained Model

Load the model with the pretrained weights

In [10]:
pretrained_model = load_pretrained_model(model_class=SBM_CLASS, dataset_trained_on=dataset_origin)

Working with <class 'seisbench.models.eqtransformer.EQTransformer'> on GEOFON
Load <class 'seisbench.models.eqtransformer.EQTransformer'> pretrained weights
<class 'seisbench.models.eqtransformer.EQTransformer'> pretrained keys ['ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic', 'obs', 'original', 'original_nonconservative', 'scedc', 'stead']


Save a copy for retraining. One model will be trained and the other one will keep the current weights for benchmarking on specific examples

In [11]:
# reloading because I cannot torch clone. Seisbench models are not nn.Module :(
retraining_model = load_pretrained_model(model_class=SBM_CLASS, dataset_trained_on=dataset_origin)

Working with <class 'seisbench.models.eqtransformer.EQTransformer'> on GEOFON
Load <class 'seisbench.models.eqtransformer.EQTransformer'> pretrained weights
<class 'seisbench.models.eqtransformer.EQTransformer'> pretrained keys ['ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic', 'obs', 'original', 'original_nonconservative', 'scedc', 'stead']


In [12]:
pretrained_model.eval()
retraining_model.train()

EQTransformer(
  (encoder): Encoder(
    (convs): ModuleList(
      (0): Conv1d(3, 8, kernel_size=(11,), stride=(1,), padding=(5,))
      (1): Conv1d(8, 16, kernel_size=(9,), stride=(1,), padding=(4,))
      (2): Conv1d(16, 16, kernel_size=(7,), stride=(1,), padding=(3,))
      (3): Conv1d(16, 32, kernel_size=(7,), stride=(1,), padding=(3,))
      (4): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,))
      (5): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
      (6): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    )
    (pools): ModuleList(
      (0): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1,

## Load Datasets


We have 4 datasets:
1. original_dataset - ETHZ/GEOFON original traces filtered to have high estimated SNR - more than 20dB
2. le_original_dataset - A subset of the original_dataset (high SNR traces) that the pretrained model had a large picking error.
3. noised_dataset - Traces taken from the original dataset and merged with noise traces such that the resulting trace is a 10 dB SNR trace.
4. le_noised_dataset -  A subset of the noised_dataset that the pretrained model had a large picking error.

In [13]:
original_dataset_path = os.path.join(DATASET_PATH, 'original_dataset.pt')
assert_path_exists(path_str=original_dataset_path)
original_labels_path = os.path.join(DATASET_PATH, 'original_labels.pt')
assert_path_exists(path_str=original_labels_path)

In [14]:
original_dataset, original_labels = load_dataset_and_labels(dataset_path=original_dataset_path, labels_path=original_labels_path)

In [15]:
print(f'Loaded {original_dataset.shape[0]} traces')

Loaded 11146 traces


In [16]:
le_original_dataset_path = os.path.join(DATASET_PATH, 'le_original_dataset.pt')
assert_path_exists(path_str=le_original_dataset_path)
le_original_labels_path = os.path.join(DATASET_PATH, 'le_original_labels.pt')
assert_path_exists(path_str=le_original_labels_path)

In [17]:
le_original_dataset, le_original_labels = load_dataset_and_labels(dataset_path=le_original_dataset_path, labels_path=le_original_labels_path)

In [18]:
print(f'Loaded {le_original_dataset.shape[0]} traces')

Loaded 297 traces


In [19]:
noised_dataset_path= os.path.join(NOISY_DATA_PATH, 'traces.pt')
assert_path_exists(path_str=noised_dataset_path)
noised_labels_path= os.path.join(NOISY_DATA_PATH, 'labels.pt')
assert_path_exists(path_str=noised_labels_path)
noised_dataset_path, noised_labels_path

('/home/moshe/datasets/GFZ/noisy_datasets/geofon_6000_sample_joachim_noises_energy_ratio_snr/noisy_dataset_snr_10/traces.pt',
 '/home/moshe/datasets/GFZ/noisy_datasets/geofon_6000_sample_joachim_noises_energy_ratio_snr/noisy_dataset_snr_10/labels.pt')

In [20]:
noised_dataset, noised_labels = load_dataset_and_labels(dataset_path=noised_dataset_path, labels_path=noised_labels_path)

In [21]:
print(f'Loaded {noised_dataset.shape[0]} traces')

Loaded 8000 traces


In [22]:
le_noised_dataset_path = os.path.join(NOISY_DATA_PATH, f'le_{str(SBM_CLASS)}_dataset.pt')
assert_path_exists(path_str=le_noised_dataset_path)
le_noised_labels_path = os.path.join(NOISY_DATA_PATH, f'le_{str(SBM_CLASS)}_labels.pt')
assert_path_exists(path_str=le_noised_labels_path)
le_noised_dataset_path, le_noised_labels_path

("/home/moshe/datasets/GFZ/noisy_datasets/geofon_6000_sample_joachim_noises_energy_ratio_snr/noisy_dataset_snr_10/le_<class 'seisbench.models.eqtransformer.EQTransformer'>_dataset.pt",
 "/home/moshe/datasets/GFZ/noisy_datasets/geofon_6000_sample_joachim_noises_energy_ratio_snr/noisy_dataset_snr_10/le_<class 'seisbench.models.eqtransformer.EQTransformer'>_labels.pt")

In [23]:
le_noised_dataset, le_noised_labels = load_dataset_and_labels(dataset_path=le_noised_dataset_path, labels_path=le_noised_labels_path)

In [24]:
print(f'Loaded {le_noised_dataset.shape[0]} traces')

Loaded 1687 traces


## Arrange Train\Validation\Test Sets

In [25]:
train_dataset_inds, val_dataset_inds, test_dataset_inds = random_split(range(noised_dataset.shape[0]), [0.8,0.1,0.1], generator=torch.Generator().manual_seed(42))

In [26]:
train_dataset, val_dataset, test_dataset = noised_dataset[train_dataset_inds], noised_dataset[val_dataset_inds], noised_dataset[test_dataset_inds]
train_labels, val_labels, test_labels = noised_labels[train_dataset_inds], noised_labels[val_dataset_inds], noised_labels[test_dataset_inds]

In [27]:
print(f'Created train set with {train_dataset.shape[0]} traces, validation set with {val_dataset.shape[0]} traces and test set with {test_dataset.shape[0]} traces.')

Created train set with 6400 traces, validation set with 800 traces and test set with 800 traces.


## Define Custom Datasets/DataLoader

In [28]:
@torch.no_grad()
def label_normal_smooth(label):
    num_samples = NUM_SAMPLES
    sigma = 1000.0
    v = torch.arange(num_samples).double()
    return (1.0/(sigma*torch.sqrt(2.0* torch.tensor(torch.pi)))) * torch.exp(-0.5*torch.square((v-label)/sigma))

In [29]:
class CustomDataset(Dataset):
    def __init__(self, dataset: torch.tensor, labels: torch.tensor, transform=None, target_transform=None):
        self._dataset = dataset
        self._labels = labels
        assert dataset.dim() == 3, f'Expected 3 dim dataset tensor (#traces,#channels,#samples). Got {dataset.dim()} dims. Shape {dataset.shape} '
        assert labels.shape[0] == dataset.shape[0], f'Expected 1 label per trace. Got {labels.shape[0]} for {dataset.shape[0]} traces'
        self._len = int(dataset.shape[0])
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return self._len

    def __getitem__(self, idx):
        trace = self._dataset[idx]
        label = self._labels[idx]
        if self.transform:
            trace = self.transform(trace)
        if self.target_transform:
            label = self.target_transform(label)
        return trace, label

In [30]:
trainset = CustomDataset(dataset=train_dataset, labels=train_labels, target_transform=lambda l: (l,label_normal_smooth(l)))
valset = CustomDataset(dataset=val_dataset, labels=val_labels, target_transform=lambda l: (l,label_normal_smooth(l)))
testset = CustomDataset(dataset=test_dataset, labels=test_labels, target_transform=lambda l: (l,label_normal_smooth(l)))

In [31]:
print(f'Created train set with {len(trainset)} traces, validation set with {len(valset)} traces and test set with {len(testset)} traces.')

Created train set with 6400 traces, validation set with 800 traces and test set with 800 traces.


## Train

### Define Loss Function

Taken from Seisbench tutorial notebook "03a_training_phasenet"  -  not using it for now - commented out

In [32]:
def loss_fn(y_pred, y_true, eps=1e-5):
    # vector cross entropy loss
    h = y_true * torch.log(y_pred + eps)

    h = h.mean(-1).sum(-1)  # Mean along sample dimension and sum along pick dimension

    h = h.mean()  # Mean over batch axis

    return -h

### Train Loop

In [33]:
def train_loop(dataloader, model, loss_fn, optimizer, large_error_threshold=100):
    model.train()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    train_loss = 0.0
    large_errors_counter = 0
    mean_residual = 0.0
    for batch_id, batch in enumerate(dataloader):
        trace, (label, label_smoothed) = batch

        # print('trace shape', trace.shape, 'label shape', label.shape, 'label_smoothed shape', label_smoothed.shape)

        batch_size = trace.shape[0]
        # print('batch_size', batch_size)

        # Compute prediction and loss
        # Fwd pass - outputs the likelihood function
        pred_probs = model(trace.to(model.device))

        if SBM_CLASS == sbm.EQTransformer:
            # EQTransformer returns a tuple (N,Z,E)
            pred_probs = torch.stack((pred_probs[1],pred_probs[0],pred_probs[2]), dim=0).swapaxes(0,1)
        # print('pred_probs shape', pred_probs.shape)


        # softargmax
        beta = 100.0
        softmax = torch.nn.functional.softmax(beta  * pred_probs[:, 0, :], dim=-1)
        indices = torch.arange(pred_probs[:, 0, :].shape[-1])
        softargmax_preds = torch.sum(torch.mul(indices, softmax), dim=-1)

        loss = torch.abs(softargmax_preds - label).mean()

        # loss = loss_fn(pred_probs[:,0,:], label_smoothed.double().to(model.device))
        # print('loss', loss)
        # loss = loss_fn(F.log_softmax(pred_probs[:,0,:], dim=-1), F.log_softmax(label.to(model.device), dim=-1))

        prediction = torch.argmax(pred_probs[:, 0, :], dim=-1)
        # print('prediction', prediction.shape)
        residual = torch.abs(prediction - label.to(model.device))
        # print(residual)
        mean_residual += float(residual.mean())
        # print('mean_residual', mean_residual)
        # print(residual > large_error_threshold)
        # print(residual[residual > large_error_threshold])
        # print((residual > large_error_threshold).sum())
        # large_errors_counter += (1 if residual > large_error_threshold else 0)
        large_errors_counter += int((residual > large_error_threshold).sum())
        # print(large_errors_counter)

        # if batch_id == 2:
        #     break

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # if batch_id % 100 == 0:
        #     loss, current = loss.item(), batch_id * trace.shape[0]
        #     print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    mean_residual /= num_batches
    train_loss /= num_batches

    # print(mean_residual)
    # print(train_loss)
    # raise Exception
    return train_loss, mean_residual, large_errors_counter

### Test Loop

In [34]:

@torch.no_grad()
def test_loop(dataloader, model, loss_fn, large_error_threshold=100):
    model.eval()
    num_batches = len(dataloader)
    test_loss = 0
    large_errors_counter = 0
    mean_residual = 0.0
    with torch.no_grad():
        for batch in dataloader:
            trace, (label, label_smoothed) = batch
            pred_probs = model(trace.to(model.device))
            if SBM_CLASS == sbm.EQTransformer:
                # EQTransformer returns a tuple (N,Z,E)
                pred_probs = torch.stack((pred_probs[1],pred_probs[0],pred_probs[2]), dim=0).swapaxes(0,1)
            # Take the maximum of the z channel prediction
            loss = loss_fn(pred_probs[:,0,:], label_smoothed.to(model.device))
            # loss = loss_fn(F.log_softmax(pred_probs[:,0,:], dim=-1), F.log_softmax(label_smoothed.to(model.device), dim=-1))
            test_loss += loss.item()

            prediction = torch.argmax(pred_probs[:, 0, :], dim=-1)
            residual = float(torch.abs(prediction - label.to(model.device)))
            mean_residual += residual
            large_errors_counter += (1 if residual > large_error_threshold else 0)

    mean_residual /= num_batches
    test_loss /= num_batches
    return test_loss, mean_residual, large_errors_counter

In [35]:
def train(trainset, valset, trained_model, benchmark_model, epochs, learning_rate, batch_size):
    # init weights&biases monitoring
    wandb.init(project="seisynth", entity="moshebeutel")
    wandb.config = {"learning_rate": learning_rate, "epochs": epochs, "batch_size": 1}

    train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
    val_dataloader = DataLoader(valset, batch_size=1, shuffle=False)

    # Define the train optimizer and optimization criterion
    # optimizer = torch.optim.SGD(trained_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-3)
    optimizer = torch.optim.Adam(trained_model.parameters(), lr=learning_rate)

    # criterion = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)
    # criterion = torch.nn.MSELoss()

    criterion = loss_fn

    # Evaluate the benchmark model on the validation data.
    # The benchmark model is not training so it is done once.
    benchmark_loss, benchmark_mean_residual, benchmark_large_errors_counter =  test_loop(dataloader=val_dataloader, model=benchmark_model, loss_fn=criterion, large_error_threshold=LARGE_ERROR_THRESHOLD_SAMPLES)
    print(f'Benchmark results: loss {benchmark_loss}, Mean Residual {benchmark_mean_residual}, Errors Above {LARGE_ERROR_THRESHOLD_SECONDS} sec. {benchmark_large_errors_counter}')
    pbar = tqdm(range(epochs))
    for t in pbar:
        epoch_train_loss, epoch_train_mean_residual, epoch_train_large_errors_counter = train_loop(dataloader=train_dataloader, model=trained_model, loss_fn=criterion , optimizer=optimizer, large_error_threshold=LARGE_ERROR_THRESHOLD_SAMPLES)
        epoch_val_loss, epoch_val_mean_residual, epoch_val_large_errors_counter = test_loop(dataloader=val_dataloader, model=trained_model, loss_fn=criterion, large_error_threshold=LARGE_ERROR_THRESHOLD_SAMPLES)

        wandb.log({'epoch train loss': epoch_train_loss,
                   'epoch_train_mean_residual':epoch_train_mean_residual,
                   'epoch_train_large_errors_counter':epoch_train_large_errors_counter,
                   'epoch validation loss': epoch_val_loss,
                   'epoch_val_mean_residual':epoch_val_mean_residual,
                   'epoch_val_large_errors_counter':epoch_val_large_errors_counter})

        pbar.set_description(f'Epoch {t}, train loss {epoch_train_loss}, validation loss {epoch_val_loss}, epoch_val_large_errors_counter {epoch_val_large_errors_counter}')


In [36]:
# import torch.nn.functional as F
#
# def normal_smooth(label):
#     num_samples = 3001
#     sigma = 1000
#     v = torch.arange(num_samples).float()
#     return (1.0/(sigma*torch.sqrt(2.0* torch.tensor(torch.pi)))) * torch.exp(-0.5*torch.square((v-label)/sigma))
#
# a = normal_smooth(1000) * 1e6
# b = normal_smooth(1200) * 1e6
#
# fig, (ax_a, ax_b) = plt.subplots(1,2, sharey='all');
# ax_a.plot(a);
# ax_b.plot(b);
#
# torch.nn.MSELoss()(a,b)

### Training Hyperparameters

In [37]:
EPOCHS = 300
LEARNING_RATE = 1e-6
BATCH_SIZE = 32

## Call Train Entry Point

In [38]:
# test_dataloader = DataLoader(testset, batch_size=1, shuffle=False)

In [None]:
train(trainset=trainset, valset=valset, trained_model=retraining_model.double(), benchmark_model=pretrained_model.double(), epochs=EPOCHS, learning_rate=LEARNING_RATE, batch_size=BATCH_SIZE)

[34m[1mwandb[0m: Currently logged in as: [33mmoshebeutel[0m. Use [1m`wandb login --relogin`[0m to force relogin


Benchmark results: loss 0.0013749178744094418, Mean Residual 92.25378262500013, Errors Above 1 sec. 149


Epoch 13, train loss 90.38634143490681, validation loss 0.0014182093281871258, epoch_val_large_errors_counter 138:   5%|▍         | 14/300 [1:23:01<28:26:55, 358.10s/it]

In [None]:
benchmark_dataloader = DataLoader(valset, batch_size=1, shuffle=False)
# train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

criterion = loss_fn
benchmark_loss, benchmark_mean_residual, benchmark_large_errors_counter =  test_loop(dataloader=benchmark_dataloader, model=pretrained_model.double(), loss_fn=criterion, large_error_threshold=LARGE_ERROR_THRESHOLD_SAMPLES)

In [None]:
benchmark_loss, benchmark_mean_residual, benchmark_large_errors_counter, float(benchmark_large_errors_counter) / float(len(benchmark_dataloader))

In [None]:
benchmark_loss, benchmark_mean_residual, benchmark_large_errors_counter, float(benchmark_large_errors_counter) / float(len(benchmark_dataloader))

In [None]:
# softargmax_demo.py

# import torch
#
# def softargmax(x):
#   # crude: assumes max value is unique
#   beta = 100.0
#   xx = beta  * x
#   sm = torch.nn.functional.softmax(xx, dim=-1)
#   indices = torch.arange(x.shape[-1])
#   y = torch.mul(indices, sm)
#   result = torch.sum(y, dim=-1)
#   return result
#
# print("\nBegin PyTorch softargmax demo ")
#
# t = torch.randint(low=1, high=100, size=(2,10))
# print("\nSource tensor: ")
# print(t)
#
# sam = softargmax(t)
# print("\nValue of softargmax(): ")
# print(sam)
#
# print("\nEnd ")