In [2]:
# Install requirements and Import
%pip install -r /content/requirements.txt --quiet

In [None]:
from dotenv import load_dotenv
import os

load_dotenv()

os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")


In [1]:

# Standard Library Imports
import logging
import warnings
import json
import os

# Third-Party Libraries
import shutil

# MLflow for Experiment Tracking and Model Management
import mlflow
from mlflow import MlflowClient
from mlflow.types.schema import Schema, ColSpec
from mlflow.types import ParamSchema, ParamSpec
from mlflow.models import ModelSignature

# Transformers
import torch
from transformers import pipeline
from transformers import BitsAndBytesConfig
from transformers import AutoConfig, Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor

#Configurations

In [2]:

warnings.filterwarnings("ignore")

In [3]:
# Define global experiment and run names to be used throughout the notebook
MODEL_PERSONAL_NAME = "google/medgemma-4b-it"
EXPERIMENT_NAME = "Medgemma bot"
MODEL_NAME = "MEDGEMMA"
RUN_NAME = 'MEDGEMMA'
NAME = 'MEDGEMMA'


In [4]:
# Model configs
model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)

In [5]:
# === Create logger ===
logger = logging.getLogger("deployment-notebook")
logger.setLevel(logging.INFO)

formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s",
                             datefmt="%Y-%m-%d %H:%M:%S")

stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
logger.propagate = False

In [6]:

logger.info('Notebook execution started.')

2025-06-09 19:13:25 - INFO - Notebook execution started.


# Model

In [7]:

model_name = MODEL_PERSONAL_NAME

pipe = pipeline(
    "image-text-to-text",
    model=model_name,
)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Device set to use cuda:0


In [15]:
#Sample pipeline test
pipe(text="What is a lesion?", max_new_tokens=300)

Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


