In [1]:
# 1) Uninstall any existing AI Edge packages
!pip uninstall -y ai-edge-torch ai-edge-torch-nightly ai-edge-quantizer ai-edge-quantizer-nightly

# 2) Re-install the newest nightly packages together
!pip install --upgrade --pre ai-edge-torch-nightly ai-edge-quantizer-nightly


[0mFound existing installation: ai-edge-torch-nightly 0.3.0.dev20250218
Uninstalling ai-edge-torch-nightly-0.3.0.dev20250218:
  Successfully uninstalled ai-edge-torch-nightly-0.3.0.dev20250218
[0mFound existing installation: ai-edge-quantizer-nightly 0.0.1.dev20250218
Uninstalling ai-edge-quantizer-nightly-0.0.1.dev20250218:
  Successfully uninstalled ai-edge-quantizer-nightly-0.0.1.dev20250218
Collecting ai-edge-torch-nightly
  Using cached ai_edge_torch_nightly-0.3.0.dev20250218-py3-none-any.whl (381 kB)
Collecting ai-edge-quantizer-nightly
  Using cached ai_edge_quantizer_nightly-0.0.1.dev20250218-py3-none-any.whl (146 kB)
Installing collected packages: ai-edge-quantizer-nightly, ai-edge-torch-nightly
Successfully installed ai-edge-quantizer-nightly-0.0.1.dev20250218 ai-edge-torch-nightly-0.3.0.dev20250218


In [2]:
old_sd = torch.load("Fold3_NoDistill_best_loss_weights.pth", map_location="cpu",weights_only=True)
for k in old_sd.keys():
    print(k)


time2vec.lin_weight
time2vec.lin_bias
time2vec.per_weight
time2vec.per_bias
input_proj.weight
input_proj.bias
encoder.layers.0.self_attn.in_proj_weight
encoder.layers.0.self_attn.in_proj_bias
encoder.layers.0.self_attn.out_proj.weight
encoder.layers.0.self_attn.out_proj.bias
encoder.layers.0.linear1.weight
encoder.layers.0.linear1.bias
encoder.layers.0.linear2.weight
encoder.layers.0.linear2.bias
encoder.layers.0.norm1.weight
encoder.layers.0.norm1.bias
encoder.layers.0.norm2.weight
encoder.layers.0.norm2.bias
encoder.layers.1.self_attn.in_proj_weight
encoder.layers.1.self_attn.in_proj_bias
encoder.layers.1.self_attn.out_proj.weight
encoder.layers.1.self_attn.out_proj.bias
encoder.layers.1.linear1.weight
encoder.layers.1.linear1.bias
encoder.layers.1.linear2.weight
encoder.layers.1.linear2.bias
encoder.layers.1.norm1.weight
encoder.layers.1.norm1.bias
encoder.layers.1.norm2.weight
encoder.layers.1.norm2.bias
encoder.layers.2.self_attn.in_proj_weight
encoder.layers.2.self_attn.in_proj_b

In [3]:
import ai_edge_torch
print("ai_edge_torch version:", ai_edge_torch.__version__)
help(ai_edge_torch.signature)


ai_edge_torch version: 0.3.0.dev20250218
Help on function signature in module ai_edge_torch._convert.converter:

signature(name: 'str', module: 'torch.nn.Module', sample_args=None, sample_kwargs=None, dynamic_shapes: 'Optional[Union[dict[str, Any], Tuple[Any, ...]]]' = None) -> 'Converter'
    Initiates a Converter object with the provided signature.
    
    Args:
      name: The name of the signature included in the converted edge model.
      module: The torch module to be converted.
      sample_args: Tuple of tensors by which the torch module will be traced with
        prior to conversion.
      sample_kwargs: Dict of str to tensor by which the torch module will be
        traced with prior to conversion.
      dynamic_shapes: Optional dict or tuple that specify dynamic shape
        specifications for each input in original order. See
        https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
          details.
    
    Returns:
      A Converter object wit

In [6]:
#!/usr/bin/env python
# coding: utf-8

##############################
# 1) Imports
##############################
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

# AI Edge Torch (make sure you're using version >= 0.3.x nightly)
import ai_edge_torch
from ai_edge_torch.generative.quantize import quant_recipes

##############################
# 2) Time2Vec Definition
##############################
class Time2Vec(nn.Module):
    """
    This matches the submodule named in your checkpoint:
      time2vec.lin_weight
      time2vec.lin_bias
      time2vec.per_weight
      time2vec.per_bias
    """
    def __init__(self, out_channels=8):
        super().__init__()
        self.out_channels = out_channels
        self.lin_weight = nn.Parameter(torch.randn(1))
        self.lin_bias   = nn.Parameter(torch.randn(1))
        if out_channels > 1:
            self.per_weight = nn.Parameter(torch.randn(out_channels - 1))
            self.per_bias   = nn.Parameter(torch.randn(out_channels - 1))
        else:
            self.per_weight = None
            self.per_bias   = None

    def forward(self, t):
        # shape of t: (N, 1), e.g. flatten (B*T, 1)
        t_lin = self.lin_weight * t + self.lin_bias
        if self.per_weight is not None:
            alpha = self.per_weight.unsqueeze(0)  # (1, out_channels-1)
            beta  = self.per_bias.unsqueeze(0)    # (1, out_channels-1)
            t_per = torch.sin(alpha * t + beta)
            return torch.cat([t_lin, t_per], dim=-1)
        else:
            return t_lin

##############################
# 3) Main Model Definition
##############################
class FallTime2VecTransformer(nn.Module):
    """
    This matches your original layering structure and parameter naming:
      - time2vec.* for Time2Vec
      - input_proj (linear)
      - encoder.layers.(0..2) => each has
          self_attn.in_proj_weight, self_attn.in_proj_bias,
          self_attn.out_proj.weight, self_attn.out_proj.bias,
          linear1.weight, linear1.bias, linear2.weight, linear2.bias,
          norm1.weight, norm1.bias, norm2.weight, norm2.bias
      - fc.weight, fc.bias
    """
    def __init__(self,
                 feat_dim=19,        # e.g. 3 accel channels + 16 Time2Vec => 19
                 d_model=64,
                 nhead=4,
                 num_layers=3,
                 num_classes=2,
                 time2vec_dim=16,
                 dropout=0.1,
                 dim_feedforward=128):
        super().__init__()
        self.feat_dim = feat_dim
        self.time2vec_dim = time2vec_dim

        # 1) The same time2vec submodule
        self.time2vec = Time2Vec(out_channels=time2vec_dim)

        # 2) Project input => d_model
        self.input_proj = nn.Linear(feat_dim, d_model)

        # 3) A standard PyTorch TransformerEncoder with `num_layers`.
        #    Each layer's internal submodules map exactly to the checkpoint keys:
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        # 4) Final classification
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, accel_xyz, accel_mask, accel_time):
        """
        accel_xyz:  (B, T, 3)  => raw accelerometer channels
        accel_mask: (B, T) bool => True means 'PAD' => used by Transformer
        accel_time: (B, T) => raw time indices
        """
        B, T, _ = accel_xyz.shape

        # Flatten time => pass through time2vec => reshape
        time_flat = accel_time.reshape(-1, 1)         # (B*T, 1)
        t_emb_flat = self.time2vec(time_flat)         # => (B*T, time2vec_dim)
        t_emb = t_emb_flat.view(B, T, self.time2vec_dim)

        # Concat => shape => (B, T, feat_dim)
        x = torch.cat([accel_xyz, t_emb], dim=-1)     # => 3 + time2vec_dim

        # Project
        x_proj = self.input_proj(x)                   # => (B, T, d_model)

        # Pass to Transformer
        out_seq = self.encoder(x_proj, src_key_padding_mask=accel_mask)

        # Global average pool
        feat = out_seq.mean(dim=1)

        # Final linear
        logits = self.fc(feat)
        return logits

##############################
# 4) Load Weights
##############################
def load_weights(model, ckpt_path):
    """
    Loads your original .pth. The checkpoint keys, for example, are:

      time2vec.lin_weight
      time2vec.lin_bias
      time2vec.per_weight
      ...
      encoder.layers.0.self_attn.in_proj_weight
      ...
      fc.weight
      fc.bias

    If no prefix, we can load directly:
    """
    ckpt = torch.load(ckpt_path, map_location='cpu')

    # If there's no mismatch (like "module." prefix), just load:
    model.load_state_dict(ckpt, strict=True)
    print(f"[INFO] Loaded weights from: {ckpt_path}")

##############################
# 5) Convert => TFLite
##############################
def convert_to_tflite(model,
                      tflite_path="fall_time2vec_transformer.tflite",
                      quantize=False):
    model.eval()

    # Example dummy input with (B=1, T=20)
    B = 1
    T = 20
    dummy_xyz  = torch.randn(B, T, 3, dtype=torch.float32)
    dummy_mask = torch.zeros(B, T, dtype=torch.bool)   # no pad
    dummy_time = torch.arange(T).unsqueeze(0).float()  # shape (1,T)

    # AI Edge Torch new signature uses sample_args= for input(s):
    converter = ai_edge_torch.signature(
        name="inference",
        module=model,
        sample_args=(dummy_xyz, dummy_mask, dummy_time),
    )

    quant_config = None
    if quantize:
        quant_config = quant_recipes.full_int8_weight_only_recipe()

    # Convert to TFLite
    tflite_model = converter.convert(quant_config=quant_config)
    tflite_model.export(tflite_path)
    print(f"[INFO] Exported TFLite model => {tflite_path}")

##############################
# 6) Main Orchestrator
##############################
if __name__ == "__main__":
    CKPT_PATH = "Fold3_NoDistill_best_loss_weights.pth"
    OUTPUT_TFLITE_PATH = "fall_time2vec_transformer.tflite"

    # Build model with same shapes used in your training
    model = FallTime2VecTransformer(
        feat_dim=19,      # e.g. 3 + 16
        d_model=64,
        nhead=4,
        num_layers=3,     # as your checkpoint uses layers.0, .1, .2
        num_classes=2,
        time2vec_dim=16,  # must match the time2vec block dimension
        dropout=0.1,
        dim_feedforward=128
    )

    # Load your trained checkpoint
    load_weights(model, CKPT_PATH)

    # Convert to TFLite
    convert_to_tflite(
        model,
        tflite_path=OUTPUT_TFLITE_PATH,
        quantize=False   # or True if you want int8 weight-only
    )


  ckpt = torch.load(ckpt_path, map_location='cpu')


[INFO] Loaded weights from: Fold3_NoDistill_best_loss_weights.pth


I0000 00:00:1739878493.214925   95114 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 4057 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2060 with Max-Q Design, pci bus id: 0000:01:00.0, compute capability: 7.5


INFO:tensorflow:Assets written to: /tmp/tmpextvdr50/assets


INFO:tensorflow:Assets written to: /tmp/tmpextvdr50/assets


[INFO] Exported TFLite model => fall_time2vec_transformer.tflite


W0000 00:00:1739878494.673421   95114 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1739878494.673463   95114 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-02-18 05:34:54.674993: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /tmp/tmpextvdr50
2025-02-18 05:34:54.675949: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-02-18 05:34:54.675962: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpextvdr50
I0000 00:00:1739878494.683861   95114 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
2025-02-18 05:34:54.685245: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-02-18 05:34:54.751366: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpextvdr50
2025-02-18 05:34:54.765003: I tensorflow/cc/saved_model/loader.cc:471] SavedModel 