# Train a model

In [1]:
#Load the necessary packages
import seisbench.data as sbd
import seisbench.generate as sbg
import seisbench.models as sbm
from seisbench.util import worker_seeding

import numpy as np
import matplotlib.pyplot as plt
import torch, torchvision
from torch.utils.data import DataLoader
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
import os
import tensorflow as tf
import torch.nn as nn
from tensorflow import keras
from tensorflow.keras import backend as K

import argparse
from pathlib import Path
import pandas as pd
from tqdm import tqdm

from sklearn.metrics import roc_curve, precision_recall_curve, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
from sklearn.metrics import confusion_matrix
import seaborn as sns
import logging

  from .autonotebook import tqdm as notebook_tqdm
2024-02-26 02:56:33.443417: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-26 02:56:33.599953: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2024-02-26 02:56:33.599980: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2024-02-26 02:56:33.637322: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-26

In [2]:
from seisbench.util import DuplicateEvent
from seisbench.util import trains, evalu, models, utils, export_models, results_summary, plots

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.cuda.empty_cache()

In [3]:
#Configure the model parameters
#Will require you to already have hdf5 continuous data
config = {
"model": "GPD",
"data": "TAM_MF",
"model_args": {"lr":1e-3, "highpass": 2, "sigma": 50},
"trainer_args": {"max_epochs": 60, "gpus": 1},
"batch_size": 100,
"num_workers": 12,
}

In [5]:
from seisbench.util.trains import train, prepare_data
train(config=config, experiment_name="tammf_gpd_transfer", test_run=False)

Preloading waveforms: 100%|█████████████████████████████████████████████████████| 19822/19822 [00:04<00:00, 4807.39it/s]
Preloading waveforms: 100%|███████████████████████████████████████████████████████| 4247/4247 [00:00<00:00, 4521.83it/s]
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | GPD  | 1.7 M 
-------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.964     Total estimated model params size (MB)


