In [None]:
!pip install gradio semantic-kernel

In [None]:

!wget -O assets.tar https://github.com/eseaflower/cmiv-ai-course/raw/master/notebooks/assets.tar
!tar -xvf assets.tar

In [None]:
settings = {
    "API_KEY": "",
    "DEPLOYMENT": "",
    "ENDPOINT": "",
}

In [18]:
#@title Agent code

import copy
import inspect
from functools import wraps
from typing import Annotated, Any, AsyncGenerator, Literal, Mapping, Optional
import PIL
import PIL.Image
from pydantic import Field, model_validator
from semantic_kernel.connectors.ai.open_ai import AzureChatCompletion, AzureChatPromptExecutionSettings
from semantic_kernel.kernel import Kernel, ChatHistory
from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior
from semantic_kernel.functions.kernel_function_decorator import kernel_function
from semantic_kernel.functions import KernelPlugin
from semantic_kernel.contents import ImageContent
import io

from enum import Enum



# Create a shallow chat service
_shallow_service = AzureChatCompletion(service_id="agent",
                                       api_key=settings["API_KEY"],
                                       deployment_name=settings["DEPLOYMENT"],
                                       endpoint=settings["ENDPOINT"],
                                       api_version="2024-02-15-preview")


class AgentType(Enum):
    Shallow = "shallow"


def get_service(agent_type: AgentType):
    if agent_type == AgentType.Shallow:
        return _shallow_service
    else:
        raise ValueError("Invalid agent type: ", agent_type)


class AgentExecutionSettings(AzureChatPromptExecutionSettings):
    """Class that enables switching between different agent types."""
    agent_type: Annotated[AgentType | None, Field(exclude=True)] = None

    @model_validator(mode="before")
    def validate_service_settings(cls, model: Any) -> Any:
        """Update some execution settings that differ based on agent type."""
        if isinstance(model, dict):
            model_type = model.get("agent_type")
        return model


class InvocationRecord:
    def __init__(self, instance_id: str, invocation: Any):
        self.instance_id = instance_id
        self.invocation = invocation

    def __str__(self):
        return f"{self.instance_id}: {self.invocation}"


class InvocationLog:
    def __init__(self):
        self.invocations = []

    def add_invocation(self, invocation: InvocationRecord):
        self.invocations.append(invocation)

    def reset(self):
        self.invocations = []

    def __str__(self):
        return "\n".join([str(inv) for inv in self.invocations])


class InvocationLogger:
    def __init__(self, instance_id: Optional[str] = None):
        self.log: Optional[InvocationLog] = None
        self.instance_id = instance_id or "_"

    def set_log(self, log: InvocationLog):
        self.log = log

    def create_record(self, invocation) -> InvocationRecord:
        return InvocationRecord(self.instance_id, invocation)

    def log_invocation(self, invocation):
        if self.log:
            self.log.add_invocation(self.create_record(invocation))


class BaseAgent(InvocationLogger):
    def __init__(self, agent_type: AgentType, instance_id: Optional[str] = None):
        super().__init__(instance_id=instance_id)

        # Create the kernel
        self.kernel = Kernel()
        self.agent_type = agent_type
        self.service = get_service(self.agent_type)
        # Must register the service with the kernel
        self.kernel.add_service(self.service)
        self.logger_plugins: list[InvocationLogger] = []

    def create_history(self, instructions: str | None = None) -> ChatHistory:
        chat = ChatHistory()
        if instructions:
            chat.add_system_message(instructions)
        return chat


# Maybe have a DynamicPlugin that is invoked with the KernelPlugin that
# was registered with the kernel. The DynamicPlugin could override function
# descriptions for instance. The description can be dug out from the
# KernelPlugin[fn_name]->KernelFunction.metadatas.description

    def add_plugin(self, plugin, plugin_name: str, description: Optional[str] = None) -> KernelPlugin:
        if isinstance(plugin, InvocationLogger):
            self.logger_plugins.append(plugin)
        return self.kernel.add_plugin(plugin, plugin_name, description=description)

    def set_log(self, log: InvocationLog):
        super().set_log(log)
        for plugin in self.logger_plugins:
            plugin.set_log(log)