[{'input_text': 'What is a lesion?',
  'generated_text': "What is a lesion? A lesion is an abnormal growth or change in a tissue or organ. In medicine, a lesion is a specific area of tissue that has been damaged or altered. Lesions can be caused by many things, including trauma, infection, and genetic mutations.\n\nHere's a breakdown of what a lesion is, and some common examples:\n\n**Key Features of a Lesion:**\n\n*   **Abnormal Growth:**  This is the core defining characteristic. A lesion isn't a normal, healthy tissue; it's something *different*.\n*   **Area of Change:** It represents a localized area of altered function or appearance.\n*   **Variety of Causes:** Lesions can arise from diverse origins, as mentioned earlier.\n*   **Diagnostic Importance:**  Identifying the *type* of lesion (its appearance, growth rate, response to treatment, etc.) is crucial for diagnosis and treatment.\n\n**Common Types of Lesions:**\n\nThe specific type of lesion is determined by its cause and loca

# MLFLOW MODEL FORMAT

In [8]:
import base64
from io import BytesIO
from PIL import Image



In [9]:
class MedGemmaModel(mlflow.pyfunc.PythonModel):

  def decode_base64_to_image(self, base64_str: str) -> Image.Image:
    """
    Decodes a base64-encoded string back into a PIL Image.

    Args:
        base64_str (str): The base64-encoded image string.

    Returns:
        PIL.Image.Image: The decoded image.
    """
    try:
        image_data = base64.b64decode(base64_str)
        image = Image.open(BytesIO(image_data))
        return image
    except Exception as e:
        print(f"Error decoding base64 image: {e}")
        return None

  def _preprocess(self,inputs):
    """
    Preprocesses the input data.

    Args:
        inputs: A dictionary containing two keys:
            - 'user_prompt': The user query.
            - 'system_instruction': The system prompt with the context
            - 'image': The image to be processed, if any.

    Returns:
        messages: A list of role based instructions to be sent to the model
    """
    try:
      user_prompt = inputs['user_prompt'][0]
      system_instruction = inputs['system_instruction'][0]
      image = inputs['image'][0]
      max_tokens = inputs['max_tokens'][0]

      print("pre processing", user_prompt[:10])
      if not system_instruction:
        system_instruction = 'You are an excellent medical expert.'
      if image:
        image = self.decode_base64_to_image(image)
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_instruction}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_prompt},
                    {"type": "image", "image": image}
                ]
            }
        ]
      else:
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_instruction}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_prompt}
                ]
            }
        ]
      return messages,max_tokens

    except Exception as e:
      logger.error(f"Error preprocessing the input data: {str(e)}")

  def load_context(self, context):
    try:
        # config = AutoConfig.from_pretrained(context.artifacts['model_path'])
        # tokenizer = AutoTokenizer.from_pretrained(context.artifacts['model_path'])
        # processor = AutoProcessor.from_pretrained(context.artifacts['model_path'])
        # model = Gemma3ForConditionalGeneration(config)
        # state_dict = torch.load(os.path.join(context.artifacts['model_path'], "pytorch_model.bin"), map_location="cpu")
        # model.load_state_dict(state_dict)

        # # Instantiate BitsAndBytesConfig locally
        # quant_config = BitsAndBytesConfig(load_in_4bit=True)

        # self.model = pipeline(
        #     "image-text-to-text",
        #     model=model,
        #     tokenizer=tokenizer,
        #     processor=processor,
        #     model_kwargs={
        #         "torch_dtype": torch.bfloat16,
        #         "device_map": "auto",
        #         "quantization_config": quant_config,
        #     }
        # )
        quant_config = BitsAndBytesConfig(load_in_4bit=True)
        self.model = pipeline(
            'image-text-to-text',
            model=context.artifacts['model'],  # should point to folder with pretrained files
            torch_dtype=torch.bfloat16,
            device_map="auto",
            model_kwargs={
                "quantization_config": quant_config,
            }
        )
    except Exception as e:
        logger.error(f"Error loading the image-text-to-text pipeline: {str(e)}")

  def predict(self, context, model_input, params):
    """
    Runs inference using the loaded model and input data.

    Args:
        context: The MLflow context object
            with access to artifacts.
        model_input: A dictionary containing 'context' and 'question' keys.

    Returns:
        The output from the model containing the predicted answer and optionally the score.
    """
    try:
      messages,max_tokens = self._preprocess(model_input)
      output = self.model(text=messages, max_new_tokens=max_tokens)
      output = output[0]["generated_text"][-1]["content"]
      return output
    except Exception as e:
      logger.error(f"Error running inference: {str(e)}")

  @classmethod
  def log_model(cls, model_name, source_pipeline=None,demo_folder="../demo"):
    #define the schema for the model
    try:
      input_schema = Schema(
          [
              ColSpec("string", "user_prompt"),
              ColSpec("string", "system_instruction"),
              ColSpec("string", "image"),
              ColSpec("integer", "max_tokens"),
          ]
      )
      output_schema = Schema([ColSpec("string","answer")])

      signature = ModelSignature(inputs=input_schema, outputs=output_schema)
      if source_pipeline:
        os.makedirs(model_name, exist_ok=True)

        # Save model and tokenizer properly
        source_pipeline.model.save_pretrained(model_name, safe_serialization=True)
        source_pipeline.tokenizer.save_pretrained(model_name)
        if hasattr(source_pipeline, "processor"):
            source_pipeline.processor.save_pretrained(model_name)
      # Log model via MLflow pyfunc
      mlflow.pyfunc.log_model(
          artifact_path=model_name,
          python_model=cls(),
          artifacts={"model": model_name, "demo": demo_folder},
          signature=signature,
          pip_requirements=[
              "transformers==4.52.4",
              "bitsandbytes==0.46.0",
              "torch",
              "pillow",
          ],
      )
      requirements = [
          "bitsandbytes==0.46.0",
          "transformers==4.52.4"
      ]
      shutil.rmtree(model_name)
      logger.info("Logging model to MLflow done successfully")
    except Exception as e:
      logger.error(f"Error logging model to MLflow: {str(e)}")





# MODEL REGISTRY

In [10]:
mlflow.set_tracking_uri('/phoenix/mlflow')
mlflow.set_experiment(experiment_name = EXPERIMENT_NAME)

<Experiment: artifact_location='/phoenix/mlflow/310480675827085248', creation_time=1749484917617, experiment_id='310480675827085248', last_update_time=1749484917617, lifecycle_stage='active', name='Medgemma bot', tags={}>

