In [1]:
import os
import torch
import torch.nn as nn
from einops import rearrange
import ai_edge_torch
import numpy as np
from torch.serialization import add_safe_globals
import warnings

# Suppress the specific UserWarning from the Transformer module
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")

# Define the XYZProcessor
class XYZProcessor(nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.xyz_encoder = nn.Sequential(
            nn.Conv1d(3, hidden_dim // 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.xyz_encoder(x)

# Define the SMVProcessor
class SMVProcessor(nn.Module):
    def __init__(self, hidden_dim, sequence_length, dropout=0.2):
        super().__init__()
        self.smv_encoder = nn.Sequential(
            nn.Conv1d(1, hidden_dim // 2, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )
        
        self.threshold_learner = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.smv_encoder(x)
        threshold = self.threshold_learner(features)
        return features, threshold

# Define the DualPathFallDetector
class DualPathFallDetector(nn.Module):
    def __init__(
        self,
        acc_coords=4,
        sequence_length=128,
        hidden_dim=64,
        num_heads=8,
        depth=4,
        mlp_ratio=4,
        num_classes=2,
        dropout=0.3,
        use_skeleton=False
    ):
        super().__init__()
        
        self.sequence_length = sequence_length
        self.hidden_dim = hidden_dim
        
        # Processors
        self.phone_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.phone_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        self.watch_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.watch_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * mlp_ratio,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # You can set this to False if nested tensors are required
        )
        
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth,
            norm=nn.LayerNorm(hidden_dim)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def process_device_data(self, data):
        """Process data from one device with SMV calculation"""
        # Split XYZ and calculate SMV
        xyz_data = data[:, :, :3]  # [B, T, 3]
        smv_data = torch.norm(xyz_data, dim=2, keepdim=True)  # [B, T, 1]
        
        # Process XYZ coordinates
        xyz_data = rearrange(xyz_data, 'b t c -> b c t')
        xyz_features = self.phone_xyz_processor(xyz_data)  # [B, H, T/2]
        xyz_features = xyz_features.mean(dim=2)  # [B, H]
        
        # Process SMV signal
        smv_data = rearrange(smv_data, 'b t c -> b c t')
        smv_features, smv_threshold = self.phone_smv_processor(smv_data)
        smv_features = smv_features.mean(dim=2)  # [B, H]
        
        # Combine features
        device_features = torch.cat([xyz_features, smv_features], dim=1)  # [B, 2H]
        
        return device_features, smv_threshold

    def forward(self, data):
        """Forward pass with both classification and SMV features"""
        # Process phone data
        phone_features, phone_threshold = self.process_device_data(
            data['accelerometer_phone'].float()
        )
        
        # Process watch data
        watch_features, watch_threshold = self.process_device_data(
            data['accelerometer_watch'].float()
        )
        
        # Combine features
        combined = torch.cat([phone_features, watch_features], dim=1)
        fused = self.fusion(combined)
        
        # Temporal modeling
        temporal = fused.unsqueeze(1)
        temporal = self.transformer(temporal)
        
        # Classification
        pooled = temporal.mean(dim=1)
        logits = self.classifier(pooled)
        
        # Return both logits and SMV features
        smv_features = {
            'phone_smv': phone_threshold.squeeze(-1),
            'watch_smv': watch_threshold.squeeze(-1),
        }
        
        return logits, smv_features

# Define the Wrapper Module
class DualPathFallDetectorWrapper(nn.Module):
    def __init__(self, original_model):
        super(DualPathFallDetectorWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, accelerometer_phone, accelerometer_watch):
        data = {
            'accelerometer_phone': accelerometer_phone,
            'accelerometer_watch': accelerometer_watch
        }
        logits, smv_features = self.original_model(data)
        return logits, smv_features

# Initialize the model
model = DualPathFallDetector(
    acc_coords=4,
    sequence_length=128,
    hidden_dim=64,
    num_heads=8,
    depth=4,
    mlp_ratio=4,
    num_classes=2,
    dropout=0.3,
    use_skeleton=False
)

# Path to your checkpoint
model_path = "exps/smartfall_har/mobile_falldet/model_epoch_22_f1_0.9414.pth"  # Use forward slashes for cross-platform compatibility

# Option 1: Add safe globals and load checkpoint with weights_only=True
try:
    add_safe_globals([np.core.multiarray.scalar])
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=True)
    print("Checkpoint loaded successfully with weights_only=True.")
except AttributeError:
    print("add_safe_globals is not available in your PyTorch version. Please update PyTorch to >=2.1.0.")
    checkpoint = None
except Exception as e:
    print(f"Failed to load checkpoint with weights_only=True: {e}")
    print("Attempting to load without weights_only=True (security risk)...")
    try:
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        print("Checkpoint loaded successfully with weights_only=False.")
    except Exception as e2:
        print(f"Failed to load checkpoint with weights_only=False: {e2}")
        checkpoint = None

if checkpoint:
    # Load state_dict
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("State_dict loaded from checkpoint.")
    else:
        model.load_state_dict(checkpoint)
        print("State_dict loaded directly from checkpoint.")

    # Set the model to evaluation mode
    model.eval()

    # Wrap the model
    wrapped_model = DualPathFallDetectorWrapper(model).eval()

    # Prepare sample inputs as a tuple of tensors
    batch_size = 1
    sequence_length = 128
    channels_phone = 4  # Adjust based on your data
    channels_watch = 4  # Adjust based on your data

    sample_args = (
        torch.randn(batch_size, sequence_length, channels_phone),
        torch.randn(batch_size, sequence_length, channels_watch)
    )

    # Set PJRT_DEVICE to 'CPU' to address the CUDA-related RuntimeError
    os.environ['PJRT_DEVICE'] = 'CPU'

    # Convert the wrapped model to LiteRT
    try:
        edge_model = ai_edge_torch.convert(wrapped_model, sample_args)
        print("Model conversion successful!")
    except Exception as e:
        print(f"Model conversion failed: {e}")
else:
    print("Checkpoint loading failed. Conversion cannot proceed.")


2024-11-20 21:23:17.728622: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732159397.776086  275183 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732159397.794692  275183 cuda_blas.cc:1410] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-20 21:23:17.897936: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  add_safe_globals([np.core.multiarray.scalar])


Failed to load checkpoint with weights_only=True: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Attempting to load without weights_only=True (security risk)...
Checkpoint loaded su

I0000 00:00:1732159415.398424  275183 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4057 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2060 with Max-Q Design, pci bus id: 0000:01:00.0                                                   , compute capability: 7.5


INFO:tensorflow:Assets written to: /tmp/tmptujrlrz3/assets


INFO:tensorflow:Assets written to: /tmp/tmptujrlrz3/assets
W0000 00:00:1732159418.625831  275183 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1732159418.625903  275183 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-11-20 21:23:38.627858: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmptujrlrz3
2024-11-20 21:23:38.629511: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-11-20 21:23:38.629543: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmptujrlrz3
I0000 00:00:1732159418.644897  275183 mlir_graph_optimization_pass.cc:402] MLIR V1 optimization pass is not enabled
2024-11-20 21:23:38.647251: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-11-20 21:23:38.751646: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmptujrlrz3
2024-11-20 21:23:38.781

Model conversion successful!


In [2]:
# Export the LiteRT model to TFLite
edge_model.export('mobile_falldet2.tflite')
print("Model successfully exported to 'mobile_falldet.tflite'")

Model successfully exported to 'mobile_falldet.tflite'


In [1]:
import os
import torch
import torch.nn as nn
from einops import rearrange
import ai_edge_torch
import numpy as np
from torch.serialization import add_safe_globals
import warnings

# Suppress the specific UserWarning from the Transformer module
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")

# Define the XYZProcessor
class XYZProcessor(nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.xyz_encoder = nn.Sequential(
            nn.Conv1d(3, hidden_dim // 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.xyz_encoder(x)

# Define the SMVProcessor
class SMVProcessor(nn.Module):
    def __init__(self, hidden_dim, sequence_length, dropout=0.2):
        super().__init__()
        self.smv_encoder = nn.Sequential(
            nn.Conv1d(1, hidden_dim // 2, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )
        
        self.threshold_learner = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.smv_encoder(x)
        threshold = self.threshold_learner(features)
        return features, threshold

# Define the DualPathFallDetector
class DualPathFallDetector(nn.Module):
    def __init__(
        self,
        acc_coords=4,
        sequence_length=128,
        hidden_dim=64,
        num_heads=8,
        depth=4,
        mlp_ratio=4,
        num_classes=2,
        dropout=0.3,
        use_skeleton=False
    ):
        super().__init__()
        
        self.sequence_length = sequence_length
        self.hidden_dim = hidden_dim
        
        # Processors
        self.phone_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.phone_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        self.watch_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.watch_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * mlp_ratio,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # You can set this to False if nested tensors are required
        )
        
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth,
            norm=nn.LayerNorm(hidden_dim)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def process_device_data(self, data):
        """Process data from one device with SMV calculation"""
        # Split XYZ and calculate SMV
        xyz_data = data[:, :, :3]  # [B, T, 3]
        smv_data = torch.norm(xyz_data, dim=2, keepdim=True)  # [B, T, 1]
        
        # Process XYZ coordinates
        xyz_data = rearrange(xyz_data, 'b t c -> b c t')
        xyz_features = self.phone_xyz_processor(xyz_data)  # [B, H, T/2]
        xyz_features = xyz_features.mean(dim=2)  # [B, H]
        
        # Process SMV signal
        smv_data = rearrange(smv_data, 'b t c -> b c t')
        smv_features, smv_threshold = self.phone_smv_processor(smv_data)
        smv_features = smv_features.mean(dim=2)  # [B, H]
        
        # Combine features
        device_features = torch.cat([xyz_features, smv_features], dim=1)  # [B, 2H]
        
        return device_features, smv_threshold

    def forward(self, data):
        """Forward pass with both classification and SMV features"""
        # Process phone data
        phone_features, phone_threshold = self.process_device_data(
            data['accelerometer_phone'].float()
        )
        
        # Process watch data
        watch_features, watch_threshold = self.process_device_data(
            data['accelerometer_watch'].float()
        )
        
        # Combine features
        combined = torch.cat([phone_features, watch_features], dim=1)
        fused = self.fusion(combined)
        
        # Temporal modeling
        temporal = fused.unsqueeze(1)
        temporal = self.transformer(temporal)
        
        # Classification
        pooled = temporal.mean(dim=1)
        logits = self.classifier(pooled)
        
        # Return both logits and SMV features
        smv_features = {
            'phone_smv': phone_threshold.squeeze(-1),
            'watch_smv': watch_threshold.squeeze(-1),
        }
        
        return logits, smv_features

# Define the Wrapper Module
class DualPathFallDetectorWrapper(nn.Module):
    def __init__(self, original_model):
        super(DualPathFallDetectorWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, accelerometer_phone, accelerometer_watch):
        data = {
            'accelerometer_phone': accelerometer_phone,
            'accelerometer_watch': accelerometer_watch
        }
        logits, smv_features = self.original_model(data)
        return logits, smv_features

# Initialize the model
model = DualPathFallDetector(
    acc_coords=4,
    sequence_length=128,
    hidden_dim=64,
    num_heads=8,
    depth=4,
    mlp_ratio=4,
    num_classes=2,
    dropout=0.3,
    use_skeleton=False
)

# Path to your checkpoint
model_path = "exps/smartfall_har/mobile_falldet/model_epoch_22_f1_0.9414.pth"  # Use forward slashes for cross-platform compatibility

# Option 1: Add safe globals and load checkpoint with weights_only=True
try:
    add_safe_globals([np.core.multiarray.scalar])
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=True)
    print("Checkpoint loaded successfully with weights_only=True.")
except AttributeError:
    print("add_safe_globals is not available in your PyTorch version. Please update PyTorch to >=2.1.0.")
    checkpoint = None
except Exception as e:
    print(f"Failed to load checkpoint with weights_only=True: {e}")
    print("Attempting to load without weights_only=True (security risk)...")
    try:
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        print("Checkpoint loaded successfully with weights_only=False.")
    except Exception as e2:
        print(f"Failed to load checkpoint with weights_only=False: {e2}")
        checkpoint = None

if checkpoint:
    # Load state_dict
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("State_dict loaded from checkpoint.")
    else:
        model.load_state_dict(checkpoint)
        print("State_dict loaded directly from checkpoint.")

    # Set the model to evaluation mode
    model.eval()

    # Wrap the model
    wrapped_model = DualPathFallDetectorWrapper(model).eval()

    # Prepare sample inputs as a tuple of tensors
    batch_size = 1
    sequence_length = 128
    channels_phone = 4  # Adjust based on your data
    channels_watch = 4  # Adjust based on your data

    sample_args = (
        torch.randn(batch_size, sequence_length, channels_phone),
        torch.randn(batch_size, sequence_length, channels_watch)
    )

    # Set PJRT_DEVICE to 'CPU' to address the CUDA-related RuntimeError
    os.environ['PJRT_DEVICE'] = 'CPU'

    # Quantization Steps
    from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
    from torch._export import capture_pre_autograd_graph

    from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
    from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
    from ai_edge_torch.quantize.quant_config import QuantConfig

    # Initialize the PT2E Quantizer with symmetric quantization configuration
    pt2e_quantizer = PT2EQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
    )

    # Capture the pre-autograd graph of the wrapped model
    pt2e_torch_model = capture_pre_autograd_graph(wrapped_model, sample_args)

    # Prepare the model for PT2E quantization
    pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

    # Run the prepared model with sample input data to ensure that internal observers are populated with correct values
    pt2e_torch_model(*sample_args)

    # Convert the prepared model to a quantized model
    pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

    # Convert to an ai_edge_torch model with quantization configuration and additional converter flags
    try:
        pt2e_drq_model = ai_edge_torch.convert(
            pt2e_torch_model,
            sample_args,
            quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer),
            _ai_edge_converter_flags={'experimental_enable_resource_variables': True}
        )
        print("Model conversion successful!")
    except Exception as e:
        print(f"Model conversion failed: {e}")
        # Depending on the error, consider alternative approaches below

    # Save the quantized model with a different name
    try:
        quantized_model_path = "exps/smartfall_har/mobile_falldet/model_epoch_22_f1_0.9414_quant.pth"
        torch.save(pt2e_drq_model.state_dict(), quantized_model_path)
        print(f"Quantized model saved successfully at '{quantized_model_path}'.")
    except Exception as e:
        print(f"Failed to save quantized model: {e}")

    # Optional: Convert the quantized model to TFLite
    try:
        # Export the quantized LiteRT model to TFLite
        pt2e_drq_model.export('mobile_falldet_quant.tflite')
        print("Quantized model successfully exported to 'mobile_falldet_quant.tflite'.")
    except Exception as e:
        print(f"Failed to export quantized model to TFLite: {e}")

    # Optional: Validate the Quantized Model
    try:
        # Perform inference with the quantized PyTorch model
        with torch.no_grad():
            torch_quant_output = pt2e_drq_model(*sample_args)

        # Perform inference with the quantized LiteRT model
        # Assuming that pt2e_drq_model can perform inference like this
        # If not, you may need to load the TFLite model separately for inference
        tfl_quant_output = pt2e_drq_model(*sample_args)

        # Extract logits and SMV features
        torch_quant_logits = torch_quant_output[0].detach().numpy()
        tfl_quant_logits = tfl_quant_output[0]

        torch_quant_smv_features = {k: v.detach().numpy() for k, v in torch_quant_output[1].items()}
        tfl_quant_smv_features = {k: v for k, v in tfl_quant_output[1].items()}

        # Compare logits
        if np.allclose(torch_quant_logits, tfl_quant_logits, atol=1e-5, rtol=1e-5):
            print("Quantized inference result for logits with PyTorch and LiteRT matches within tolerance.")
        else:
            print("Discrepancy found in quantized logits between PyTorch and LiteRT models.")

        # Compare SMV features
        for key in torch_quant_smv_features:
            if np.allclose(torch_quant_smv_features[key], tfl_quant_smv_features[key], atol=1e-5, rtol=1e-5):
                print(f"Quantized inference result for {key} matches within tolerance.")
            else:
                print(f"Discrepancy found in quantized {key} between PyTorch and LiteRT models.")
    except Exception as e:
        print(f"Validation of quantized model failed: {e}")
else:
    print("Checkpoint loading failed. Conversion cannot proceed.")


2024-11-20 21:10:47.494061: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732158647.519374  271538 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732158647.527195  271538 cuda_blas.cc:1410] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-20 21:10:47.569229: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  add_safe_globals([np.core.multiarray.scalar])
W1120 21:10:51.585000 271538 torch/_export/__init__.py:67] capture_pre_autogr