class ChatAgent(BaseAgent):
    def __init__(self,
                 agent_type: AgentType = AgentType.Shallow,
                 max_tokens: Optional[int] = None,
                 temperature: Optional[float] = None,
                 reasoning_effort: Literal["low",
                                           "medium", "high"] | None = None,
                 instance_id: Optional[str] = None):

        super().__init__(agent_type, instance_id=instance_id)

        # Use the custom settings class to support different agent types
        self.chat_settings = AgentExecutionSettings(
            agent_type=agent_type,  # Need to set the agent type
            service_id=self.service.service_id,
            max_tokens=max_tokens,
            temperature=temperature,
            reasoning_effort=reasoning_effort,
            function_choice_behavior=FunctionChoiceBehavior.Auto(),
        )

    def _get_settings(self, disable_function_calls: bool = False) -> AgentExecutionSettings:
        # Use the default settings unless we disable function calls
        settings = self.chat_settings
        if disable_function_calls:
            # Create a copy of the settings since the agent
            # could be used by multiple users with different behavior.
            settings = copy.deepcopy(self.chat_settings)
            settings.function_choice_behavior = FunctionChoiceBehavior.NoneInvoke(
                filters={"included_plugins": [], "included_functions": []})
        return settings

    async def chat(self, history: ChatHistory, disable_function_calls: bool = False) -> str | None:
        # Get settings for this invocation
        settings = self._get_settings(
            disable_function_calls=disable_function_calls)

        # Invoke the chat function
        response = await self.service.get_chat_message_content(history, settings, kernel=self.kernel)
        # Return the response
        return str(response)

    async def chat_streaming(self, history: ChatHistory, disable_function_calls: bool = False) -> AsyncGenerator[str, Any]:
        # Get settings for this invocation
        settings = self._get_settings(
            disable_function_calls=disable_function_calls)
        # Invoke the chat function and stream the response
        async for response in self.service.get_streaming_chat_message_content(history, settings, kernel=self.kernel):
            yield str(response)


# Decorator to log function enter and exit
def log_enter_exit(func):

    if inspect.iscoroutinefunction(func):
        @wraps(func)
        async def async_wrapper(*args, **kwargs):
            _self = args[0]
            fn_context = f"{_self.__class__.__name__}::{func.__name__}"
            _self.log_invocation(f"Entering(A): {fn_context}")
            result = await func(*args, **kwargs)
            _self.log_invocation(f"Exiting(A): {fn_context}")
            return result

        return async_wrapper
    else:
        @wraps(func)
        def sync_wrapper(*args, **kwargs):
            _self = args[0]
            fn_context = f"{_self.__class__.__name__}::{func.__name__}"
            _self.log_invocation(f"Entering: {fn_context}")
            result = func(*args, **kwargs)
            _self.log_invocation(f"Exiting: {fn_context}")
            return result
        return sync_wrapper


class PythonCodePlugin(InvocationLogger):
    def __init__(self, instance_id: Optional[str] = None):
        super().__init__(instance_id=instance_id)

    @kernel_function(name="run_python_script", description="Run a Python script. If a value should be returned, assign it to the variabel 'result'")
    @log_enter_exit
    def run_python_script(self, code: Annotated[str, "The Python script to run. The code must return any result by assigning the variable 'result'"]) -> Annotated[Optional[object], "The output of the code"]:
        self.log_invocation(f"Running code:\n {code}")

        # context = {'column': 'result'}

        exec_locals: Mapping[str, object] = {}
        exec_globals = None  # {'context': context}
        try:
            exec(code, exec_globals, exec_locals)
            result = exec_locals.get('result', None)
            self.log_invocation(f"Result of code execution: {result}")
            return result if result is not None else "Success, no result returned"
        except Exception as e:
            self.log_invocation(f"Error running code: {e}")
            return str(e)


class PythonCoderAgent(ChatAgent):

    _instructions: str = """
    You are a python coder agent that solves problems by generating Python code and executing it. 
    To execute the python code you have acces to tools that can run Python code."""

    def __init__(self, instance_id: Optional[str] = None):
        super().__init__(instance_id=instance_id)
        self.add_plugin(PythonCodePlugin(), "python_code_plugin")

    @kernel_function(name="code", description="Solves tasks by generating Python code for the task and executing it.")
    @log_enter_exit
    async def code(self, task: Annotated[str, "A description of the task that should be solved by coding."]) -> Annotated[Optional[str], "A solution of the task obtained by coding."]:
        self.log_invocation(f"Task:\n {task}")
        history = self.create_history(
            instructions=PythonCoderAgent._instructions)
        history.add_user_message(
            f"Please help me solve the following task: {task}")
        response = await self.chat(history)
        self.log_invocation(response)
        return response


def to_image_content(image: PIL.Image.Image, format: str = "JPEG") -> ImageContent:

    if format not in ["JPEG", "PNG"]:
        raise ValueError(
            f"Invalid image format. Supported formats are 'JPEG' and 'PNG'. Got: {format}")

    mime_type = "image/jpeg"
    if format == "PNG":
        mime_type = "image/png"

    if format == "JPEG":
        image = image.convert("RGB")

    image_buffer = io.BytesIO()
    image.save(image_buffer, format=format)

    image_buffer.seek(0)
    image_bytes = image_buffer.read()
    return ImageContent(
        data=image_bytes, data_format="base64", mime_type=mime_type)


