In [4]:
!pip install ai-edge-torch-nightly

Collecting ai-edge-torch-nightly
  Downloading ai_edge_torch_nightly-0.3.0.dev20250120-py3-none-any.whl (345 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m345.2/345.2 KB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting jax
  Downloading jax-0.5.0-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
Collecting ai-edge-litert-nightly
  Downloading ai_edge_litert_nightly-1.0.1.dev20250119-cp310-cp310-manylinux_2_17_x86_64.whl (2.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting tf-nightly>=2.19.0.dev20241201
  Downloading tf_nightly-2.19.0.dev20250118-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (641.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m641.6/641.6 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00

In [1]:
pip install --upgrade tensorflow==2.12.0


[0mCollecting tensorflow==2.12.0
  Using cached tensorflow-2.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (585.9 MB)
Collecting keras<2.13,>=2.12.0
  Using cached keras-2.12.0-py2.py3-none-any.whl (1.7 MB)
Collecting tensorboard<2.13,>=2.12
  Using cached tensorboard-2.12.3-py3-none-any.whl (5.6 MB)
Collecting gast<=0.4.0,>=0.2.1
  Using cached gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting jax>=0.3.15
  Using cached jax-0.4.38-py3-none-any.whl (2.2 MB)
Collecting jaxlib<=0.4.38,>=0.4.38
  Using cached jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl (101.7 MB)
Collecting jax>=0.3.15
  Using cached jax-0.4.37-py3-none-any.whl (2.2 MB)
Collecting jaxlib<=0.4.37,>=0.4.36
  Using cached jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl (100.3 MB)
Collecting jax>=0.3.15
  Using cached jax-0.4.36-py3-none-any.whl (2.2 MB)
  Using cached jax-0.4.35-py3-none-any.whl (2.2 MB)
Collecting jaxlib<=0.4.35,>=0.4.34
  Using cached jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_6

In [6]:
import os
os.environ['PJRT_DEVICE'] = 'CPU'
os.environ['CUDA_VISIBLE_DEVICES'] = ''  # Optionally disable CUDA
!export USE_TORCH_XLA=0
!export PJRT_DEVICE=CPU
!export CUDA_VISIBLE_DEVICES=""


In [8]:
import os
import warnings
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import ai_edge_torch

# Suppress user warnings from the Transformer module (if any)
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.transformer")

##############################################################################
# 1) Define StudentModel returning a single probability (via sigmoid)
##############################################################################
class PrecisionBlock(nn.Module):
    """
    Convolution-based residual block for capturing detailed temporal dynamics.
    """
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
                               padding=kernel_size // 2)
        self.bn1 = nn.BatchNorm1d(out_channels)
        
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
                               padding=kernel_size // 2)
        self.bn2 = nn.BatchNorm1d(out_channels)

        # Weighted residual connection
        self.res_weight = nn.Parameter(torch.ones(1))
        self.shortcut = (
            nn.Conv1d(in_channels, out_channels, 1)
            if in_channels != out_channels else nn.Identity()
        )

    def forward(self, x):
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = out + self.res_weight * identity
        return F.relu(out)


class TemporalAttention(nn.Module):
    """
    Multi-scale temporal attention for highlighting critical fall segments.
    """
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.max_pool = nn.AdaptiveMaxPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels * 2, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, t = x.size()
        avg_pooled = self.avg_pool(x).view(b, c)
        max_pooled = self.max_pool(x).view(b, c)

        combined = torch.cat([avg_pooled, max_pooled], dim=1)
        scale = self.fc(combined).view(b, c, 1)
        return x * scale.expand_as(x)


class StudentModel(nn.Module):
    """
    Student model: Single-modality (watch accelerometer).
    Incorporates magnitude computation in place of (x, y, z) => (x, y, z, magnitude).
    Uses a stack of PrecisionBlocks + TemporalAttention to capture fall patterns.

    IMPORTANT CHANGE:
      Now returns *only* a single probability from sigmoid, rather than (prob, feat).
    """
    def __init__(self,
                 input_channels=4,  # x, y, z, + magnitude
                 hidden_dim=48,
                 num_blocks=4,
                 dropout_rate=0.2):
        super().__init__()
        self.input_proj = nn.Sequential(
            nn.Linear(input_channels, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate * 0.5)
        )

        # Stacked temporal blocks
        self.temporal_blocks = nn.ModuleList([
            PrecisionBlock(hidden_dim, hidden_dim, kernel_size=(2*i + 3))
            for i in range(num_blocks)
        ])

        self.attention = TemporalAttention(channels=hidden_dim)

        # Classification head
        self.fall_confidence = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        """
        x: [B, T, 3] => raw watch accelerometer data (x, y, z)
        Returns:
            output: [B] => final probability in [0, 1]
        """
        # 1) Compute magnitude => shape [B, T, 1]
        magnitude = torch.sqrt(torch.sum(x**2, dim=-1, keepdim=True))
        x = torch.cat([x, magnitude], dim=-1)  # => [B, T, 4]

        # 2) Project => [B, T, hidden_dim] => [B, hidden_dim, T]
        x = self.input_proj(x)
        x = x.transpose(1, 2)

        # 3) Pass through multiple blocks + attention
        for block in self.temporal_blocks:
            x = block(x)
            x = self.attention(x)

        # 4) Global average pool => [B, hidden_dim]
        student_feat = F.adaptive_avg_pool1d(x, 1).squeeze(-1)

        # 5) Final linear => shape [B], then sigmoid => [B]
        student_logits = self.fall_confidence(student_feat).squeeze(-1)
        output_prob = self.sigmoid(student_logits)

        return output_prob


##############################################################################
# 2) Instantiate Model & Load Weights (No Retraining Needed)
##############################################################################
if __name__ == "__main__":

    # Instantiate your model
    model = StudentModel(
        input_channels=4,
        hidden_dim=48,
        num_blocks=4,
        dropout_rate=0.2
    ).eval()

    # (Optional) Load existing checkpoint
    # If you have a trained checkpoint, do:
    checkpoint = torch.load("student_checkpoint.pth", map_location="cpu")
    model.load_state_dict(checkpoint)
    model.eval()

    ##############################################################################
    # 3) Convert with AI Edge Torch
    ##############################################################################
    # Example input shape: [B=1, T=128, channels=3], but the model internally appends magnitude => 4
    sample_input = torch.randn(1, 128, 3)

    # Baseline PyTorch inference
    with torch.no_grad():
        pt_output = model(sample_input).numpy()  # shape [1]

    # Convert to LiteRT
    edge_model = ai_edge_torch.convert(model.eval(), (sample_input,))

    # LiteRT inference
    edge_output = edge_model(sample_input)  # shape [1]

    # Compare
    print("PyTorch prob:", pt_output)
    print("LiteRT prob:", edge_output)
    print("Close?", np.allclose(pt_output, edge_output, atol=1e-4, rtol=1e-4))

    ##############################################################################
    # 4) Serialize to TFLite
    ##############################################################################
    edge_model.export("student_model.tflite")
    print("Model exported as 'student_model.tflite'")


  checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu" )


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x5 and 4x48)