In [1]:
from huggingface_hub import snapshot_download
from moondream_model import VisionEncoder, TextModel
import torch 
from PIL import Image
import re

model_path = snapshot_download("vikhyatk/moondream1")

DEVICE = "cuda"
DTYPE = torch.float16

vision_encoder = VisionEncoder(model_path).to(DEVICE, dtype=DTYPE)
text_model = TextModel(model_path).to(DEVICE, dtype=DTYPE)

print(f"Using:{DEVICE}")
print(f"Type: {DTYPE}")

Fetching 16 files:   0%|          | 0/16 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Using:cuda
Type: torch.float16


In [36]:
def run_inference(image, prompt, max_new_tokens=128):
    with torch.inference_mode(), torch.cuda.amp.autocast():
        image_embeds = vision_encoder(image)
        result = text_model.answer_question(image_embeds, 
                                            prompt, 
                                            max_new_tokens=max_new_tokens)
    
    if isinstance(result, tuple):
        result_text = result[0]
    else:
        result_text = result

    # Convert the result to string if it's not already
    if not isinstance(result_text, str):
        if torch.is_tensor(result_text):
            result_text = result_text.cpu().numpy().tolist()
            result_text = ' '.join(map(str, result_text))
        else:
            result_text = str(result_text)

    # Apply regex to clean up the result string
    cleaned_result = re.sub("<$", "", re.sub("END$", "", result_text))
    return cleaned_result

In [37]:
img = Image.open("img/output_000017.jpg")
prompt = "Describe this image."

In [38]:
%%timeit
result = run_inference(img, prompt, max_new_tokens=20)
result

Inference time took 0.42591166496276855 seconds.
Inference time took 0.4257669448852539 seconds.
Inference time took 0.41782665252685547 seconds.
Inference time took 0.4163053035736084 seconds.
Inference time took 0.42008137702941895 seconds.
Inference time took 0.4188261032104492 seconds.
Inference time took 0.4212477207183838 seconds.
Inference time took 0.41730237007141113 seconds.
428 ms ± 3.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
