In [1]:
import seaborn as sns
from collections import defaultdict

import seisbench.data as sbd
import seisbench.generate as sbg
import numpy as np
import matplotlib.pyplot as plt
from seisbench.util import worker_seeding
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import einops
import tqdm
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
from collections import defaultdict
from sklearn.decomposition import PCA
from seisLM.utils import project_path

from seisLM.model.foundation.pretrained_models import LitMultiDimWav2Vec2, MultiDimWav2Vec2ForPreTraining

  from .autonotebook import tqdm as notebook_tqdm


## Model

Dataloader

1. only extract 1000 samples for each class (earthquake vs noise)
2. return meta data for each sample

In [2]:
class MetaDataKeepingSteeredGenerator(sbg.SteeredGenerator):
  def _clean_state_dict(self, state_dict):
    # Remove control information
    trace_type = state_dict["_control_"]["trace_type"]

    X, meta = state_dict["X"]
    path_ep_distance_km = meta.get("path_ep_distance_km", np.inf)
    path_hyp_distance_km = meta.get("path_hyp_distance_km", np.inf)

    state_dict = {
      "X": X,
      'trace_type': trace_type,
      "path_ep_distance_km": path_ep_distance_km,
      "path_hyp_distance_km": path_hyp_distance_km,
    }
    return state_dict

def get_loader():
  dataset_name = 'InstanceCountsCombined'
  task = '1'
  num_samples_per_trace_type = 1000

  dataset = sbd.__getattribute__(dataset_name)(
    sampling_rate=100,
    component_order="ZNE",
    dimension_order="NCW",
    missing_component="copy",
    cache=None
  )
  metadata_df = dataset.metadata

  eval_set = 'dev'
  split = dataset.get_split(eval_set)

  # task_csv = f'/home/liu0003/Desktop/projects/seisLM/data/targets/{dataset_name}/task{task}.csv'
  task_csv = project_path.gitdir() + f'/data/targets/{dataset_name}/task{task}.csv'
  task_targets = pd.read_csv(task_csv)
  task_targets = task_targets[task_targets["trace_split"] == eval_set]


  eq_targets = task_targets[task_targets['trace_type'] == 'earthquake'].head(num_samples_per_trace_type)
  noise_targets = task_targets[task_targets['trace_type'] == 'noise'].head(num_samples_per_trace_type)

  task_targets = pd.concat([eq_targets, noise_targets])

  generator = MetaDataKeepingSteeredGenerator(split, task_targets)
  generator.add_augmentations(
    [
      sbg.SteeredWindow(windowlen=3001, strategy="pad"),
      sbg.ChangeDtype(np.float32),
      sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="std"),
    ]
  )
  batch_size  = 10
  num_workers = 2
  loader = DataLoader(
    generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
  )
  return loader


loader = get_loader()

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_collections = {}

for model_type in ['pretrained', 'random_init']:

  model = LitMultiDimWav2Vec2.load_from_checkpoint(
    # '/home/liu0003/Desktop/projects/seisLM/results/models/pretrained_seisLM/pretrain_config_layernorm_std_small_batch_6_datasets_42__2024-08-14-09h-06m-17s/checkpoints/epoch=33-step=893792.ckpt',
    project_path.gitdir() + \
      '/results/models/pretrained_seisLM/pretrain_config_std_norm_single_ax_8_datasets_sample_pick_false_42__2024-08-31-18h-41m-44s/checkpoints/epoch=35-step=1082700.ckpt',
  ).model

  if model_type == 'random_init':
    model = MultiDimWav2Vec2ForPreTraining(model.config)

  model = model.to(device)
  model = model.eval()
  model_collections[model_type] = model
  del model

/home/liu0003/miniconda3/envs/seisbench/lib/python3.9/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.4.0, which is newer than your current Lightning version: v2.2.5
/home/liu0003/miniconda3/envs/seisbench/lib/python3.9/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.4.0, which is newer than your current Lightning version: v2.2.5


In [4]:
embeddings_of_models = {}