Failed to load checkpoint with weights_only=True: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Attempting to load without weights_only=True (security risk)...
Checkpoint loaded su

I0000 00:00:1732158665.497585  271538 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4057 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2060 with Max-Q Design, pci bus id: 0000:01:00.0                                                   , compute capability: 7.5


INFO:tensorflow:Assets written to: /tmp/tmpt9ri5xs0/assets


INFO:tensorflow:Assets written to: /tmp/tmpt9ri5xs0/assets


Model conversion failed: Variable constant folding is failed. Please consider using enabling `experimental_enable_resource_variables` flag in the TFLite converter object. For example, converter.experimental_enable_resource_variables = True<unknown>:0: error: loc(callsite(callsite(callsite("torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl/__main__.DualPathFallDetector_original_model/__main__.SMVProcessor_phone_smv_processor/torch.nn.modules.container.Sequential_smv_encoder/torch.nn.modules.conv.Conv1d_0;" at fused["XlaCallModule:", "XlaCallModule@__inference_inner_572"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall@__inference_signature_wrapper_731"]) at fused["StatefulPartitionedCall:", "StatefulPartitionedCall"])): 'tfl.transpose' op has mismatched quantized axes of input and output
<unknown>:0: note: loc(fused["StatefulPartitionedCall:", "StatefulPartitionedCall"]): called from

Failed to save quantized model: name 'pt2e_drq_model' is not defined
Fa

W0000 00:00:1732158668.828573  271538 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1732158668.828642  271538 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-11-20 21:11:08.829464: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpt9ri5xs0
2024-11-20 21:11:08.831443: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-11-20 21:11:08.831467: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpt9ri5xs0
I0000 00:00:1732158668.850352  271538 mlir_graph_optimization_pass.cc:402] MLIR V1 optimization pass is not enabled
2024-11-20 21:11:08.853217: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-11-20 21:11:09.004122: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpt9ri5xs0
2024-11-20 21:11:09.039545: I tensorflow/cc/saved_model/loader.cc:466] SavedModel 

In [2]:
!pip install --upgrade torch torchvision ai-edge-torch tensorflow




Collecting ai-edge-torch
  Downloading ai_edge_torch-0.2.1-py3-none-any.whl (210 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.9/210.9 KB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting tensorflow
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.3/615.3 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting ai-edge-quantizer-nightly==0.0.1.dev20240718
  Using cached ai_edge_quantizer_nightly-0.0.1.dev20240718-py3-none-any.whl (100 kB)
Collecting torch-xla<2.6,>=2.4.0
  Downloading torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl (90.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 MB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting tabulate
  Using cached tabulate-0.9.0-py3-none-any.whl (35 kB)
Collecting tf-nightly>=2.18.0.dev20

In [3]:
import os
import torch
import torch.nn as nn
from einops import rearrange
import ai_edge_torch
import numpy as np
from torch.serialization import add_safe_globals
import warnings

# Suppress specific UserWarnings from the Transformer module
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")

# Define the XYZProcessor
class XYZProcessor(nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.xyz_encoder = nn.Sequential(
            nn.Conv1d(3, hidden_dim // 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.xyz_encoder(x)

# Define the SMVProcessor
class SMVProcessor(nn.Module):
    def __init__(self, hidden_dim, sequence_length, dropout=0.2):
        super().__init__()
        self.smv_encoder = nn.Sequential(
            nn.Conv1d(1, hidden_dim // 2, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )
        
        self.threshold_learner = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.smv_encoder(x)
        threshold = self.threshold_learner(features)
        return features, threshold

# Define the DualPathFallDetector
class DualPathFallDetector(nn.Module):
    def __init__(
        self,
        acc_coords=4,
        sequence_length=128,
        hidden_dim=64,
        num_heads=8,
        depth=4,
        mlp_ratio=4,
        num_classes=2,
        dropout=0.3,
        use_skeleton=False
    ):
        super().__init__()
        
        self.sequence_length = sequence_length
        self.hidden_dim = hidden_dim
        
        # Processors
        self.phone_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.phone_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        self.watch_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.watch_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * mlp_ratio,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # Can be set to False if needed
        )
        
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth,
            norm=nn.LayerNorm(hidden_dim)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def process_device_data(self, data):
        """Process data from one device with SMV calculation"""
        # Split XYZ and calculate SMV
        xyz_data = data[:, :, :3]  # [B, T, 3]
        smv_data = torch.norm(xyz_data, dim=2, keepdim=True)  # [B, T, 1]
        
        # Process XYZ coordinates
        xyz_data = rearrange(xyz_data, 'b t c -> b c t')
        xyz_features = self.phone_xyz_processor(xyz_data)  # [B, H, T/2]
        xyz_features = xyz_features.mean(dim=2)  # [B, H]
        
        # Process SMV signal
        smv_data = rearrange(smv_data, 'b t c -> b c t')
        smv_features, smv_threshold = self.phone_smv_processor(smv_data)
        smv_features = smv_features.mean(dim=2)  # [B, H]
        
        # Combine features
        device_features = torch.cat([xyz_features, smv_features], dim=1)  # [B, 2H]
        
        return device_features, smv_threshold

    def forward(self, data):
        """Forward pass with both classification and SMV features"""
        # Process phone data
        phone_features, phone_threshold = self.process_device_data(
            data['accelerometer_phone'].float()
        )
        
        # Process watch data
        watch_features, watch_threshold = self.process_device_data(
            data['accelerometer_watch'].float()
        )
        
        # Combine features
        combined = torch.cat([phone_features, watch_features], dim=1)
        fused = self.fusion(combined)
        
        # Temporal modeling
        temporal = fused.unsqueeze(1)
        temporal = self.transformer(temporal)
        
        # Classification
        pooled = temporal.mean(dim=1)
        logits = self.classifier(pooled)
        
        # Return both logits and SMV features
        smv_features = {
            'phone_smv': phone_threshold.squeeze(-1),
            'watch_smv': watch_threshold.squeeze(-1),
        }
        
        return logits, smv_features

# Define the Wrapper Module
class DualPathFallDetectorWrapper(nn.Module):
    def __init__(self, original_model):
        super(DualPathFallDetectorWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, accelerometer_phone, accelerometer_watch):
        data = {
            'accelerometer_phone': accelerometer_phone,
            'accelerometer_watch': accelerometer_watch
        }
        logits, smv_features = self.original_model(data)
        return logits, smv_features

# Initialize the model
model = DualPathFallDetector(
    acc_coords=4,
    sequence_length=128,
    hidden_dim=64,
    num_heads=8,
    depth=4,
    mlp_ratio=4,
    num_classes=2,
    dropout=0.3,
    use_skeleton=False
)

# Path to your checkpoint
model_path = "exps/smartfall_har/mobile_falldet/model_epoch_22_f1_0.9414.pth"  # Use forward slashes for cross-platform compatibility

# Option 1: Add safe globals and load checkpoint with weights_only=True
try:
    add_safe_globals([np.core.multiarray.scalar])
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=True)
    print("Checkpoint loaded successfully with weights_only=True.")
except AttributeError:
    print("add_safe_globals is not available in your PyTorch version. Please update PyTorch to >=2.1.0.")
    checkpoint = None
except Exception as e:
    print(f"Failed to load checkpoint with weights_only=True: {e}")
    print("Attempting to load without weights_only=True (security risk)...")
    try:
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        print("Checkpoint loaded successfully with weights_only=False.")
    except Exception as e2:
        print(f"Failed to load checkpoint with weights_only=False: {e2}")
        checkpoint = None

if checkpoint:
    # Load state_dict
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("State_dict loaded from checkpoint.")
    else:
        model.load_state_dict(checkpoint)
        print("State_dict loaded directly from checkpoint.")

    # Set the model to evaluation mode
    model.eval()

    # Wrap the model
    wrapped_model = DualPathFallDetectorWrapper(model).eval()

    # Prepare sample inputs as a tuple of tensors
    batch_size = 1
    sequence_length = 128
    channels_phone = 4  # Adjust based on your data
    channels_watch = 4  # Adjust based on your data

    sample_args = (
        torch.randn(batch_size, sequence_length, channels_phone),
        torch.randn(batch_size, sequence_length, channels_watch)
    )

    # Set PJRT_DEVICE to 'CPU' to address the CUDA-related RuntimeError
    os.environ['PJRT_DEVICE'] = 'CPU'

    # Quantization Steps
    from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
    from torch._export import capture_pre_autograd_graph

    from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
    from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
    from ai_edge_torch.quantize.quant_config import QuantConfig

    # Initialize the PT2E Quantizer with symmetric quantization configuration (per-tensor and dynamic)
    pt2e_quantizer = PT2EQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=False, is_dynamic=True)
    )

    # Capture the pre-autograd graph of the wrapped model
    pt2e_torch_model = capture_pre_autograd_graph(wrapped_model, sample_args)

    # Prepare the model for PT2E quantization
    pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

    # Run the prepared model with sample input data to ensure that internal observers are populated with correct values
    pt2e_torch_model(*sample_args)

    # Convert the prepared model to a quantized model
    pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

    # Conversion Flags
    _ai_edge_converter_flags = {
        'experimental_enable_resource_variables': True
    }

    # Convert to an ai_edge_torch model with quantization configuration and additional converter flags
    try:
        pt2e_drq_model = ai_edge_torch.convert(
            pt2e_torch_model,
            sample_args,
            quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer),
            _ai_edge_converter_flags=_ai_edge_converter_flags
        )
        print("Model conversion successful!")
    except Exception as e:
        print(f"Model conversion failed: {e}")
        # Depending on the error, consider alternative approaches below

    # Save the quantized model with a different name
    try:
        quantized_model_path = "exps/smartfall_har/mobile_falldet/model_epoch_22_f1_0.9414_quant.pth"
        torch.save(pt2e_drq_model.state_dict(), quantized_model_path)
        print(f"Quantized model saved successfully at '{quantized_model_path}'.")
    except Exception as e:
        print(f"Failed to save quantized model: {e}")

    # Optional: Convert the quantized model to TFLite
    try:
        # Export the quantized LiteRT model to TFLite
        pt2e_drq_model.export('mobile_falldet_quant.tflite')
        print("Quantized model successfully exported to 'mobile_falldet_quant.tflite'.")
    except Exception as e:
        print(f"Failed to export quantized model to TFLite: {e}")

    # Optional: Validate the Quantized Model
    try:
        # Perform inference with the quantized PyTorch model
        with torch.no_grad():
            torch_quant_output = pt2e_drq_model(*sample_args)

        # Perform inference with the quantized LiteRT model
        # Note: AI Edge Torch's LiteRT model might require a different inference approach.
        # Here, we assume it can be invoked similarly.
        # If not, you may need to load the TFLite model separately for inference.
        tfl_quant_output = pt2e_drq_model(*sample_args)

        # Extract logits and SMV features
        torch_quant_logits = torch_quant_output[0].detach().numpy()
        tfl_quant_logits = tfl_quant_output[0]

        torch_quant_smv_features = {k: v.detach().numpy() for k, v in torch_quant_output[1].items()}
        tfl_quant_smv_features = {k: v for k, v in tfl_quant_output[1].items()}

        # Compare logits
        if np.allclose(torch_quant_logits, tfl_quant_logits, atol=1e-5, rtol=1e-5):
            print("Quantized inference result for logits with PyTorch and LiteRT matches within tolerance.")
        else:
            print("Discrepancy found in quantized logits between PyTorch and LiteRT models.")

        # Compare SMV features
        for key in torch_quant_smv_features:
            if np.allclose(torch_quant_smv_features[key], tfl_quant_smv_features[key], atol=1e-5, rtol=1e-5):
                print(f"Quantized inference result for {key} matches within tolerance.")
            else:
                print(f"Discrepancy found in quantized {key} between PyTorch and LiteRT models.")
    except Exception as e:
        print(f"Validation of quantized model failed: {e}")
else:
    print("Checkpoint loading failed. Conversion cannot proceed.")


Failed to load checkpoint with weights_only=True: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Attempting to load without weights_only=True (security risk)...
Checkpoint loaded su



INFO:tensorflow:Assets written to: /tmp/tmp0rx_dv7s/assets


INFO:tensorflow:Assets written to: /tmp/tmp0rx_dv7s/assets
W0000 00:00:1732158995.136254  271538 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1732158995.136791  271538 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-11-20 21:16:35.138567: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmp0rx_dv7s
2024-11-20 21:16:35.140593: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-11-20 21:16:35.140633: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmp0rx_dv7s
2024-11-20 21:16:35.156286: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-11-20 21:16:35.275884: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmp0rx_dv7s
2024-11-20 21:16:35.299945: I tensorflow/cc/saved_model/loader.cc:466] SavedModel load for tags { serve }; Status: success: OK. Took 161527

Model conversion successful!
Failed to save quantized model: 'TfLiteModel' object has no attribute 'state_dict'
Quantized model successfully exported to 'mobile_falldet_quant.tflite'.
Validation of quantized model failed: 'numpy.ndarray' object has no attribute 'detach'


INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [4]:
import numpy as np
import tensorflow as tf

# Path to your TFLite model
tflite_model_path = "mobile_falldet_quant.tflite"

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()

# Get input and output tensor details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input Details:")
for detail in input_details:
    print(detail)

print("\nOutput Details:")
for detail in output_details:
    print(detail)

# Prepare sample input data
# Replace these with actual sensor data in practice
phone_input = np.random.rand(1, 128, 4).astype(np.float32)    # Shape: [1, 128, 4]
watch_input = np.random.rand(1, 128, 4).astype(np.float32)   # Shape: [1, 128, 4]

# Set tensor for 'accelerometer_phone'
interpreter.set_tensor(input_details[0]['index'], phone_input)

# Set tensor for 'accelerometer_watch'
interpreter.set_tensor(input_details[1]['index'], watch_input)

# Run the inference
interpreter.invoke()

# Get the output tensors
logits = interpreter.get_tensor(output_details[0]['index'])
smv_features = interpreter.get_tensor(output_details[1]['index'])

print("\nLogits:", logits)
print("SMV Features:", smv_features)

# Post-processing: Apply softmax to logits to get probabilities
probabilities = tf.nn.softmax(logits, axis=1).numpy()
print("\nProbabilities:", probabilities)

# Determine the predicted class
predicted_class = np.argmax(probabilities, axis=1)
print("Predicted Class:", predicted_class)


Input Details:
{'name': 'serving_default_args_0:0', 'index': 0, 'shape': array([  1, 128,   4], dtype=int32), 'shape_signature': array([  1, 128,   4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
{'name': 'serving_default_args_1:0', 'index': 1, 'shape': array([  1, 128,   4], dtype=int32), 'shape_signature': array([  1, 128,   4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}

Output Details:
{'name': 'StatefulPartitionedCall:1', 'index': 407, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': arra

InvalidArgumentError: `dim` must be in the range [-1, 1) where 1 is the number of dimensions in the input. Received: dim=1

In [5]:
import numpy as np
import tensorflow as tf

# Path to your TFLite model
tflite_model_path = "mobile_falldet2.tflite"

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
interpreter.allocate_tensors()

# Get input and output tensor details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input Details:")
for detail in input_details:
    print(detail)

print("\nOutput Details:")
for detail in output_details:
    print(detail)

# Prepare sample input data
# Replace these with actual sensor data in practice
phone_input = np.random.rand(1, 128, 4).astype(np.float32)    # Shape: [1, 128, 4]
watch_input = np.random.rand(1, 128, 4).astype(np.float32)   # Shape: [1, 128, 4]

# Set tensor for 'accelerometer_phone'
interpreter.set_tensor(input_details[0]['index'], phone_input)

# Set tensor for 'accelerometer_watch'
interpreter.set_tensor(input_details[1]['index'], watch_input)

# Run the inference
interpreter.invoke()

# Assign outputs based on name and shape
logits = None
smv_features_1 = None
smv_features_2 = None

for detail in output_details:
    tensor = interpreter.get_tensor(detail['index'])
    name = detail['name']
    shape = detail['shape']
    
    # Debugging prints
    print(f"Processing Output Tensor: {name}, Shape: {shape}, Data: {tensor}")

    # Correct shape comparison using np.array_equal
    if name == 'StatefulPartitionedCall:0' and np.array_equal(shape, [1, 2]):
        logits = tensor
        print(f"Assigned '{name}' to logits.")
    elif name == 'StatefulPartitionedCall:1' and np.array_equal(shape, [1]):
        smv_features_1 = tensor
        print(f"Assigned '{name}' to smv_features_1.")
    elif name == 'StatefulPartitionedCall:2' and np.array_equal(shape, [1]):
        smv_features_2 = tensor
        print(f"Assigned '{name}' to smv_features_2.")
    else:
        print(f"Unrecognized tensor: {name} with shape: {shape}")

# Verify that logits have been correctly assigned
if logits is None:
    raise ValueError("Logits tensor not found in the model outputs.")
if smv_features_1 is None or smv_features_2 is None:
    print("Warning: One or more SMV features were not found in the model outputs.")

print("\nLogits:", logits)
print("SMV Features 1:", smv_features_1)
print("SMV Features 2:", smv_features_2)

# Post-processing: Apply softmax to logits to get probabilities
# Ensure that logits have shape [1, 2] before applying softmax
if logits.ndim == 2 and logits.shape[1] == 2:
    probabilities = tf.nn.softmax(logits, axis=1).numpy()
    print("\nProbabilities:", probabilities)
    
    # Determine the predicted class
    predicted_class = np.argmax(probabilities, axis=1)
    print("Predicted Class:", predicted_class)
else:
    print("Logits tensor has an unexpected shape:", logits.shape)


Input Details:
{'name': 'serving_default_args_0:0', 'index': 0, 'shape': array([  1, 128,   4], dtype=int32), 'shape_signature': array([  1, 128,   4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
{'name': 'serving_default_args_1:0', 'index': 1, 'shape': array([  1, 128,   4], dtype=int32), 'shape_signature': array([  1, 128,   4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}

Output Details:
{'name': 'StatefulPartitionedCall:1', 'index': 466, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': arra

In [1]:
import os
import torch
import torch.nn as nn
from einops import rearrange
import ai_edge_torch
import numpy as np
from torch.serialization import add_safe_globals
import warnings
import tensorflow as tf

# Suppress the specific UserWarning from the Transformer module
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")

# ------------------------------
# 1. Define the PyTorch Model
# ------------------------------

# Define the XYZProcessor
class XYZProcessor(nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.xyz_encoder = nn.Sequential(
            nn.Conv1d(3, hidden_dim // 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.xyz_encoder(x)

# Define the SMVProcessor
class SMVProcessor(nn.Module):
    def __init__(self, hidden_dim, sequence_length, dropout=0.2):
        super().__init__()
        self.smv_encoder = nn.Sequential(
            nn.Conv1d(1, hidden_dim // 2, kernel_size=5, padding=2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim // 2, hidden_dim, kernel_size=7, padding=3),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(dropout)
        )
        
        self.threshold_learner = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.smv_encoder(x)
        threshold = self.threshold_learner(features)
        return features, threshold

# Define the DualPathFallDetector
class DualPathFallDetector(nn.Module):
    def __init__(
        self,
        acc_coords=4,
        sequence_length=128,
        hidden_dim=64,
        num_heads=8,
        depth=4,
        mlp_ratio=4,
        num_classes=2,
        dropout=0.3,
        use_skeleton=False
    ):
        super().__init__()
        
        self.sequence_length = sequence_length
        self.hidden_dim = hidden_dim
        
        # Processors
        self.phone_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.phone_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        self.watch_xyz_processor = XYZProcessor(hidden_dim, dropout)
        self.watch_smv_processor = SMVProcessor(hidden_dim, sequence_length, dropout)
        
        # Feature fusion
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LayerNorm(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim)
        )
        
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * mlp_ratio,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True  # You can set this to False if nested tensors are required
        )
        
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth,
            norm=nn.LayerNorm(hidden_dim)
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def process_device_data(self, data):
        """Process data from one device with SMV calculation"""
        # Split XYZ and calculate SMV
        xyz_data = data[:, :, :3]  # [B, T, 3]
        smv_data = torch.norm(xyz_data, dim=2, keepdim=True)  # [B, T, 1]
        
        # Process XYZ coordinates
        xyz_data = rearrange(xyz_data, 'b t c -> b c t')
        xyz_features = self.phone_xyz_processor(xyz_data)  # [B, H, T/2]
        xyz_features = xyz_features.mean(dim=2)  # [B, H]
        
        # Process SMV signal
        smv_data = rearrange(smv_data, 'b t c -> b c t')
        smv_features, smv_threshold = self.phone_smv_processor(smv_data)
        smv_features = smv_features.mean(dim=2)  # [B, H]
        
        # Combine features
        device_features = torch.cat([xyz_features, smv_features], dim=1)  # [B, 2H]
        
        return device_features, smv_threshold

    def forward(self, data):
        """Forward pass with both classification and SMV features"""
        # Process phone data
        phone_features, phone_threshold = self.process_device_data(
            data['accelerometer_phone'].float()
        )
        
        # Process watch data
        watch_features, watch_threshold = self.process_device_data(
            data['accelerometer_watch'].float()
        )
        
        # Combine features
        combined = torch.cat([phone_features, watch_features], dim=1)
        fused = self.fusion(combined)
        
        # Temporal modeling
        temporal = fused.unsqueeze(1)
        temporal = self.transformer(temporal)
        
        # Classification
        pooled = temporal.mean(dim=1)
        logits = self.classifier(pooled)
        
        # Return both logits and SMV features
        smv_features = {
            'phone_smv': phone_threshold.squeeze(-1),
            'watch_smv': watch_threshold.squeeze(-1),
        }
        
        return logits, smv_features

# Define the Wrapper Module
class DualPathFallDetectorWrapper(nn.Module):
    def __init__(self, original_model):
        super(DualPathFallDetectorWrapper, self).__init__()
        self.original_model = original_model

    def forward(self, accelerometer_phone, accelerometer_watch):
        data = {
            'accelerometer_phone': accelerometer_phone,
            'accelerometer_watch': accelerometer_watch
        }
        logits, smv_features = self.original_model(data)
        return logits, smv_features

# ------------------------------
# 2. Load and Convert the PyTorch Model
# ------------------------------

# Initialize the model
model = DualPathFallDetector(
    acc_coords=4,
    sequence_length=128,
    hidden_dim=64,
    num_heads=8,
    depth=4,
    mlp_ratio=4,
    num_classes=2,
    dropout=0.3,
    use_skeleton=False
)

# Path to your checkpoint
model_path = "exps/smartfall_har/mobile_falldet/model_epoch_22_f1_0.9414.pth"  # Use forward slashes for cross-platform compatibility

# Option 1: Add safe globals and load checkpoint with weights_only=True
try:
    add_safe_globals([np.core.multiarray.scalar])
    checkpoint = torch.load(model_path, map_location='cpu', weights_only=True)
    print("Checkpoint loaded successfully with weights_only=True.")
except AttributeError:
    print("add_safe_globals is not available in your PyTorch version. Please update PyTorch to >=2.1.0.")
    checkpoint = None
except Exception as e:
    print(f"Failed to load checkpoint with weights_only=True: {e}")
    print("Attempting to load without weights_only=True (security risk)...")
    try:
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        print("Checkpoint loaded successfully with weights_only=False.")
    except Exception as e2:
        print(f"Failed to load checkpoint with weights_only=False: {e2}")
        checkpoint = None

if checkpoint:
    # Load state_dict
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        print("State_dict loaded from checkpoint.")
    else:
        model.load_state_dict(checkpoint)
        print("State_dict loaded directly from checkpoint.")

    # Set the model to evaluation mode
    model.eval()

    # Wrap the model
    wrapped_model = DualPathFallDetectorWrapper(model).eval()

    # Prepare sample inputs as a tuple of tensors
    batch_size = 1
    sequence_length = 128
    channels_phone = 4  # Adjust based on your data
    channels_watch = 4  # Adjust based on your data

    sample_args = (
        torch.randn(batch_size, sequence_length, channels_phone),
        torch.randn(batch_size, sequence_length, channels_watch)
    )

    # Set PJRT_DEVICE to 'CPU' to address the CUDA-related RuntimeError
    os.environ['PJRT_DEVICE'] = 'CPU'

    # Convert the wrapped model to LiteRT
    try:
        edge_model = ai_edge_torch.convert(wrapped_model, sample_args)
        print("Model conversion successful!")
    except Exception as e:
        print(f"Model conversion failed: {e}")
else:
    print("Checkpoint loading failed. Conversion cannot proceed.")

# ------------------------------
# 3. Save the TFLite Model
# ------------------------------

# Assuming the conversion was successful, save the TFLite model
if 'edge_model' in locals():
    try:
        tflite_model_path = "mobile_falldet2.tflite"
        edge_model.export(tflite_model_path)
        print(f"TFLite model saved successfully at '{tflite_model_path}'.")
    except Exception as e:
        print(f"Failed to save TFLite model: {e}")
else:
    print("Edge model not found. Cannot export to TFLite.")

# ------------------------------
# 4. Perform Inference and Compare Outputs
# ------------------------------

# Function to run PyTorch inference
def run_pytorch_inference(model, phone_input, watch_input):
    """
    Runs inference on the PyTorch model.

    Args:
        model (nn.Module): The wrapped PyTorch model.
        phone_input (torch.Tensor): Tensor for accelerometer_phone.
        watch_input (torch.Tensor): Tensor for accelerometer_watch.

    Returns:
        tuple: logits as NumPy array and SMV features as a dictionary.
    """
    with torch.no_grad():
        logits, smv_features = model(phone_input, watch_input)
    return logits.numpy(), {k: v.numpy() for k, v in smv_features.items()}

# Function to run TFLite inference
def run_tflite_inference(tflite_model_path, phone_input, watch_input):
    """
    Runs inference on the TFLite model.

    Args:
        tflite_model_path (str): Path to the TFLite model.
        phone_input (np.ndarray): Input array for accelerometer_phone.
        watch_input (np.ndarray): Input array for accelerometer_watch.

    Returns:
        tuple: logits as NumPy array and SMV features as a dictionary.
    """
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    interpreter.allocate_tensors()

    # Get input and output tensor details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Ensure the inputs are in the correct order
    # Assuming 'serving_default_args_0:0' corresponds to 'accelerometer_phone'
    # and 'serving_default_args_1:0' corresponds to 'accelerometer_watch'
    interpreter.set_tensor(input_details[0]['index'], phone_input)
    interpreter.set_tensor(input_details[1]['index'], watch_input)

    # Run the inference
    interpreter.invoke()

    # Assign outputs based on name and shape
    logits = None
    smv_features_1 = None
    smv_features_2 = None

    for detail in output_details:
        tensor = interpreter.get_tensor(detail['index'])
        name = detail['name']
        shape = detail['shape']
        
        # Debugging prints
        print(f"Processing Output Tensor: {name}, Shape: {shape}, Data: {tensor}")

        # Correct shape comparison using np.array_equal
        if name == 'StatefulPartitionedCall:0' and np.array_equal(shape, [1, 2]):
            logits = tensor
            print(f"Assigned '{name}' to logits.")
        elif name == 'StatefulPartitionedCall:1' and np.array_equal(shape, [1]):
            smv_features_1 = tensor
            print(f"Assigned '{name}' to smv_features_1.")
        elif name == 'StatefulPartitionedCall:2' and np.array_equal(shape, [1]):
            smv_features_2 = tensor
            print(f"Assigned '{name}' to smv_features_2.")
        else:
            print(f"Unrecognized tensor: {name} with shape: {shape}")

    # Verify that logits have been correctly assigned
    if logits is None:
        raise ValueError("Logits tensor not found in the model outputs.")
    if smv_features_1 is None or smv_features_2 is None:
        print("Warning: One or more SMV features were not found in the model outputs.")

    return logits, {'smv_features_1': smv_features_1, 'smv_features_2': smv_features_2}

# ------------------------------
# 5. Generate Consistent Random Input Data
# ------------------------------

# For reproducibility, set the random seed
np.random.seed(42)
torch.manual_seed(42)

# Generate deterministic random input data
phone_input_np = np.random.rand(1, 128, 4).astype(np.float32)    # Shape: [1, 128, 4]
watch_input_np = np.random.rand(1, 128, 4).astype(np.float32)   # Shape: [1, 128, 4]

phone_input_torch = torch.from_numpy(phone_input_np)
watch_input_torch = torch.from_numpy(watch_input_np)

# ------------------------------
# 6. Run Inference on Both Models
# ------------------------------

# Run PyTorch inference
if 'wrapped_model' in locals():
    pytorch_logits, pytorch_smv = run_pytorch_inference(wrapped_model, phone_input_torch, watch_input_torch)
    print("\nPyTorch Model Outputs:")
    print("Logits:", pytorch_logits)
    print("SMV Features:", pytorch_smv)
else:
    print("Wrapped model not found. Skipping PyTorch inference.")

# Run TFLite inference
if os.path.exists("mobile_falldet2.tflite"):
    tflite_logits, tflite_smv = run_tflite_inference("mobile_falldet2.tflite", phone_input_np, watch_input_np)
    print("\nTFLite Model Outputs:")
    print("Logits:", tflite_logits)
    print("SMV Features:", tflite_smv)
else:
    print("TFLite model file 'mobile_falldet2.tflite' not found. Skipping TFLite inference.")

# ------------------------------
# 7. Compare the Outputs
# ------------------------------

if 'pytorch_logits' in locals() and 'tflite_logits' in locals():
    # Compare logits
    logits_diff = np.abs(pytorch_logits - tflite_logits)
    logits_mse = np.mean((pytorch_logits - tflite_logits) ** 2)
    logits_mae = np.mean(logits_diff)
    
    print("\nLogits Comparison:")
    print("Absolute Differences:", logits_diff)
    print(f"Mean Squared Error (MSE): {logits_mse}")
    print(f"Mean Absolute Error (MAE): {logits_mae}")

    # Compare SMV features
    for key in pytorch_smv:
        pytorch_smv_feat = pytorch_smv[key]
        tflite_smv_feat = tflite_smv.get(key, None)
        
        if tflite_smv_feat is not None:
            smv_diff = np.abs(pytorch_smv_feat - tflite_smv_feat)
            smv_mse = np.mean((pytorch_smv_feat - tflite_smv_feat) ** 2)
            smv_mae = np.mean(smv_diff)
            
            print(f"\nSMV Feature '{key}' Comparison:")
            print("Absolute Differences:", smv_diff)
            print(f"Mean Squared Error (MSE): {smv_mse}")
            print(f"Mean Absolute Error (MAE): {smv_mae}")
        else:
            print(f"\nSMV Feature '{key}' not found in TFLite outputs.")
else:
    print("Insufficient outputs to perform comparison.")


2024-11-20 21:30:06.809417: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732159806.835889  277252 cuda_dnn.cc:8498] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732159806.844495  277252 cuda_blas.cc:1410] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-20 21:30:06.892289: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  add_safe_globals([np.core.multiarray.scalar])


Failed to load checkpoint with weights_only=True: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray.scalar was not an allowed global by default. Please use `torch.serialization.add_safe_globals([scalar])` to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Attempting to load without weights_only=True (security risk)...
Checkpoint loaded su

I0000 00:00:1732159823.564281  277252 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4057 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2060 with Max-Q Design, pci bus id: 0000:01:00.0                                                   , compute capability: 7.5


INFO:tensorflow:Assets written to: /tmp/tmpay1u275n/assets


INFO:tensorflow:Assets written to: /tmp/tmpay1u275n/assets
W0000 00:00:1732159826.296060  277252 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1732159826.296149  277252 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2024-11-20 21:30:26.296846: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpay1u275n
2024-11-20 21:30:26.298441: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2024-11-20 21:30:26.298464: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpay1u275n
I0000 00:00:1732159826.314439  277252 mlir_graph_optimization_pass.cc:402] MLIR V1 optimization pass is not enabled
2024-11-20 21:30:26.316565: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2024-11-20 21:30:26.414157: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpay1u275n
2024-11-20 21:30:26.441

Model conversion successful!
TFLite model saved successfully at 'mobile_falldet2.tflite'.

PyTorch Model Outputs:
Logits: [[ 1.4914067 -1.5174068]]
SMV Features: {'phone_smv': array([0.01581427], dtype=float32), 'watch_smv': array([0.01431987], dtype=float32)}
Processing Output Tensor: StatefulPartitionedCall:1, Shape: [1], Data: [0.01581427]
Assigned 'StatefulPartitionedCall:1' to smv_features_1.
Processing Output Tensor: StatefulPartitionedCall:0, Shape: [1 2], Data: [[ 1.4914067 -1.517407 ]]
Assigned 'StatefulPartitionedCall:0' to logits.
Processing Output Tensor: StatefulPartitionedCall:2, Shape: [1], Data: [0.01431986]
Assigned 'StatefulPartitionedCall:2' to smv_features_2.

TFLite Model Outputs:
Logits: [[ 1.4914067 -1.517407 ]]
SMV Features: {'smv_features_1': array([0.01581427], dtype=float32), 'smv_features_2': array([0.01431986], dtype=float32)}

Logits Comparison:
Absolute Differences: [[0.0000000e+00 1.1920929e-07]]
Mean Squared Error (MSE): 7.105427357601002e-15
Mean Absol

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