In [18]:
import pandas as pd
import numpy as np
from mlflow.models import infer_signature

example_input = pd.DataFrame({
    "user_prompt": ["what is a lesion?"],              # required
    "system_instruction": [None],                      # optional
    "image": [None],                                   # optional
    "max_tokens": [np.nan]                             # optional
})

signature = infer_signature(model_input=example_input)

In [11]:
with mlflow.start_run(run_name= RUN_NAME) as run:
    logger.info(f"Run's Artifact URI: {run.info.artifact_uri}")
    MedGemmaModel.log_model(model_name = MODEL_NAME, source_pipeline=pipe,demo_folder='/content/demo')
    mlflow.register_model(model_uri = f"runs:/{run.info.run_id}/{MODEL_NAME}", name = NAME)

2025-06-09 19:14:32 - INFO - Run's Artifact URI: /phoenix/mlflow/310480675827085248/b9c0202d6cc74389b7d3c61139aac79d/artifacts


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


Downloading artifacts:   0%|          | 0/15 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/4 [00:00<?, ?it/s]

2025-06-09 19:20:33 - INFO - Logging model to MLflow done successfully
Registered model 'MEDGEMMA' already exists. Creating a new version of this model...
Created version '7' of model 'MEDGEMMA'.


# TESTING LATEST MODEL REGISTERED

In [12]:
client = mlflow.MlflowClient()
model_metadata = client.get_latest_versions('MEDGEMMA', stages=["None"])
print(model_metadata)
latest_model_version = model_metadata[0].version
print(latest_model_version, mlflow.models.get_model_info(f"models:/MEDGEMMA/{latest_model_version}").signature)

[<ModelVersion: aliases=[], creation_timestamp=1749496833130, current_stage='None', description=None, last_updated_timestamp=1749496833130, name='MEDGEMMA', run_id='b9c0202d6cc74389b7d3c61139aac79d', run_link=None, source='/phoenix/mlflow/310480675827085248/b9c0202d6cc74389b7d3c61139aac79d/artifacts/MEDGEMMA', status='READY', status_message=None, tags={}, user_id=None, version=7>]
7 inputs: 
  ['user_prompt': string (required), 'system_instruction': string (required), 'image': string (required), 'max_tokens': integer (required)]
outputs: 
  ['answer': string (required)]
params: 
  None



In [13]:
import pandas as pd
import numpy as np
model = mlflow.pyfunc.load_model(model_uri=f"models:/MEDGEMMA/{latest_model_version}")
input_df = pd.DataFrame({
    "user_prompt": ["what is a lesion?"],
    "system_instruction": [""],
    "image": [""],
    "max_tokens": np.array([1000], dtype=np.int32)  # 👈 force int32
})

# Predict
response = model.predict(input_df)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.


pre processing what is a 


In [14]:
response

'Okay, as an excellent medical expert, let me define the term "lesion" as it applies to medical contexts.\n\nA **lesion** is a general term used to describe any abnormal change in the structure or function of a part of the body.  This can encompass a wide variety of conditions, depending on the location, cause, and appearance of the abnormality.\n\nHere\'s a more detailed breakdown:\n\n*   **Definition:** A lesion refers to an area of damage or abnormal change in tissue, whether it is a result of disease, injury, or other factors.\n\n*   **Key Characteristics:** The defining characteristic of a lesion is that it\'s an **abnormality**.\n\n*   **Types of Lesions:** There are numerous types of lesions, categorized by their origin, appearance, and behavior. Common examples include:\n\n    *   **Skin Lesions:** These are lesions that occur on the skin. Examples include:\n        *   **Macules:** Flat, discolored spots (e.g., freckles, moles).\n        *   **Papules:** Small, raised, solid b

In [19]:
!mlflow models serve -m /phoenix/mlflow/310480675827085248/f70dcb4963b149c98c37318afa93bfd4/artifacts/MEDGEMMA --env-manager conda --port 5004

  value = self.callback(ctx, self, value)
Downloading artifacts:  84% 21/25 [01:07<00:12,  3.23s/it]
^C
