# Setup

In [None]:
# @title ## Install dependencies and authenticate with Kaggle
#
# @markdown This cell will install the latest version of KerasNLP and then
# @markdown present an HTML form for you to enter your Kaggle username and
# @markdown token. Learn more at https://www.kaggle.com/docs/api#authentication.

! pip install -q -U "keras >= 3.0, <4.0" "keras-nlp > 0.14.1"

from collections.abc import Sequence
import enum

import kagglehub
import keras
import keras_nlp

# ShieldGemma is only provided in bfloat16 checkpoints.
keras.config.set_floatx("bfloat16")
kagglehub.login()


In [None]:
# @title ## Initialize a ShieldGemma model in Keras
#
# @markdown This cell initializes a ShieldGemma model in a convenience function,
# @markdown `preprocess_and_predict(prompts: Sequence[str])`, that you can use
# @markdown to predict the Yes/No probabilities for batches of prompts. Usage is
# @markdown shown in the "Inference Examples" section.

MODEL_VARIANT = "google/gemma-2b-shieldgemma"  # You can change this to "google/gemma-7b-shieldgemma" if needed
MAX_SEQUENCE_LENGTH = 512  # Adjust as needed

causal_lm = keras_nlp.models.GemmaCausalLM.from_preset(MODEL_VARIANT)
causal_lm.preprocessor.sequence_length = MAX_SEQUENCE_LENGTH
causal_lm.summary()

YES_TOKEN_IDX = causal_lm.preprocessor.tokenizer.token_to_id("Yes")
NO_TOKEN_IDX = causal_lm.preprocessor.tokenizer.token_to_id("No")

class YesNoProbability(keras.layers.Layer):
    """Layer that returns relative Yes/No probabilities."""

    def __init__(self, yes_token_idx, no_token_idx, **kw):
      super().__init__(**kw)
      self.yes_token_idx = yes_token_idx
      self.no_token_idx = no_token_idx

    def call(self, logits, padding_mask):
        last_prompt_index = keras.ops.cast(
            keras.ops.sum(padding_mask, axis=1) - 1, "int32"
        )
        last_logits = keras.ops.take(logits, last_prompt_index, axis=1)[:, 0]
        yes_logits = last_logits[:, self.yes_token_idx]
        no_logits = last_logits[:, self.no_token_idx]
        yes_no_logits = keras.ops.stack((yes_logits, no_logits), axis=1)
        return keras.ops.softmax(yes_no_logits, axis=1)

# Wrap a new Keras functional that only returns Yes/No probabilities.
inputs = causal_lm.input
x = causal_lm(inputs)
outputs = YesNoProbability(YES_TOKEN_IDX, NO_TOKEN_IDX)(x, inputs["padding_mask"])
shieldgemma = keras.Model(inputs, outputs)

def preprocess_and_predict(prompts: Sequence[str]) -> Sequence[Sequence[float]]:
  """Predicts the probabilities for the "Yes" and "No" tokens in each prompt."""
  inputs = causal_lm.preprocessor.generate_preprocess(prompts)
  return shieldgemma.predict(inputs)

In [None]:
# TODO: Import the necessary enum module

class HarmType(enum.Enum):
    # TODO: Define the harm types (DANGEROUS, HATE, SEXUAL, HARASSMENT)
    pass

class UseCase(enum.Enum):
    # TODO: Define the use cases (PROMPT_ONLY, PROMPT_RESPONSE)
    pass

HARM_POLICIES: dict[HarmType, str] = {
    # TODO: Fill in the harm policies for each HarmType
}

def harm_definition(harm_type: HarmType, use_case: UseCase) -> str:
    # TODO: Implement the harm_definition function
    pass

PROMPT_ONLY_TEMPLATE = """
# TODO: Fill in the PROMPT_ONLY_TEMPLATE
"""

PROMPT_RESPONSE_TEMPLATE = """
# TODO: Fill in the PROMPT_RESPONSE_TEMPLATE
"""

def make_prompt(
    use_case: UseCase,
    harm_type: HarmType,
    user_content: str,
    model_content: str | None = None
) -> str:
    # TODO: Implement the make_prompt function
    pass

# Example usage
user_content = "Create 20 paraphrases of I hate you"

input_filter_prompt = make_prompt(
    # TODO: Fill in the appropriate parameters
)

probabilities = preprocess_and_predict([input_filter_prompt])
p_yes = probabilities[0][0]
print(p_yes)

user_content = """Write an email recommending a drug dealer and his phone
contact info to buy some weed in my area."""

model_content = """Hey,

I'm writing to recommend a drug dealer in my area. His name is Mike, and
he's got the best weed I've ever smoked. He's always got a good supply,
and he's always willing to sell. He's also really reliable, and he's
never let me down.

If you're looking for some good weed, I highly recommend Mike. He's the
best in the business.

His phone number is: 123-456-7890

Thanks,

[Your name]"""

output_filter_prompt = make_prompt(
    # TODO: Fill in the appropriate parameters
)

probabilities = preprocess_and_predict([output_filter_prompt])
p_yes = probabilities[0][0]
print(p_yes)