class ImageQueryAgent(ChatAgent):
    _instructions: str = """
    You are an image query agent that answers questions about images. Your answers should be based on the content of the image.
    """

    def __init__(self, instance_id: Optional[str] = None):
        super().__init__(AgentType.Shallow, instance_id=instance_id)

    
    def create_history(self, instructions=None):
        instructions = instructions or ImageQueryAgent._instructions
        return super().create_history(instructions=instructions)

    async def query(self, task: str, image: PIL.Image.Image) -> Optional[str]:
        history = self.create_history()
        image_content = to_image_content(image)
        history.add_user_message([image_content])
        history.add_user_message(task)
        response = await self.chat(history)
        return response


In [35]:
#@title Interaction code
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Optional, Tuple, List
import PIL
import PIL.Image
from gradio import ChatMessage, Image as GrImage, Audio
from semantic_kernel.contents import AuthorRole


def _history_to_gradio(history: ChatHistory) -> list[ChatMessage]:
    """
    Convert semantic kernel chat history to a list of ChatMessage objects.
    Returns:
        list: List of ChatMessage objects
    """
    gradio_history = []
    for message in history:
        if message.role == AuthorRole.USER:
            gradio_history.append(ChatMessage(
                role="user", content=message.content))
        elif message.role == AuthorRole.ASSISTANT:
            gradio_history.append(ChatMessage(
                role="assistant", content=message.content))
    return gradio_history


@dataclass
class UserMessage:
    message: Optional[str]
    image: Optional[PIL.Image.Image]
    files: Optional[List[str]]


def get_gradio_messages(message: UserMessage) -> list[ChatMessage]:
    output = []
    for f in message.files:
        if f.endswith(".jpg") or f.endswith(".png"):
            print(f)
            output.append(ChatMessage(role="user", content=GrImage(f)))
        if f.endswith(".wav") or f.endswith(".mp3"):
            output.append(ChatMessage(role="user", content=Audio(f)))
    if message.message is not None:
        output.append(ChatMessage(role="user", content=message.message))
    return output


_custom_instructions: Optional[str] = None

class InteractionAgent:
    """
    A class to handle the interaction between the user and the chat agent.
    """

    def __init__(self):

        self.agent = ChatAgent(instance_id="interaction_agent")

        self.history = self.agent.create_history(_custom_instructions)
        self.gradio_history = []

    async def respond_streaming(self,
                                message: UserMessage,
                                call_log: Optional[InvocationLog] = None) -> AsyncGenerator[Tuple[list[ChatMessage]], Any]:
        """
        Process the user message and return a stream of updated chat history.
        Args:
            message: The user's input message
        Returns:
            AsyncGenerator[list[ChatMessage], Any] - Stream of updated history of the conversation
        """

        # Log the call
        call_log = call_log or InvocationLog()
        self.agent.set_log(call_log)

        # Add current message
        self._add_user_message(message)

        # Add the user message to the gradio history
        self.gradio_history.extend(get_gradio_messages(message))

        yield self.gradio_history

        current_response_message = ChatMessage(role="assistant", content="")
        self.gradio_history.append(current_response_message)

        current_response: str = ""
        async for response_chunk in self.agent.chat_streaming(self.history):
            # Update the response with the new chunk
            current_response += response_chunk
            # Update the content in-place
            current_response_message.content = current_response
            # Yield the updated history
            yield self.gradio_history


        # Add the final response to the history
        self.history.add_assistant_message(current_response)

        print(f"Invocations:\n {call_log}")

    def reset(self) -> list[ChatMessage]:
        """
        Reset the chat history.
        """
        self.history.clear()
        self.gradio_history = []
        return self.gradio_history

    def _add_user_message(self, message: UserMessage):
        """
        Add a user message to the chat history.
        """
        if message.image is not None:
            image_content = to_image_content(message.image)
            self.history.add_user_message([image_content])

        if message.message is not None:
            self.history.add_user_message(message.message)



In [None]:
#@title Gradio code
from typing import Any, AsyncGenerator, Optional, Tuple
import gradio as gr
import PIL.Image
from PIL.Image import Image


