In [4]:
pip install ai-edge-torch


Note: you may need to restart the kernel to use updated packages.


In [7]:
# Cell 1: Model Definition (time2VecStudent.py)
import torch
import torch.nn as nn
import torch.nn.functional as F

class Time2Vec(nn.Module):
    """
    Minimal Time2Vec example. This version returns an embedding of shape (N, out_channels).
    """
    def __init__(self, out_channels=8):
        super().__init__()
        self.out_channels = out_channels
        self.lin_weight = nn.Parameter(torch.randn(1))
        self.lin_bias   = nn.Parameter(torch.randn(1))
        if out_channels > 1:
            self.per_weight = nn.Parameter(torch.randn(out_channels - 1))
            self.per_bias   = nn.Parameter(torch.randn(out_channels - 1))
        else:
            self.per_weight = None
            self.per_bias   = None

    def forward(self, t):
        # t => shape (N,1)
        t_lin = self.lin_weight * t + self.lin_bias
        if self.per_weight is not None:
            alpha = self.per_weight.unsqueeze(0)
            beta  = self.per_bias.unsqueeze(0)
            t_per = torch.sin(alpha * t + beta)
            return torch.cat([t_lin, t_per], dim=-1)
        else:
            return t_lin

class FallTime2VecTransformer(nn.Module):
    """
    A simple Transformer-based model that takes a (B, T, 3) accelerometer input
    (x, y, z inertial data) plus an optional mask (B, T) and time (B, T).
    It creates an 8-dim Time2Vec embedding, concatenates it with the 3-channel data
    (total feat = 11), passes it through a TransformerEncoder, and outputs (B, num_classes).
    """

    def __init__(self,
                 feat_dim=11,      # 3 (accel) + 8 (time2vec) = 11
                 d_model=64,
                 nhead=4,
                 num_layers=2,
                 num_classes=2,
                 time2vec_dim=8,   # 8-D Time2Vec embedding
                 dropout=0.1,
                 dim_feedforward=128):
        super().__init__()
        self.feat_dim = feat_dim
        self.d_model = d_model
        self.nhead = nhead
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.time2vec_dim = time2vec_dim
        self.dropout = dropout
        self.dim_feedforward = dim_feedforward

        # 1) Time2Vec for the time axis
        self.time2vec = Time2Vec(out_channels=time2vec_dim)

        # 2) Input projection from feat_dim -> d_model
        self.input_proj = nn.Linear(feat_dim, d_model)

        # 3) Transformer Encoder
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True  # expecting input (B, T, d_model)
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        # 4) Output layer for classification
        self.fc = nn.Linear(d_model, self.num_classes)

    def forward(self, accel_seq, accel_mask=None, accel_time=None):
        """
        accel_seq: shape (B, T, 3) (accelerometer inertial data: x, y, z)
        accel_mask: shape (B, T) bool, where True indicates padding
        accel_time: shape (B, T) or None; if None, dummy time indices are created
        """
        B, T, C = accel_seq.shape  # Expect C=3

        # 1) Create dummy time if not provided
        if accel_time is None:
            time_idx = torch.arange(T, device=accel_seq.device).unsqueeze(0).expand(B, T).float()
        else:
            time_idx = accel_time

        # 2) Flatten time and apply Time2Vec => shape: (B*T, 1) -> (B*T, time2vec_dim)
        time_flat = time_idx.reshape(-1, 1)
        t_emb_flat = self.time2vec(time_flat)
        t_emb = t_emb_flat.view(B, T, self.time2vec_dim)

        # 3) Concatenate accelerometer data and Time2Vec embedding => (B, T, 3+8=11)
        x = torch.cat([accel_seq, t_emb], dim=-1)

        # 4) Project to d_model dimension => (B, T, d_model)
        x_proj = self.input_proj(x)

        # 5) Process through Transformer encoder => (B, T, d_model)
        out = self.encoder(x_proj, src_key_padding_mask=accel_mask)

        # 6) Global average pool across time => (B, d_model)
        out = out.mean(dim=1)

        # 7) Final linear layer => (B, num_classes)
        logits = self.fc(out)
        return logits
os.environ['PJRT_DEVICE'] = 'CPU'

In [12]:
import os
import torch
import ai_edge_torch
import torch.nn.functional as F

# Disable Torch XLA to force CUDA usage.
os.environ["USE_TORCH_XLA"] = "0"

# Updated fallback implementation for scaled dot product attention.
def fallback_scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    if attn_mask is not None:
        # Ensure that the attention mask is a boolean tensor.
        if attn_mask.dtype != torch.bool:
            attn_mask = attn_mask.to(torch.bool)
        scores = scores.masked_fill(attn_mask, float("-inf"))
    attn = torch.softmax(scores, dim=-1)
    if dropout_p > 0.0:
        attn = torch.nn.functional.dropout(attn, p=dropout_p)
    return torch.matmul(attn, v)

# Override the efficient attention op with our fallback.
F.scaled_dot_product_attention = fallback_scaled_dot_product_attention