for model_type, model in model_collections.items():
  batch_input_dict = defaultdict(list)
  batch_features_dict = defaultdict(list)

  for batch in tqdm.tqdm(loader):
    for key, value in batch.items():
      batch_input_dict[key].append(value)

    with torch.no_grad():
      input_values = batch['X']
      wav2vec2_output = model.wav2vec2(
        input_values=input_values.cuda(),
        output_hidden_states=True,
      )

    batch_features_dict['conv_features'].append(
      wav2vec2_output.extract_features.mean(axis=1)
    )

    for hidden_states_layer_idx, hidden_states in enumerate(wav2vec2_output.hidden_states):
      batch_features_dict[f'hidden_states_{hidden_states_layer_idx}'].append(
        hidden_states.mean(axis=1)
      )

  all_features_dict = defaultdict(list)
  all_input_values = defaultdict(list)

  for key, value in batch_features_dict.items():
    concat_features = torch.concatenate(value, axis=0).cpu().numpy()
    all_features_dict[key] = concat_features

  for key, value in batch_input_dict.items():
    if isinstance(value[0], torch.Tensor):
      concat_values = torch.cat(value, axis=0).cpu().numpy()
    elif isinstance(value[0], np.ndarray) or isinstance(value[0], list):
      concat_values = np.concatenate(value, axis=0)
    else:
      raise ValueError
    all_input_values[key] = concat_values


  embedding_dict = defaultdict(list)

  for key, value in tqdm.tqdm(all_features_dict.items()):
    # pca = PCA(n_components=2)
    # embedding_dict[key] = pca.fit_transform(value)

    tsne = TSNE(
      n_components=2,
      max_iter=500,
      n_iter_without_progress=150,
      n_jobs=2,
      random_state=0,
    )

    embedding = tsne.fit_transform(value)
    embedding_dict[key] = embedding

  embeddings_of_models[model_type] = embedding_dict

100%|██████████| 200/200 [00:17<00:00, 11.73it/s]
100%|██████████| 8/8 [00:42<00:00,  5.32s/it]
100%|██████████| 200/200 [00:17<00:00, 11.42it/s]
100%|██████████| 8/8 [00:40<00:00,  5.06s/it]


In [5]:
all_trace_types = all_input_values['trace_type']
embedding = embeddings_of_models['pretrained']['hidden_states_6']
raw_waveforms = all_input_values['X']

In [6]:
color_map = {
    'earthquake': "#9e0142",
    'noise': "#91bfdb"
}

colors = [color_map[a] for a in all_input_values['trace_type'] ]

In [7]:
import plotly.express as px
import numpy as np
from dash import Dash, dcc, html, Input, Output, no_update, callback
import plotly.graph_objects as go
from sklearn.manifold import TSNE



In [8]:

# Create 2D scatter plot
fig = go.Figure(data=[go.Scatter(
    x=embedding[:, 0],
    y=embedding[:, 1],
    mode='markers',
    marker=dict(
        size=5,  # Adjusted marker size for better visibility
        color=colors,
    )
)])

fig.update_traces(
    hoverinfo="none",
    hovertemplate=None,
)

extra_space_on_left = 300
# Update layout to make the plot square
fig.update_layout(
    width=600+extra_space_on_left,
    height=600,
    margin=dict(l=extra_space_on_left),  # Add extra space on the left and right
    xaxis=dict(scaleanchor="y", scaleratio=1),
    yaxis=dict(scaleanchor="x", scaleratio=1)
)

# Dash app layout
app = Dash(__name__)

app.layout = html.Div(
    className="container",
    children=[
        dcc.Graph(id="graph-5", figure=fig, clear_on_unhover=True),
        dcc.Tooltip(id="graph-tooltip-5", direction='bottom'),
    ],
)

# Callback to display hover data as a plotly figure
@callback(
    Output("graph-tooltip-5", "show"),
    Output("graph-tooltip-5", "bbox"),
    Output("graph-tooltip-5", "children"),
    Input("graph-5", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    hover_data = hoverData["points"][0]
    bbox = hover_data["bbox"]
    num = hover_data["pointNumber"]

    # Get the time series data for the hovered point
    # ts_data = time_series_data[num]
    ts_data = raw_waveforms[num]
    num_channels = ts_data.shape[0]
    # Create a Plotly figure for the time series data
    ts_fig = go.Figure()
    for i in range(num_channels):
        ts_fig.add_trace(go.Scatter(y=ts_data[i], mode='lines', name=f'Channel {i+1}'))

    ts_fig.update_layout(
        margin=dict(t=10, b=10, l=10, r=10),
        height=200,
        template="plotly_white",
        # title=f"Time Series",
        title=f"Waveform {num} of type {all_trace_types[num]}",
        title_x=0.5,
        title_y=0.95
    )

    children = dcc.Graph(
        figure=ts_fig,
        config={'displayModeBar': False},
        style={"width": "100%"}
    )

    return True, bbox, children

if __name__ == "__main__":
    app.run(debug=True)

