# schema

> Our Pydantic Schema

In [1]:
#| default_exp schema

In [2]:
#| hide
from nbdev.showdoc import *

In [7]:
#| export
import os
from typing import Dict, Callable
from ratelimit import limits, sleep_and_retry
import time
from datetime import datetime
from functools import wraps

import vertexai as vai
from vertexai.language_models import TextGenerationModel
from vertexai.language_models._language_models import MultiCandidateTextGenerationResponse
from google.cloud.aiplatform import BatchPredictionJob
from google.cloud import storage
from google.api_core.exceptions import ResourceExhausted

from langchain.llms import VertexAI
from langchain.embeddings import VertexAIEmbeddings

# GRPC requires this
os.environ["GRPC_DNS_RESOLVER"] = "native"

PROJECT_ID = "cdejam-gbsrc-ext-cah"
PROJECT_BUCKET = "pharma_email_classification"
REGION = "us-central1"

Langchain utils

In [4]:
#| export
def get_embedder() -> VertexAIEmbeddings:
    return VertexAIEmbeddings(
        project=PROJECT_ID,
        location=REGION,
        model_name='textembedding-gecko'
    )

Google utils

In [13]:
#| export
WRITE_PREFIX = "JDB_experiments"

DEFAULT_PREDICT_PARAMS = {
    "temperature": 0.2,
    "top_p": 0.95,
    "top_k": 40,
}

_VERTEX_INITIATED = False
ONE_MINUTE = 60


def quota_handler(func: Callable):
    @wraps(func)
    def handle_quota(*args, **kwargs):
        """Handles GCP ResourceExhausted exceptions. 
        Will sleep the thread until the next minute before trying again."""
        try:
            return func(*args, **kwargs)
        except ResourceExhausted:
            while True:
                try:
                    return func(*args, **kwargs)
                except ResourceExhausted:
                    # Sleep until the next minute
                    sleep_time = 60 - datetime.utcnow().second
                    time.sleep(sleep_time)
    return handle_quota


def get_storage_client() -> storage.Client:
    return storage.Client(project=PROJECT_ID)


def init_vertexai(
    project_id: str = PROJECT_ID,
    region: str = REGION):
    global _VERTEX_INITIATED
    if not _VERTEX_INITIATED:
        vai.init(project=project_id, location=region)
        _VERTEX_INITIATED = True


def get_model() -> TextGenerationModel:
    init_vertexai()
    return TextGenerationModel.from_pretrained("text-bison")


@sleep_and_retry
@limits(calls=50, period=ONE_MINUTE)
def predict(
        prompt: str,
        parameters: Dict[str, str] = DEFAULT_PREDICT_PARAMS
        ) -> MultiCandidateTextGenerationResponse:
    model = get_model()
    return model.predict(
        prompt,
        **parameters)


def batch_predict(
        source_uri: str,
        destination_uri_prefix: str,
        model_parameters: Dict[str, str] = DEFAULT_PREDICT_PARAMS
) -> BatchPredictionJob:
    """
    Make a batch prediction request to text-bison.

    :param source_uri: Source file in GCS with prompted requests, 
        I.E. 'gs://BUCKET_NAME/test_table.jsonl'
    :param destination_uri_prefix: Where the results will be written, 
        ex: 'gs://BUCKET_NAME/tmp/2023-05-25-vertex-LLM-Batch-Prediction/result3'
    """
    model = get_model()
    batch_prediction_job = model.batch_predict(
        source_uri=[source_uri],
        destination_uri_prefix=destination_uri_prefix,
        # Optional:
        model_parameters=model_parameters
    )
    # print(batch_prediction_job.display_name)
    # print(batch_prediction_job.resource_name)
    # print(batch_prediction_job.state)
    return batch_prediction_job

In [14]:
#| hide
import nbdev; nbdev.nbdev_export()