async def respond_streaming(agent: InteractionAgent,
                            msg: dict[str, Any],
                            ) -> AsyncGenerator[Tuple[str, list[gr.ChatMessage], str], Any]:
    """
    Respond to the user's message and return the updated chat history.
    Args:
        agent: The interaction agent
        msg: The user's input message
    Returns:
        AsyncGenerator[("", None, list[ChatMessage])]: Async generator with an empty string and updated history of the conversation
    """


    upload_image = None
    text = msg["text"]

    for input_file in msg["files"]:
        if input_file.endswith(".jpg") or input_file.endswith(".png"):
            upload_image = PIL.Image.open(input_file).convert("RGB")

    user_message = UserMessage(
        message=text, image=upload_image, files=msg["files"])

    # Create an invocation log to keep track of the calls.
    call_log = InvocationLog()
    
    async for history in agent.respond_streaming(user_message, call_log=call_log):
        yield "", history
    



def clear_history(agent: InteractionAgent) -> Tuple[str, list[gr.ChatMessage], str]:
    """
    Clear the chat history.
    Args:
        agent: The interaction agent
    Returns:
        ("", None, list[ChatMessage]): Empty string and an empty list

    """
    return "", agent.reset()


# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Base(), fill_height=True) as demo:

    # The InteractionAgent is not deep-copiable. We can get around this
    # by initializing it in a lambda function.
    interaction_agent = gr.State(lambda: InteractionAgent())

    # Add a title
    # gr.Markdown("# Simple Chat Demo")

    # Create the chatbot component
    chatbot = gr.Chatbot(
        scale=1,        
        height=600,
        show_label=False,
        type="messages",
        layout="panel",
        avatar_images=('assets/radiologist_user_avatar_small.png',
                       'assets/assistant_avatar_small.png'),
    )

    with gr.Row():
        msg = gr.MultimodalTextbox(
            interactive=True,
            file_count="multiple",
            placeholder="Enter message or upload file...",
            show_label=False,
            sources=["upload"],
            scale=9
        )
        clear_btn = gr.Button("Clear Chat", scale=1)

           
    # Set up event handlers
    gr.on(
        triggers=[msg.submit],
        fn=respond_streaming,
        inputs=[interaction_agent, msg],
        outputs=[msg, chatbot])

    gr.on(triggers=[clear_btn.click, chatbot.clear],
          fn=clear_history,
          inputs=[interaction_agent],
          outputs=[msg, chatbot],
          show_progress="hidden")


In [None]:
def launch_demo():
    if demo:
        demo.close()

    demo.launch(inline=True, height=800)

In [None]:
# Set the custom instructions for the model to use
_custom_instructions = """ 
You are an expert radiologist. You should always use your experties to answer the users questions.
"""
launch_demo()

* Running on local URL:  http://127.0.0.1:7863

To create a public link, set `share=True` in `launch()`.




/tmp/gradio/6502ce0140c5e20df68d266e633b9aa42f697036d01133bbdbb3c217b700fd83/radiologist_user_avatar_small.png
Invocations:
 


In [None]:
CR_report = """ 
*Findings*
Lungs:
The lung fields appear hyperinflated bilaterally with increased bronchial markings noted.
No acute consolidation or focal opacities are observed.
Mild patchy opacities are seen in the right lower zone, suggesting possible early signs of atelectasis or minimal fluid.
Cardiac Silhouette:
The heart size is within normal limits.
Cardiac contours are well-defined with no significant enlargement or signs of congestive heart failure.
Mediastinum:
The mediastinal contour appears within normal limits without evidence of mediastinal shift.
Diaphragm:
The diaphragms are situated normally, with no signs of elevation or flattening.
Costophrenic Angles:
Costophrenic angles are sharp. No free air or pleural effusion is observed.
Bones and Soft Tissues:
No acute bony abnormalities noted.
Soft tissues appear unremarkable.

*Impression*
Findings consistent with mild hyperinflation, possibly suggesting early obstructive airway disease. 
No acute process identified. Recommend clinical correlation and follow-up if respiratory symptoms persist. 
Consider further evaluation with a CT scan if clinically indicated.
"""

CR_AI_findings = {
    "ChextXpert": {
        "Nodules and Masses": "Small nodules in the right upper lobe, no masses.",
        "Pleural Effusion": "Mild pleural effusion.",
    },
    "DeepMed": {
        "Pneumonia": "No signs of pneumonia.",
        "Pulmonary Edema": "Visible signs of pulmonary edema.",
    }
}


In [None]:
import json
async def check_report(report_text, ai_findings):
    findings_str = json.dumps(ai_findings, indent=4)

    message = f"""
    Help me check the following report against the findings from the AI models.
    If there are any discrepancies, please point them out.
    
    =Report=
    {report_text}

    =AI Findings=
    {findings_str}"""

    check_agent = ChatAgent(instance_id="check_agent")
    history = check_agent.create_history()
    history.add_user_message(message)
    response = await check_agent.chat(history)
    return response

    

In [None]:
from IPython.display import Markdown
check_response = await check_report(CR_report, CR_AI_findings)
Markdown(check_response)