##  Example: Extract U-Trans Latent Representation and Decoder Features

This example shows how to:

- Load the full **U-Trans foundation model** (UNET + transformer backbone)
- Extract the latent token representation **(75, 80)**
- Generate the decoder features **(6000, 1)** that are ready to concatenate with downstream tasks Example: Extract U-Trans Latent Representation and Decoder Features

This example shows how to:

- Load the full **U-Trans foundation model** (UNET + transformer backbone)
- Extract the latent token representation **(75, 80)**
- Generate the decoder features **(6000, 1)** that are ready to concatenate with downstream tasks

In [1]:
"""
U-Trans Foundation Model Usage Example
=======================================

This script:

1) Loads pretrained U-Trans foundation weights.
2) Extracts the latent transformer representation.
3) Builds the decoder feature stream (6000, 1) ready for concatenation
   with downstream tasks.

Expected input shape:
    (B, 6000, 3)

Where:
    B = batch size
    6000 = waveform length
    3 = three-component waveform
"""

import os
import sys
import numpy as np

# ---------------------------------------------------------
# If notebook is inside examples/, add repo root to path
# ---------------------------------------------------------
sys.path.insert(0, os.path.abspath(".."))

from utrans.foundation import get_latent_model, get_decoder_model


# ---------------------------------------------------------
# Path to pretrained foundation weights
# ---------------------------------------------------------
UNET_WEIGHTS = "../weights/UTrans_Foundation.h5"


# ---------------------------------------------------------
# Model that outputs transformer latent tokens
# Expected output shape: (B, 75, 80)
# ---------------------------------------------------------
latent_model = get_latent_model(UNET_WEIGHTS)


# ---------------------------------------------------------
# Model that outputs decoder features
# ready_to_concatenate_model -> Keras model
# Featuear_Ready_to_Concatenate -> feature tensor shape
# Expected decoder feature shape: (B, 6000, 1)
# ---------------------------------------------------------
ready_to_concatenate_model, Featuear_Ready_to_Concatenate = get_decoder_model(UNET_WEIGHTS)


# ---------------------------------------------------------
# Display architecture and output shape
# ---------------------------------------------------------
ready_to_concatenate_model.summary()
Featuear_Ready_to_Concatenate.shape

2026-02-26 00:50:27.340993: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-26 00:50:30.257050: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31615 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:18:00.0, compute capability: 7.5
2026-02-26 00:50:30.258429: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 17760 MB memory:  -> device: 1, name: Quadro RTX 8000, pci bus id: 0000:3b:00.0, compute capability: 7.5
2026-02-26 00:50:30.259476: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1532] Created device /job:localhost/replica:0/task:0/device:G

Model: "UTrans_decoder_features"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input (InputLayer)             [(None, 6000, 3)]    0           []                               
                                                                                                  
 conv1d_29 (Conv1D)             (None, 6000, 5)      50          ['input[0][0]']                  
                                                                                                  
 conv1d_30 (Conv1D)             (None, 6000, 5)      80          ['conv1d_29[0][0]']              
                                                                                                  
 max_pooling1d_5 (MaxPooling1D)  (None, 3000, 5)     0           ['conv1d_30[0][0]']              
                                                                            

 layer_normalization_12 (LayerN  (None, 75, 80)      160         ['add_10[0][0]']                 
 ormalization)                                                                                    
                                                                                                  
 dense_10 (Dense)               (None, 75, 80)       6480        ['layer_normalization_12[0][0]'] 
                                                                                                  
 dropout_10 (Dropout)           (None, 75, 80)       0           ['dense_10[0][0]']               
                                                                                                  
 dense_11 (Dense)               (None, 75, 80)       6480        ['dropout_10[0][0]']             
                                                                                                  
 dropout_11 (Dropout)           (None, 75, 80)       0           ['dense_11[0][0]']               
          

 activation (Activation)        (None, 75, 80)       0           ['batch_normalization[0][0]']    
                                                                                                  
 conv1d_59 (Conv1D)             (None, 75, 80)       83280       ['activation[0][0]']             
                                                                                                  
 batch_normalization_1 (BatchNo  (None, 75, 80)      320         ['conv1d_59[0][0]']              
 rmalization)                                                                                     
                                                                                                  
 activation_1 (Activation)      (None, 75, 80)       0           ['batch_normalization_1[0][0]']  
                                                                                                  
 add_16 (Add)                   (None, 75, 80)       0           ['activation_1[0][0]',           
          

TensorShape([None, 6000, 1])