# TCG Scanner - Model Conversion to TFLite

This notebook converts trained PyTorch models to TFLite format for Flutter mobile deployment.

**Models to convert:**
1. YOLOv8 Detection Model (best.pt) -> detection.tflite
2. FastViT Embedding Model (embedding_model.pt) -> embedding.tflite

**Requirements:**
- Trained models in Google Drive
- T4 GPU runtime (recommended)

## 1. Setup

In [None]:
# Install required packages
!pip install -q ultralytics>=8.0.0
!pip install -q timm>=0.9.12
!pip install -q onnx>=1.14.0
!pip install -q onnxruntime>=1.16.0
!pip install -q onnx-tf>=1.10.0
!pip install -q tf2onnx>=1.14.0

print("Packages installed successfully!")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
from pathlib import Path

# Set paths
DRIVE_ROOT = Path('/content/drive/MyDrive/tcg-scanner')
MODELS_DIR = DRIVE_ROOT / 'models'
OUTPUT_DIR = DRIVE_ROOT / 'models/tflite'

# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Drive root: {DRIVE_ROOT}")
print(f"Models dir: {MODELS_DIR}")
print(f"Output dir: {OUTPUT_DIR}")

In [None]:
# Verify model files exist
detection_model = MODELS_DIR / 'detection/best.pt'
embedding_model = MODELS_DIR / 'embedding/embedding_model.pt'

print("Checking model files...")
print(f"Detection model: {detection_model.exists()} - {detection_model}")
print(f"Embedding model: {embedding_model.exists()} - {embedding_model}")

if not detection_model.exists():
    print("\nWARNING: Detection model not found!")
    print("Expected path: models/detection/best.pt")
    
if not embedding_model.exists():
    print("\nWARNING: Embedding model not found!")
    print("Expected path: models/embedding/embedding_model.pt")

## 2. Convert Detection Model (YOLOv8)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

print("="*50)
print("Converting Embedding Model (FastViT)")
print("="*50)

# Define the model architecture (MUST match training exactly!)
class EmbeddingModel(nn.Module):
    def __init__(self, backbone='fastvit_t12', embedding_dim=384, dropout=0.2):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0)

        # Get feature dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            feature_dim = features.shape[-1]

        # Must match training architecture exactly (includes Dropout!)
        self.head = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),  # This was missing before!
            nn.Linear(512, embedding_dim),
        )

    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.head(features)
        return F.normalize(embeddings, p=2, dim=1)

# Load trained weights
print(f"\nLoading model from: {embedding_model}")
model = EmbeddingModel()
state_dict = torch.load(embedding_model, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()

print("Model loaded successfully!")

In [None]:
import shutil

# Find and copy the TFLite file
# YOLO creates the tflite in a subfolder
detection_output = OUTPUT_DIR / 'detection.tflite'

# Search for the generated tflite file
tflite_files = list(Path(export_path).parent.glob('**/*.tflite')) if export_path else []
if not tflite_files:
    tflite_files = list(MODELS_DIR.glob('**/best*.tflite'))

if tflite_files:
    src_file = tflite_files[0]
    shutil.copy(src_file, detection_output)
    size_mb = detection_output.stat().st_size / 1024 / 1024
    print(f"Detection model saved to: {detection_output}")
    print(f"Size: {size_mb:.2f} MB")
else:
    print("ERROR: Could not find generated TFLite file")
    print("Searching in:", MODELS_DIR)

## 3. Convert Embedding Model (FastViT)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

print("="*50)
print("Converting Embedding Model (FastViT)")
print("="*50)

# Define the model architecture (must match training)
class EmbeddingModel(nn.Module):
    def __init__(self, backbone='fastvit_t12', embedding_dim=384):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0)
        
        # Get feature dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            feature_dim = features.shape[-1]
        
        self.head = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, embedding_dim),
        )
    
    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.head(features)
        return F.normalize(embeddings, p=2, dim=1)

# Load trained weights
print(f"\nLoading model from: {embedding_model}")
model = EmbeddingModel()
state_dict = torch.load(embedding_model, map_location='cpu')
model.load_state_dict(state_dict)
model.eval()

print("Model loaded successfully!")

