In [35]:
import dotenv
import os
from google import genai
from enum import Enum

dotenv.load_dotenv()

class GeminiModel(Enum):
    """Available Gemini models with their capabilities and release dates."""
    
    GEMINI_2_FLASH_THINKING = "gemini-2.0-flash-thinking-exp-01-21"  # Jan 21 2025 - Reasoning & thinking capabilities
    GEMINI_2_FLASH = "gemini-2.0-flash-exp"  # Dec 11 2024 - Next-gen features, speed, tools, multimodal
    GEMINI = "gemini-exp-1206"  # Dec 6 2024 - Quality improvements
    LEARNLM = "learnlm-1.5-pro-experimental"  # Nov 19 2024 - Audio, images, video, text input
    
    def __str__(self):
        return self.value

def get_gemini_client(**kwargs):
    return genai.Client(api_key=os.getenv("GEMINI_API_KEY"), **kwargs)

class Gemini:
    def __init__(self, model: GeminiModel):
        self.client = get_gemini_client()
        self.model = model

    def generate_content(self, contents: str):
        return self.client.models.generate_content(model=self.model, contents=contents)
    
    def generate_content_stream(self, contents: str):
        return self.client.models.generate_content_stream(model=self.model, contents=contents)

In [18]:

g = Gemini(GeminiModel.GEMINI_2_FLASH)

In [19]:
for chunk in g.generate_content_stream('Explain how RLHF works in simple terms.'):
    for part in chunk.candidates[0].content.parts:
        print(part.text, end = '|')

Okay|, let's break down Reinforcement Learning from Human Feedback (RLHF)| in simple terms. Imagine you're training a puppy to do a trick,| like sitting.

**Here's how it usually goes with basic training:**

1. **You give a command ("Sit!").** This is like| feeding an AI model with a prompt.
2. **The puppy does... something.** It might wiggle, jump, or maybe even sit by accident. This| is like the AI model generating a response.
3. **You provide feedback.** If it sits, you give a treat and praise ("Good boy!"). If it does anything else, you might say "No" or nothing at all|. This is like giving a signal of good or bad.
4. **The puppy learns over time.** With enough practice, it starts to understand what "Sit!" means and how to get the treat.

**RLHF is kind| of similar, but instead of a puppy, we have a large language model (like the ones powering chatbots) and instead of a single command, we're training it to generate human-like responses:**

**Here's how RLHF works in 3 main stages:*

## Search as a tool

In [28]:
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch

# use our client
client = g.client
model_id = "gemini-2.0-flash-exp"

google_search_tool = Tool(
    google_search = GoogleSearch()
)

response = client.models.generate_content(
    model=model_id,
    contents="When is the next total solar eclipse in the United States?",
    config=GenerateContentConfig(
        tools=[google_search_tool],
        response_modalities=["TEXT"],
    )
)

for each in response.candidates[0].content.parts:
    print(each.text)
# Example response:
# The next total solar eclipse visible in the contiguous United States will be on ...


The next total solar eclipse visible in the contiguous United States will occur on **August 22, 2044.** However, a total solar eclipse will be visible in Alaska on March 30, 2033.

It is worth noting that a total solar eclipse will occur on April 8, 2024. This eclipse will cross North America, passing over Mexico, the United States, and Canada. The path of totality for this eclipse will cross over 27 park units as it makes its way across Texas, Arkansas, Missouri, Kentucky, Illinois, Indiana, Ohio, Pennsylvania, New York, Vermont, New Hampshire, and Maine.

Additionally, another total solar eclipse will occur on August 12, 2045. This eclipse will be a cross-country eclipse more impressive than those in 2017 and 2024, with areas from California to Florida being plunged into darkness.



In [29]:
# To get grounding metadata as web content.
# print(response.candidates[0].grounding_metadata)

from rich import print as rprint
from rich.panel import Panel
from rich.text import Text
from IPython.display import HTML, display

def print_grounding_chunks(grounding_metadata):
    """Print grounding chunks and metadata in a nicely formatted way."""
    if not grounding_metadata:
        rprint("[yellow]No grounding metadata available[/yellow]")
        return

    # Print web search queries if available
    if grounding_metadata.web_search_queries:
        rprint(Panel(
            "\n".join(grounding_metadata.web_search_queries),
            title="Search Queries",
            border_style="green"
        ))

    # Print grounding chunks
    if grounding_metadata.grounding_chunks:
        for i, chunk in enumerate(grounding_metadata.grounding_chunks, 1):
            if chunk.web:
                panel_content = [
                    f"Title: {chunk.web.title}",
                    f"URI: {chunk.web.uri}"
                ]
                rprint(Panel(
                    "\n".join(panel_content),
                    title=f"Grounding Chunk {i}",
                    border_style="blue"
                ))

    # Print grounding supports
    if grounding_metadata.grounding_supports:
        for i, support in enumerate(grounding_metadata.grounding_supports, 1):
            panel_content = [
                f"Text: {support.segment.text}",
                f"Confidence: {support.confidence_scores[0]:.2f}",
                f"Chunk indices: {support.grounding_chunk_indices}"
            ]
            rprint(Panel(
                "\n".join(panel_content),
                title=f"Support {i}",
                border_style="magenta"
            ))

    # Render HTML content if available
    if hasattr(grounding_metadata, 'search_entry_point') and grounding_metadata.search_entry_point.rendered_content:
        rprint("[blue]Search Results HTML Preview:[/blue]")
        display(HTML(grounding_metadata.search_entry_point.rendered_content))

