In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from typing import List
import base64
from fastapi import FastAPI, Request
from PIL import Image
import io

class VLMManager:
    def __init__(self):
        # Load pre-trained ResNet50 model for image feature extraction
        self.resnet = ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3))

        # Freeze the layers in the ResNet model
        for layer in self.resnet.layers:
            layer.trainable = False

        # Define the textual input
        self.text_input = Input(shape=(None,))

        # Define the embedding layer for textual input
        vocab_size = 10000  # Example vocabulary size
        embedding_dim = 256
        self.embedding_layer = Embedding(input_dim=vocab_size, output_dim=embedding_dim)(self.text_input)

        # Define the image input
        self.image_input = Input(shape=(224, 224, 3))

        # Extract features from the image using ResNet
        image_features = self.resnet(self.image_input)
        image_features = tf.keras.layers.GlobalAveragePooling2D()(image_features)

        # Concatenate visual and textual features
        combined_features = Concatenate()([image_features, self.embedding_layer])

        # Define LSTM layer to generate captions
        lstm_units = 256
        lstm = LSTM(units=lstm_units)(combined_features)
        output = Dense(vocab_size, activation='softmax')(lstm)

        # Create the VLM model
        self.model = Model(inputs=[self.image_input, self.text_input], outputs=output)
        self.model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')

    def identify(self, image: bytes, caption: str) -> List[int]:
        img = Image.open(io.BytesIO(image))
        image_width, image_height = img.size
        bbox = [image_width // 4, image_height // 4, image_width // 2, image_height // 2]
        return bbox

app = FastAPI()
vlm_manager = VLMManager()

@app.get("/health")
def health():
    return {"message": "health ok"}

@app.post("/identify")
async def identify(request: Request):
    """
    Performs Object Detection and Identification given an image frame and a text query.
    """
    # Get base64 encoded string of image, convert back into bytes
    input_json = await request.json()

    predictions = []
    for instance in input_json["instances"]:
        # Each is a dict with one key "b64" and the value as a b64 encoded string
        image_bytes = base64.b64decode(instance["b64"])

        # Perform identification using VLMManager
        bbox = vlm_manager.identify(image_bytes, instance["caption"])
        predictions.append(bbox)

    return {"predictions": predictions}


In [None]:
pip install "fastapi[all]"