def convert_to_tflite():
    # Use CUDA if available.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # 1. Import and instantiate your model, move to CUDA, and set to evaluation mode.
    from __main__ import FallTime2VecTransformer  # Ensure model definition cell is run
    model = FallTime2VecTransformer().to(device).eval()

    # 2. Prepare sample inputs:
    #    - accel_seq: (B, T, 3) accelerometer inertial data (x, y, z)
    #    - accel_mask: (B, T) boolean mask (here all zeros, no padding)
    #    - accel_time: Provide a dummy tensor instead of None.
    B, T, C = 1, 10, 3  # Example: batch=1, 10 time steps, 3 channels
    accel_seq = torch.randn(B, T, C, dtype=torch.float32, device=device)
    accel_mask = torch.zeros(B, T, dtype=torch.bool, device=device)
    # Create a dummy accel_time tensor, e.g., a range for each time step.
    accel_time = torch.arange(T, device=device).unsqueeze(0).expand(B, T).float()

    sample_inputs = (accel_seq, accel_mask, accel_time)

    # 3. Convert the model using AI Edge Torch.
    #    This produces a single-signature, float32 TFLite model.
    edge_model = ai_edge_torch.convert(model, sample_inputs)

    # 4. Export the converted model as a TFLite flatbuffer file.
    edge_model.export("fall_time2vec_transformer.tflite")
    print("TFLite model exported as fall_time2vec_transformer.tflite")

# Run the conversion function.
convert_to_tflite()


INFO:tensorflow:Assets written to: /tmp/tmpn5qvbumu/assets


INFO:tensorflow:Assets written to: /tmp/tmpn5qvbumu/assets


TFLite model exported as fall_time2vec_transformer.tflite


W0000 00:00:1739345238.951233    1726 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1739345238.951447    1726 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-02-12 01:27:18.953104: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpn5qvbumu
2025-02-12 01:27:18.954464: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-02-12 01:27:18.954491: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpn5qvbumu
I0000 00:00:1739345238.962546    1726 mlir_graph_optimization_pass.cc:402] MLIR V1 optimization pass is not enabled
2025-02-12 01:27:18.963641: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-02-12 01:27:19.038930: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpn5qvbumu
2025-02-12 01:27:19.053206: I tensorflow/cc/saved_model/loader.cc:466] SavedModel 

In [16]:
# Cell: Detailed Comparison of PyTorch and TFLite Models

import torch
import os
import tensorflow as tf


# 1. Inspect the original PyTorch model.
def inspect_pytorch_model():
    model = FallTime2VecTransformer().eval()
    print("=== PyTorch Model Architecture ===")
    print(model)
    
    # Count total parameters.
    total_params = sum(p.numel() for p in model.parameters())
    print("\nTotal number of parameters: {:,}".format(total_params))
    
    # Assuming all parameters are float32 (4 bytes each).
    mem_footprint_bytes = total_params * 4
    print("Estimated memory footprint (float32): {:.2f} MB".format(mem_footprint_bytes / (1024 ** 2)))
    
    # List each parameter (name, shape, dtype)
    print("\n--- Parameter Details ---")
    for name, param in model.named_parameters():
        shape_str = str(tuple(param.shape))
        print(f"{name:40s} | shape: {shape_str:15s} | dtype: {param.dtype}")
        
    return model

# 2. Inspect the TFLite model.
def inspect_tflite_model(tflite_file):
    print("\n=== TFLite Model Details ===")
    # Check file size.
    file_size = os.path.getsize(tflite_file)
    print(f"TFLite model file size: {file_size / 1024:.2f} KB")
    
    # Load the TFLite model.
    interpreter = tf.lite.Interpreter(model_path=tflite_file)
    interpreter.allocate_tensors()
    
    # Input details.
    input_details = interpreter.get_input_details()
    print("\n--- Input Tensor Details ---")
    for inp in input_details:
        print(f"Name: {inp['name']}, shape: {inp['shape']}, dtype: {inp['dtype']}")
    
    # Output details.
    output_details = interpreter.get_output_details()
    print("\n--- Output Tensor Details ---")
    for out in output_details:
        print(f"Name: {out['name']}, shape: {out['shape']}, dtype: {out['dtype']}")
    
    # Get all tensor details.
    tensor_details = interpreter.get_tensor_details()
    print("\n--- All TFLite Tensor Details ---")
    for tensor in tensor_details:
        print(f"Name: {tensor.get('name', 'N/A')}, shape: {tensor['shape']}, dtype: {tensor['dtype']}")
    
    return interpreter

# 3. Run comparisons.
print("***** Inspecting PyTorch Model *****")
pt_model = inspect_pytorch_model()

tflite_filename = "fall_time2vec_transformer.tflite"
if os.path.exists(tflite_filename):
    print("\n***** Inspecting TFLite Model *****")
    interpreter = inspect_tflite_model(tflite_filename)
else:
    print(f"\nTFLite file '{tflite_filename}' not found. Run the conversion script first.")

# 4. Summary of Precision and Comparison
print("\n***** Comparison Summary *****")
print("PyTorch model was converted preserving float32 precision.")
print("All parameter dtypes in the PyTorch model are expected to be torch.float32.")
print("TFLite model input and output dtypes should be np.float32 (as seen above).")
print("Differences in internal tensor names and organization are expected due to the conversion process.")


***** Inspecting PyTorch Model *****
=== PyTorch Model Architecture ===
FallTime2VecTransformer(
  (time2vec): Time2Vec()
  (input_proj): Linear(in_features=11, out_features=64, bias=True)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=128, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc): Linear(in_features=64, out_features=2, bias=True)
)

Total number of parameters: 67,858
Estimated memory

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