Epoch 0:  82%|████████████████████████████████████████▎        | 198/241 [00:12<00:02, 16.35it/s, loss=0.699, v_num=3_3]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                | 0/43 [00:00<?, ?it/s][A
Epoch 0:  83%|████████████████████████████████████████▋        | 200/241 [00:14<00:02, 13.79it/s, loss=0.699, v_num=3_3][A
Epoch 0:  88%|██████████████████████████████████████████▉      | 211/241 [00:14<00:02, 14.44it/s, loss=0.699, v_num=3_3][A
Epoch 0:  92%|█████████████████████████████████████████████▏   | 222/241 [00:14<00:01, 15.04it/s, loss=0.699, v_num=3_3][A
Epoch 0:  97%|███████████████████████████████████████████████▎ | 233/241 [00:14<00:00, 15.57it/s, loss=0.699, v_num=3_3][A
Validating:  84%|███████████████████████████████████████████████████████████▍           | 36/43 [00:02<00:00, 25.25it/s][A
Epoch 0: 100%|█████████████████████████████████████████████████| 241/241 [00:15<00:00, 15.57it/s, los

Validating:   0%|                                                                                | 0/43 [00:00<?, ?it/s][A
Epoch 16:  87%|█████████████████████████████████████████▋      | 209/241 [00:16<00:02, 12.59it/s, loss=0.657, v_num=3_3][A
Epoch 16:  91%|███████████████████████████████████████████▊    | 220/241 [00:16<00:01, 13.07it/s, loss=0.657, v_num=3_3][A
Validating:  53%|█████████████████████████████████████▉                                 | 23/43 [00:03<00:01, 11.50it/s][A
Epoch 16:  96%|██████████████████████████████████████████████  | 231/241 [00:17<00:00, 13.55it/s, loss=0.657, v_num=3_3][A
Validating:  84%|███████████████████████████████████████████████████████████▍           | 36/43 [00:03<00:00, 18.97it/s][A
Epoch 16: 100%|████████████████████████████████████████████████| 241/241 [00:17<00:00, 13.74it/s, loss=0.657, v_num=3_3][A
Epoch 17:  82%|███████████████████████████████████████▍        | 198/241 [00:13<00:02, 15.12it/s, loss=0.665, v_num=3_3][A
Validati

Epoch 32:  91%|████████████████████████████████████████████▋    | 220/241 [00:16<00:01, 13.73it/s, loss=0.62, v_num=3_3][A
Validating:  51%|████████████████████████████████████▎                                  | 22/43 [00:02<00:01, 12.63it/s][A
Epoch 32:  96%|██████████████████████████████████████████████▉  | 231/241 [00:16<00:00, 14.24it/s, loss=0.62, v_num=3_3][A
Validating:  84%|███████████████████████████████████████████████████████████▍           | 36/43 [00:03<00:00, 22.24it/s][A
Epoch 32: 100%|█████████████████████████████████████████████████| 241/241 [00:16<00:00, 14.41it/s, loss=0.62, v_num=3_3][A
Epoch 33:  82%|███████████████████████████████████████▍        | 198/241 [00:13<00:02, 15.10it/s, loss=0.653, v_num=3_3][A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                | 0/43 [00:00<?, ?it/s][A
Epoch 33:  87%|█████████████████████████████████████████▋      | 209/241 [00:15<00:02, 13.43it/s, 

Epoch 48: 100%|████████████████████████████████████████████████| 241/241 [00:16<00:00, 14.90it/s, loss=0.609, v_num=3_3][A
Epoch 49:  82%|████████████████████████████████████████▎        | 198/241 [00:12<00:02, 15.61it/s, loss=0.64, v_num=3_3][A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                | 0/43 [00:00<?, ?it/s][A
Epoch 49:  87%|██████████████████████████████████████████▍      | 209/241 [00:15<00:02, 13.79it/s, loss=0.64, v_num=3_3][A
Epoch 49:  91%|████████████████████████████████████████████▋    | 220/241 [00:15<00:01, 14.32it/s, loss=0.64, v_num=3_3][A
Epoch 49:  96%|██████████████████████████████████████████████▉  | 231/241 [00:15<00:00, 14.83it/s, loss=0.64, v_num=3_3][A
Validating:  84%|███████████████████████████████████████████████████████████▍           | 36/43 [00:03<00:00, 20.65it/s][A
Epoch 49: 100%|█████████████████████████████████████████████████| 241/241 [00:16<00:00, 15.00it/s,

# Evaluate a model

In [5]:
#Generate targets for evaluation
from seisbench.util.generate_eval_targets import main, generate_task1, generate_task23, select_window_containing

main(dataset_name="TAM_MF", output="/home/lmho/train_test_models/tamnnet_mf_transfer_gpd/tam_mf", 
     tasks="1,2,3", sampling_rate=100, noise_before_events = True)

Preloading waveforms: 100%|███████████████████████████████████████████████████████| 8496/8496 [00:02<00:00, 3994.08it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 8496/8496 [00:04<00:00, 1717.26it/s]
100%|█████████████████████████████████████████████████████████████████████████████| 8496/8496 [00:05<00:00, 1668.07it/s]


In [6]:
#Evaluate the newly trained model
#If error, check evalu.py
from seisbench.util.evalu import main, _identify_instance_dataset_border
weights_path = Path("/home/lmho/train_test_models/tamnnet_mf_transfer_gpd/weights/tammf_gpd_transfer/")
targets_path = Path("/home/lmho/train_test_models/tamnnet_mf_transfer_gpd/tam_mf")
main(weights=weights_path, targets=targets_path, 
     sets="dev,test", batchsize=100, num_workers=24, sampling_rate=None)

Preloading waveforms: 100%|███████████████████████████████████████████████████████| 4249/4249 [00:00<00:00, 4976.16it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]



Predicting: 100%|███████████████████████████████████████████████████████████████████████| 85/85 [09:20<00:00,  6.60s/it]


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 100%|███████████████████████████████████████████████████████████████████████| 44/44 [04:56<00:00,  6.74s/it]


Preloading waveforms: 100%|███████████████████████████████████████████████████████| 4247/4247 [00:00<00:00, 4689.05it/s]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 100%|███████████████████████████████████████████████████████████████████████| 85/85 [09:20<00:00,  6.60s/it]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Predicting: 100%|███████████████████████████████████████████████████████████████████████| 44/44 [05:00<00:00,  6.83s/it]


# Export evaluation results

In [9]:
from seisbench.util import collect_results, results_summary
from seisbench.util.collect_results import *
from seisbench.util.results_summary import *

In [10]:
def evaluate_versions():
    versions_dir = Path("pred/tammf_gpd_transfer")
    stats_list = []
    for version in versions_dir.glob("version_*"):
        stats = {}
        version_stats = eval_task1(version)
        stats.update(version_stats)
        version_stats = eval_task23(version)
        stats.update(version_stats)
        
        # Extract version number from directory name and add it as a new column
        version_number = int(version.name.split("_")[-1])
        stats["version_number"] = version_number
        stats["experiment"] = "tammf_gpd"
        stats_list.append(stats)
    return pd.DataFrame(stats_list)

In [11]:
version_dir = Path("pred/tammf_gpd_transfer/version_0")
result_stats = eval_task1(version_dir)
print(result_stats)

version_dir = Path("pred/tammf_gpd_transfer/version_0")
result_stats = eval_task23(version_dir)
print(result_stats)

{'dev_det_precision': 0.7065196730659571, 'dev_det_recall': 0.8747940691927513, 'dev_det_f1': 0.7817034700315457, 'dev_det_auc': 0.8240598856729875, 'det_threshold': 0.46573508, 'test_det_precision': 0.6974853469464927, 'test_det_recall': 0.8686131386861314, 'test_det_f1': 0.7736996644295302, 'test_det_auc': 0.8180643930197714, 'test_det_TP': 3689, 'test_det_FP': 1600, 'test_det_FN': 558, 'test_det_TN': 2646}
{'dev_phase_precision': 0.9062784349408554, 'dev_phase_recall': 0.9252206223873665, 'dev_phase_f1': 0.9156515743507241, 'phase_threshold': 0.94521993, 'dev_phase_mcc': 0.8291121293069478, 'phase_threshold_mcc': 0.94688046, 'test_phase_precision': 0.9145613229214515, 'test_phase_recall': 0.9209065679925994, 'test_phase_f1': 0.9177229776446186, 'test_phase_mcc': 0.8347362779484178, 'test_phase_TP': 1991, 'test_phase_FP': 186, 'test_phase_FN': 171, 'test_phase_TN': 1972, 'dev_P_mean_s': 0.053107292150487685, 'dev_P_std_s': 1.9020964062066918, 'dev_P_mae_s': 1.2473107292150487, 'dev_S

In [12]:
stats_df = evaluate_versions()
results = pd.read_csv('results_gpd.csv')
results

Unnamed: 0.1,Unnamed: 0,dev_det_precision,dev_det_recall,dev_det_f1,dev_det_auc,det_threshold,test_det_precision,test_det_recall,test_det_f1,test_det_auc,...,dev_S_std_s,dev_S_mae_s,test_P_mean_s,test_P_std_s,test_P_mae_s,test_S_mean_s,test_S_std_s,test_S_mae_s,version_number,experiment
0,0,0.70652,0.874794,0.781703,0.82406,0.465735,0.697485,0.868613,0.7737,0.818064,...,2.131632,1.408121,0.040994,1.801449,1.167988,0.365593,2.067819,1.333026,0,tammf_gpd
1,1,0.707312,0.869616,0.780112,0.82218,0.494429,0.703576,0.866259,0.776488,0.822555,...,2.102846,1.391195,0.055703,1.835428,1.187988,0.338123,2.066533,1.347354,1,tammf_gpd
2,2,0.704606,0.87856,0.782026,0.824407,0.461117,0.696082,0.878502,0.776725,0.821999,...,2.032839,1.357716,0.059126,1.842121,1.186943,0.273591,2.073692,1.357567,2,tammf_gpd


# Export selected models and their associated weights

In [14]:
import yaml
import json
import copy

In [15]:
DATA_ALIASES = {"tam_mf": "TAM_MF",
                "tam_ml": "TAM_ML",
                "tam_og": "TAM_OG",}

json_base = {
    "docstring": "Model trained on DATASET for 100 epochs with a learning rate of LR.\n"
    "Threshold selected for optimal F1 score on in-domain evaluation. "
    "Depending on the target region, the thresholds might need to be adjusted.\n"
    "When using this model, please reference the SeisBench publications listed "
    "at https://github.com/seisbench/seisbench\n\n"
    "Jannes Münchmeyer, Jack Woollam (munchmej@gfz-potsdam.de, jack.woollam@kit.edu)",
    "model_args": {
        "component_order": "ZNE",
    },
    "seisbench_requirement": "0.3.0",
    "version": "1",
}

In [16]:
def main():
    full_res = pd.read_csv("results_transfer_learning.csv")
    full_res = full_res.dropna()

    # Create a new column 'model' with the text after the underscore
    full_res['model'] = full_res['experiment'].apply(lambda x: x.split('_')[1])

    # Create a new column 'data' with the text before the underscore
    full_res['data'] = full_res['experiment'].apply(lambda x: x.split('_')[0])

    # Add an underscore after 'tam' in the 'data' column
    full_res['data'] = full_res['data'].apply(lambda x: x[:3] + "_" + x[3:])
    
    # Add a new column 'lr' with a learning rate value of 0.001
    full_res['lr'] = 0.001
    
    # Replace 'eqt' with 'eqtransformer' in the 'model' column
    full_res['model'] = full_res['model'].replace('eqt', 'eqtransformer')

    #full_res = full_res[full_res["model"] != "gpd"].copy()
    #full_res["model"].replace("gpdpick", "gpd", inplace=True)

    for pair, subdf in tqdm(full_res.groupby(["data", "model"])):
        idx = get_optimal_model_idx(subdf)
        if idx is None:
            print(f"Skipping {pair}")
            continue

        export_model(subdf.iloc[idx])

In [17]:
def get_optimal_model_idx(subdf):
    """
    Identifies the optimal model among the candidates in subdf.
    The optimal model is determined as the model with the lowest average relative loss in the metrics.
    Example:
        Model 1: det_auc=1 phase_mcc=0.9
        Model 2: det_auc=0.98 phase_mcc = 1
        Here we will select model 2, because model 1 on average only achieves a performance of 0.95 compared to the
        optimum, but model 2 achieves 0.99.
    In contrast to the example, the model also takes P and S std into account.
    :param subdf:
    :return: idx or None if no model is valid
    """
    x = subdf[
        ["dev_det_auc", "dev_phase_mcc", "dev_P_std_s", "dev_S_std_s"]
    ].values.copy()
    x[:, 2:] = 1 / x[:, 2:]
    x /= np.max(x, axis=0, keepdims=True)
    means = np.nanmean(x, axis=1)
    if np.isnan(means).all():
        return None

    return np.nanargmax(means)

In [18]:
def generate_metadata(row):
    meta = copy.deepcopy(json_base)
    default_args = {}
    meta["docstring"] = meta["docstring"].replace("DATASET", DATA_ALIASES[row["data"]])
    meta["docstring"] = meta["docstring"].replace("LR", str(row["lr"]))
    if row["model"] in ["cred", "eqtransformer"]:
        det_threshold = row["det_threshold"]
        if np.isnan(det_threshold):
            det_threshold = (
                0.3  # Roughly the average detection threshold across datasets
            )
        default_args["detection_threshold"] = det_threshold
        if row["model"] == "eqtransformer":
            # As the outputs are independent, and the empirical phase_thresholds are usually close to 1,
            # we just suggest the detection threshold for each phase as well.
            default_args["P_threshold"] = det_threshold
            default_args["S_threshold"] = det_threshold

    elif row["model"] in ["dpppickerp", "dpppickers"]:
        pass

    elif row["model"] in ["phasenet", "basicphaseae", "dppdetect"]:
        meta["model_args"]["phases"] = "PSN"
        det_threshold = row["det_threshold"]
        if np.isnan(det_threshold):
            det_threshold = 0.4
        phase_threshold = row["phase_threshold"]
        if np.isnan(phase_threshold):
            phase_threshold = 1
        default_args["P_threshold"] = det_threshold * np.sqrt(phase_threshold)
        default_args["S_threshold"] = det_threshold / np.sqrt(phase_threshold)

    elif row["model"] == "gpd":
        meta["model_args"]["phases"] = "PSN"
        meta["model_args"]["filter_args"] = ["highpass"]
        meta["model_args"]["filter_kwargs"] = {"freq": 2}
        det_threshold = row["det_threshold"]
        if np.isnan(det_threshold):
            det_threshold = 0.8
        phase_threshold = row["phase_threshold"]
        if np.isnan(phase_threshold):
            phase_threshold = 1
        default_args["P_threshold"] = det_threshold * np.sqrt(phase_threshold)
        default_args["S_threshold"] = det_threshold / np.sqrt(phase_threshold)

    else:
        raise ValueError("Unknown model type")

    meta["default_args"] = default_args

    return meta


In [20]:
full_res = pd.read_csv("../results_transfer_learning.csv")
full_res = full_res.dropna()

# Create a new column 'model' with the text after the underscore
full_res['model'] = full_res['experiment'].apply(lambda x: x.split('_')[1])

# Create a new column 'data' with the text before the underscore
full_res['data'] = full_res['experiment'].apply(lambda x: x.split('_')[0])

# Add an underscore after 'tam' in the 'data' column
full_res['data'] = full_res['data'].apply(lambda x: x[:3] + "_" + x[3:])

# Add a new column 'lr' with a learning rate value of 0.001
full_res['lr'] = 0.001

full_res = full_res.reset_index(drop=True)

print(full_res)

    Unnamed: 0  dev_det_precision  dev_det_recall  dev_det_f1  dev_det_auc  \
0            0           0.500000        1.000000    0.666667     0.405578   
1            1           0.718122        0.673335    0.695008     0.765390   
2            0           0.514747        0.998117    0.679212     0.663730   
3            1           0.510226        0.986350    0.672551     0.634856   
4            2           0.545892        0.986820    0.702934     0.749867   
5            3           0.531030        0.994822    0.692440     0.730911   
6            4           0.526922        0.990351    0.687863     0.721670   
7            5           0.511162        0.996940    0.675814     0.679613   
8            6           0.531985        0.990351    0.692162     0.706629   
9            7           0.552471        0.978819    0.706292     0.759636   
10           8           0.530994        0.989880    0.691208     0.722686   
11           9           0.511945        0.998588    0.676876   

In [21]:
optimal_rows = []
for pair, subdf in tqdm(full_res.groupby(["data", "model"])):
    idx = get_optimal_model_idx(subdf)
    #print(idx)
    optimal_row = subdf.iloc[idx]
    print(optimal_row.name)
    #optimal_rows.append(optimal_row)
#optimal_dataframe = pd.concat(optimal_rows, axis=0)
#print(optimal_dataframe)

100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 516.23it/s]

1
9
18
19
20
21
23
24
25
26
27
29





In [28]:
indices = [1, 9, 18, 19, 20, 21, 23, 24, 25, 26, 27, 29]
selected_rows = full_res.iloc[indices]
new_dataframe = pd.DataFrame(selected_rows)
new_dataframe['model'] = new_dataframe['model'].replace('eqt', 'eqtransformer')
print(new_dataframe.iloc[2])

Unnamed: 0                      2
dev_det_precision        0.704606
dev_det_recall            0.87856
dev_det_f1               0.782026
dev_det_auc              0.824407
det_threshold            0.461117
test_det_precision       0.696082
test_det_recall          0.878502
test_det_f1              0.776725
test_det_auc             0.821999
test_det_TP                3731.0
test_det_FP                1629.0
test_det_FN                 516.0
test_det_TN                2617.0
dev_phase_precision      0.903604
dev_phase_recall         0.931723
dev_phase_f1             0.917448
phase_threshold          0.822797
dev_phase_mcc            0.830042
phase_threshold_mcc      0.841729
test_phase_precision     0.905754
test_phase_recall        0.924607
test_phase_f1            0.915084
test_phase_mcc           0.826919
test_phase_TP              1999.0
test_phase_FP               208.0
test_phase_FN               163.0
test_phase_TN              1950.0
dev_P_mean_s             0.025007
dev_P_std_s   

In [23]:
#This function select the best performing model versions and export their weights
#Need to define the model's associated weight path and the path to the weights of the best performing model

def export_model(row):
    output_base = Path("seisbench_models")
    #weights = Path("weights") / row["experiment"]
    #weights = Path("/home/lmho/train_test_models/tamnnet_mf_transfer_basicphaseAE/weights/tammf_basicphaseae_transfer")
    #weights = Path("/home/lmho/train_test_models/tamnnet_ml_transfer_gpd/weights/tamml_gpd_transfer")
    weights = Path("/home/lmho/train_test_models/tamnnet_mf_transfer_gpd/weights/tammf_gpd_transfer")
    #weights = Path("/home/lmho/train_test_models/tamnnet_og_transfer_gpd/weights/tamog_gpd_transfer")
    #weights = Path("/home/lmho/train_test_models/tamnnet_ml_transfer_eqt/weights/tamml_eqt_transfer")
    
    version = sorted(weights.iterdir())[-1]
    config_path = version / "hparams.yaml"
    with open(config_path, "r") as f:
        # config = yaml.safe_load(f)
        config = yaml.full_load(f)

    model_cls = models.__getattribute__(config["model"] + "Lit")
    #model = load_best_model(model_cls, weights, version.name)
    #model = models.GPDLit.load_from_checkpoint("/home/lmho/train_test_models/tamnnet_ml_transfer_gpd/weights/tamml_gpd_transfer_tamml_gpd_transfer/1_1/checkpoints/epoch=99-step=5699.ckpt")
    model = models.GPDLit.load_from_checkpoint("/home/lmho/train_test_models/tamnnet_mf_transfer_gpd/weights/tammf_gpd_transfer_tammf_gpd_transfer/2_2/checkpoints/epoch=99-step=9899.ckpt")
    #model = models.GPDLit.load_from_checkpoint("/home/lmho/train_test_models/tamnnet_og_transfer_gpd/weights/tamog_gpd_transfer_tamog_gpd_transfer/0_0/checkpoints/epoch=99-step=3599.ckpt")
    #model = models.BasicPhaseAELit.load_from_checkpoint("/home/lmho/train_test_models/tamnnet_ml_transfer_gpd/weights/tamml_gpd_transfer_tamml_gpd_transfer/1_1/checkpoints/epoch=99-step=5699.ckpt")
    #model = models.EQTransformerLit.load_from_checkpoint("/home/lmho/train_test_models/tamnnet_ml_transfer_eqt/weights/tamml_eqt_transfer_tamml_eqt_transfer/0_0/checkpoints/epoch=99-step=5699.ckpt")
    output_path = output_base / row["model"] / f"{row['data']}.pt.v2"
    json_path = output_base / row["model"] / f"{row['data']}.json.v2"
    output_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.model.state_dict(), output_path)

    meta = generate_metadata(row)
    with open(json_path, "w") as f:
        json.dump(meta, f, indent=4)

In [29]:
export_model(new_dataframe.iloc[2])