# Test the function with the previous response
print_grounding_chunks(response.candidates[0].grounding_metadata)


## Improved tool use _+ realtime

In [36]:
config={
    "generation_config": {"response_modalities": ["TEXT"]}}

MODEL = str(GeminiModel.GEMINI_2_FLASH)

client = get_gemini_client(http_options= {'api_version': 'v1alpha'})

async with client.aio.live.connect(model=MODEL, config=config) as session:
  message = "Hello? Gemini are you there?"
  print("> ", message, "\n")
  await session.send(input=message, end_of_turn=True)

  # For text responses, When the model's turn is complete it breaks out of the loop.
  turn = session.receive()
  async for chunk in turn:
    if chunk.text is not None:
      print(f'- {chunk.text}')

>  Hello? Gemini are you there? 

- Yes, I'
- m here! How can I help you today?



In [38]:
import contextlib
import wave

@contextlib.contextmanager
def wave_file(filename, channels=1, rate=24000, sample_width=2):
    with wave.open(filename, "wb") as wf:
        wf.setnchannels(channels)
        wf.setsampwidth(sample_width)
        wf.setframerate(rate)
        yield wf

config={
    "generation_config": {"response_modalities": ["AUDIO"]}}



async def async_enumerate(it):
  n = 0
  async for item in it:
    yield n, item
    n +=1

async with client.aio.live.connect(model=MODEL, config=config) as session:
  file_name = 'audio.wav'
  with wave_file(file_name) as wav:
    message = "Hello? Gemini are you there?"
    print("> ", message, "\n")
    await session.send(input=message, end_of_turn=True)

    turn = session.receive()
    async for n,response in async_enumerate(turn):
      if response.data is not None:
        wav.writeframes(response.data)

        if n==0:
          print(response.server_content.model_turn.parts[0].inline_data.mime_type)
        print('.', end='')


from IPython.display import display, Audio
display(Audio(file_name, autoplay=True))

>  Hello? Gemini are you there? 

audio/pcm;rate=24000
.............

In [39]:

import logging

logger = logging.getLogger('Live')
logger.setLevel('INFO')

In [40]:
import asyncio
class AudioLoop:
  def __init__(self, turns=None,  config=None):
    self.session = None
    self.index = 0
    self.turns = turns
    if config is None:
      config={
          "generation_config": {
              "response_modalities": ["AUDIO"]}}
    self.config = config

  async def run(self):
    logger.debug('connect')
    async with client.aio.live.connect(model=MODEL, config=self.config) as session:
      self.session = session

      async for sent in self.send():
        # Ideally send and recv would be separate tasks.
        await self.recv()

  async def _iter(self):
    if self.turns:
      for text in self.turns:
        print("message >", text)
        yield text
    else:
      print("Type 'q' to quit")
      while True:
        text = await asyncio.to_thread(input, "message > ")

        # If the input returns 'q' quit.
        if text.lower() == 'q':
          break

        yield text

  async def send(self):
    async for text in self._iter():
      logger.debug('send')

      # Send the message to the model.
      await self.session.send(input=text, end_of_turn=True)
      logger.debug('sent')
      yield text

  async def recv(self):
    # Start a new `.wav` file.
    file_name = f"audio_{self.index}.wav"
    with wave_file(file_name) as wav:
      self.index += 1

      logger.debug('receive')

      # Read chunks from the socket.
      turn = self.session.receive()
      async for n, response in async_enumerate(turn):
        logger.debug(f'got chunk: {str(response)}')

        if response.data is None:
          logger.debug(f'Unhandled server message! - {response}')
        else:
          wav.writeframes(response.data)
          if n == 0:
            print(response.server_content.model_turn.parts[0].inline_data.mime_type)
          print('.', end='')

      print('\n')

    display(Audio(file_name, autoplay=True))
    await asyncio.sleep(2)

In [42]:
await AudioLoop(["Respond in a woman's voice. Hello", "What's your name?"]).run()

message > Respond in a woman's voice. Hello
audio/pcm;rate=24000
...........



message > What's your name?
audio/pcm;rate=24000
.................