In [None]:
# Export to ONNX first
print("\nExporting to ONNX...")

onnx_path = OUTPUT_DIR / 'embedding.onnx'
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    str(onnx_path),
    export_params=True,
    opset_version=12,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['embedding'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'embedding': {0: 'batch_size'}
    }
)

print(f"ONNX model saved: {onnx_path}")
print(f"Size: {onnx_path.stat().st_size / 1024 / 1024:.2f} MB")

In [None]:
# Convert ONNX to TensorFlow SavedModel
import onnx
from onnx_tf.backend import prepare

print("\nConverting ONNX to TensorFlow...")

# Load ONNX model
onnx_model = onnx.load(str(onnx_path))

# Convert to TensorFlow
tf_rep = prepare(onnx_model)
tf_path = OUTPUT_DIR / 'embedding_tf'
tf_rep.export_graph(str(tf_path))

print(f"TensorFlow model saved: {tf_path}")

In [None]:
# Convert TensorFlow to TFLite
import tensorflow as tf

print("\nConverting TensorFlow to TFLite...")

converter = tf.lite.TFLiteConverter.from_saved_model(str(tf_path))
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]

tflite_model = converter.convert()

# Save TFLite model
embedding_output = OUTPUT_DIR / 'embedding.tflite'
with open(embedding_output, 'wb') as f:
    f.write(tflite_model)

size_mb = embedding_output.stat().st_size / 1024 / 1024
print(f"\nEmbedding model saved to: {embedding_output}")
print(f"Size: {size_mb:.2f} MB")

In [None]:
# Cleanup intermediate files
import shutil

print("\nCleaning up intermediate files...")

# Remove TensorFlow SavedModel directory
if tf_path.exists():
    shutil.rmtree(tf_path)
    print(f"Removed: {tf_path}")

# Optionally remove ONNX file (uncomment if you don't need it)
# if onnx_path.exists():
#     onnx_path.unlink()
#     print(f"Removed: {onnx_path}")

print("Cleanup complete!")

## 4. Verify Converted Models

In [None]:
print("="*50)
print("Conversion Summary")
print("="*50)

print(f"\nOutput directory: {OUTPUT_DIR}")
print("\nGenerated files:")

for f in OUTPUT_DIR.glob('*.tflite'):
    size_mb = f.stat().st_size / 1024 / 1024
    print(f"  - {f.name}: {size_mb:.2f} MB")

print("\n" + "="*50)
print("Next Steps:")
print("="*50)
print("""
1. Download the TFLite files from Google Drive:
   - models/tflite/detection.tflite
   - models/tflite/embedding.tflite

2. Copy to your Flutter project:
   - mobile/flutter/assets/models/detection.tflite
   - mobile/flutter/assets/models/embedding.tflite

3. Run Flutter app:
   cd mobile/flutter
   flutter pub get
   flutter run
""")

In [None]:
# Optional: Test TFLite models
import numpy as np

print("Testing TFLite models...")

# Test detection model
if (OUTPUT_DIR / 'detection.tflite').exists():
    interpreter = tf.lite.Interpreter(model_path=str(OUTPUT_DIR / 'detection.tflite'))
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print(f"\nDetection model:")
    print(f"  Input shape: {input_details[0]['shape']}")
    print(f"  Input dtype: {input_details[0]['dtype']}")
    print(f"  Output shapes: {[o['shape'] for o in output_details]}")

# Test embedding model
if (OUTPUT_DIR / 'embedding.tflite').exists():
    interpreter = tf.lite.Interpreter(model_path=str(OUTPUT_DIR / 'embedding.tflite'))
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    print(f"\nEmbedding model:")
    print(f"  Input shape: {input_details[0]['shape']}")
    print(f"  Input dtype: {input_details[0]['dtype']}")
    print(f"  Output shape: {output_details[0]['shape']}")
    
    # Run inference test
    test_input = np.random.rand(1, 3, 224, 224).astype(np.float32)
    interpreter.set_tensor(input_details[0]['index'], test_input)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    print(f"  Test output shape: {output.shape}")
    print(f"  Embedding norm: {np.linalg.norm(output[0]):.4f} (should be ~1.0)")

print("\nAll tests passed!")