In [None]:
import React, { useEffect, useState } from 'react';
import { View, Text, Button, Image as RNImage } from 'react-native';
import { NativeModules } from 'react-native';
import ImagePicker from 'react-native-image-picker';
import RNFS from 'react-native-fs';
import Canvas from 'react-native-canvas';
import ImageResizer from 'react-native-image-resizer'; // Added for native resizing

const { TFLite } = NativeModules;

const WallCeilingSegmentation = () => {
  const [modelLoaded, setModelLoaded] = useState(false);
  const [prediction, setPrediction] = useState(null);

  // Load TFLite model
  useEffect(() => {
    async function loadModel() {
      try {
        const modelPath = 'assets/Final_Wall_Segmentation.tflite'; // Changed: Match optimized model file
        await TFLite.loadModel(modelPath);
        TFLite.setUseNNAPI(true); // Added: Enable NNAPI for int8 quantization speedup
        TFLite.setNumThreads(4); // Added: Optimize CPU usage for mid-range devices
        setModelLoaded(true);
        console.log('Model loaded successfully');
      } catch (error) {
        console.error('Model loading error:', error);
      }
    }
    loadModel();
  }, []);

  // Preprocess and predict
  async function predict() {
    if (!modelLoaded) {
      console.log('Model not loaded');
      return;
    }

    try {
      // Pick image
      let response;
      try {
        response = await ImagePicker.launchImageLibrary({
          mediaType: 'photo',
          includeBase64: false, // Changed: Disable base64 for efficiency
        });
      } catch (error) {
        console.error('ImagePicker error:', error);
        return;
      }

      if (response.didCancel || !response.assets) {
        console.log('Image selection cancelled');
        return;
      }

      const imageUri = response.assets[0].uri;

      // Resize image natively to 224x224
      const resizedImage = await ImageResizer.createResizedImage(
        imageUri,
        224, // Changed: Match model’s 224x224 input
        224,
        'JPEG',
        100, // Quality
        0, // Rotation
        undefined, // Temp file
        false, // Keep metadata
        { mode: 'stretch' } // Resize mode
      );

      // Read resized image
      const imageData = await RNFS.readFile(resizedImage.uri, 'base64');
      const image = new RNImage();
      image.src = data:image/jpeg;base64,${imageData};
      await new Promise((resolve) => {
        image.onload = resolve;
      });

      // Draw to canvas for pixel access
      const canvas = new Canvas(224, 224);
      const ctx = canvas.getContext('2d');
      ctx.drawImage(image, 0, 0, 224, 224);
      const pixelData = ctx.getImageData(0, 0, 224, 224);
      const pixels = pixelData.data;

      // Convert to Uint8Array (0–255)
      const inputBuffer = new Uint8Array(224 * 224 * 3);
      for (let i = 0, j = 0; i < pixels.length; i += 4, j += 3) {
        inputBuffer[j] = pixels[i]; // R
        inputBuffer[j + 1] = pixels[i + 1]; // G
        inputBuffer[j + 2] = pixels[i + 2]; // B
      }

      // Run prediction with timing
      const start = Date.now(); // Added: Measure inference time
      const output = await TFLite.run(inputBuffer);
      console.log(Inference time: ${Date.now() - start}ms); // Added: Verify ~1s inference

      if (!output || output.length !== 224 * 224 * 3) {
        throw new Error(Invalid output length: expected ${224 * 224 * 3}, got ${output?.length});
      }

      // Process output (uint8, matches int8 quantization)
      const outputArray = new Uint8Array(output);
      const mask = new Uint8Array(224 * 224);
      for (let i = 0; i < 224 * 224; i++) {
        const probs = outputArray.slice(i * 3, i * 3 + 3);
        mask[i] = probs.indexOf(Math.max(...probs)); // argmax: 0, 1, or 2
      }

      // Visualize mask
      const colorMap = {
        0: [0, 0, 255], // Background: Blue
        1: [0, 255, 0], // Wall: Green
        2: [255, 0, 0], // Ceiling: Red
      };
      const maskImage = new Uint8Array(224 * 224 * 4);
      for (let i = 0; i < mask.length; i++) {
        const classId = mask[i];
        maskImage[i * 4] = colorMap[classId][0];
        maskImage[i * 4 + 1] = colorMap[classId][1];
        maskImage[i * 4 + 2] = colorMap[classId][2];
        maskImage[i * 4 + 3] = 255;
      }

      setPrediction(maskImage);
      console.log('Prediction completed');

      // Clean up temporary file
      if (resizedImage.uri) {
        await RNFS.unlink(resizedImage.uri).catch((err) => console.warn('Failed to delete temp file:', err));
      }

    } catch (error) {
      console.error('Prediction error:', error);
    }
  }

  return (
    <View style={{ flex: 1, justifyContent: 'center', alignItems: 'center' }}>
      <Text>Wall & Ceiling Segmentation</Text>
      {modelLoaded ? <Text>Model Ready</Text> : <Text>Loading Model...</Text>}
      {prediction && (
        <Canvas ref={(canvas) => {
          if (canvas) {
            const ctx = canvas.getContext('2d');
            const imageData = ctx.createImageData(224, 224);
            imageData.data.set(prediction);
            ctx.putImageData(imageData, 0, 0);
          }
        }} style={{ width: 224, height: 224 }} />
      )}
      <Button title="Select Image and Predict" onPress={predict} />
    </View>
  );
};

export default WallCeilingSegmentation;

// Optional: For real-time camera feed with VisionCamera
/*
import { Camera } from 'react-native-vision-camera';
const frameProcessor = (frame) => {
  const buffer = frame.toRGB(); // Pseudo-code: Convert frame to RGB
  const inputBuffer = new Uint8Array(224 * 224 * 3);
  // Copy buffer (resize to 224x224)
  const output = TFLite.runSync(inputBuffer); // Synchronous for real-time
  // Process output and update UI
};
*/