# Michelin Restaurant Dataset: ImgGeneration (Shutterstock)

Text-to-img models:
- Shutterstock: endpoint already up and running on Databricks but it is not compatible with the resolution parameter to lower it
- stable-diffusion-2-1: available on [Huggingface](https://huggingface.co/stabilityai/stable-diffusion-2-1) and it is mentioned in this [doc](https://docs.google.com/document/d/1GFZzo8paONRC9YYM-nwu1z1DXEqyOqh5ziiSe9kS06Q/edit)

In [0]:
import pandas as pd
import os
from pyspark import SparkFiles
from pyspark.sql.functions import *

import requests
import base64
import io
from PIL import Image
from IPython.display import display

### Catalog, Schema Set up

In [0]:
catalog_ = os.getenv('CATALOG_NAME')
schema_ = os.getenv('SCHEMA_NAME')
spark.sql("USE CATALOG "+catalog_)
spark.sql("USE SCHEMA "+schema_)

### Check if secret scope is set up to reach endpoints

In [0]:
scope_name_ = 'michelin_scope'
secret_name_ = 'pat_ga'

In [0]:
# Check scope
existing_scopes = [scope.name for scope in dbutils.secrets.listScopes()]
if scope_name_ in existing_scopes:
    print("Secret scope exists!")
else:
    print("Secret scope doesn't exist, create it via CLI!")

# Check secret
existing_secrets = [secret.key for secret in dbutils.secrets.list(scope_name_)]
if secret_name_ in existing_secrets:
    print("Secret exists!")
else:
    print("Secret doesn't exist, create it via CLI!")

### Read Silver Data

In [0]:
silver_df = spark.sql("SELECT * FROM silver_data")
display(silver_df)

### Testing Image Generation for 1 record

In [0]:
endpoint_url_ = 'https://e2-demo-field-eng.cloud.databricks.com/serving-endpoints/databricks-shutterstock-imageai/invocations'

In [0]:
test_df = (spark.
  read.
  table(catalog_+'.'+schema_+'.silver_data').
  select(['Res_ID', 'Name', 'Description']).
  where("Res_ID = 'res-id-100092'"))
  
display(test_df)

In [0]:
# Function to generate an text-to-image request
def generateImg(row_, endpoint_url = endpoint_url_):

  ## Extract info from the row
  id_ = row_['Res_ID']
  name_ = row_['Name']
  text_ = row_['Description']
  text_pyload = {
    "prompt": text_
  }

  ## Reach endpoint and process response
  headers = {
    "Authorization": f"Bearer " + dbutils.secrets.get(scope=scope_name_, key=secret_name_)
  }

  response = requests.post(endpoint_url, json=text_pyload, headers=headers)
  if response.status_code == 200:
    image_data = response.json()['data'][0]['b64_json']
    print("%s :: Successful!" % id_)
  else:
    image_data = None
    print("%s :: Failed! Status code: %s" % (id_, response.status_code))

  return {
    'Res_ID': id_,
    'Img': image_data
  }

In [0]:
# Call endpoint
img_response = generateImg(test_df.limit(1).collect()[0])
print(test_df.limit(1).collect()[0])

# Extract output
if img_response['Img'] is not None:
  print(img_response['Img']) ## This will be stored in the dataset
  # Display img
  image = io.BytesIO(base64.decodebytes(bytes(img_response['Img'], "utf-8")))
  decoded_img = Image.open(image)
  resized_img = decoded_img.resize((300, 300))
  display(resized_img)

In [0]:
## Write the img string into file
directory = "/Volumes/"+catalog_+"/"+schema_+"/init"
file_path = directory + "/" + img_response['Res_ID']

with open(file_path, 'w') as file: # write img as string
  file.write(img_response['Img'])

### Applying img generation at scale with Pandas UDFs on Spark

Notes:
- Within a spark UDF we cannot use `dbutils` (source [here](https://docs.databricks.com/en/dev-tools/databricks-utils.html#databricks-utilities)). Therefore, we should define the PAT as a global variable outside the function and then use it.

In [0]:
## Save PAT as global variable
PAT_ = dbutils.secrets.get(scope=scope_name_, key=secret_name_)

## Create the base function
def generateImg(text_, endpoint_url = endpoint_url_):
  ## Convert text into prompt format
  text_pyload = {
    "prompt": text_
    #"resolution": "256x256" ## Lowering img resolution (default 1024x1024) -- not supported by the endpoint
  }
  ## Reach endpoint
  headers = {
    "Authorization": f"Bearer " + PAT_
  }
  response = requests.post(endpoint_url, json=text_pyload, headers=headers)
  ## Process response
  if response.status_code == 200:
    image_data = response.json()['data'][0]['b64_json']
  else:
    image_data = None
    print("Failed! Status code: %s" % (response.status_code))
  return image_data

In [0]:
## Encapsulate function into a PandasUDF
@pandas_udf("string")
def save_Imgs(res_descr_: pd.Series) -> pd.Series:
  return res_descr_.apply(generateImg)

In [0]:
%sql
-- Create an empty table for imgs
DROP TABLE IF EXISTS rest_descr_img_genai;
CREATE TABLE IF NOT EXISTS rest_descr_img_genai (
  Res_ID STRING,
  `Name` String,
  GenAI_Img String
);

In [0]:
## Generating Img for all entries requires lot of compute and memory -- Generate only for some records
(spark.table(catalog_+'.'+schema_+'.silver_data')      
      .filter(col("Stars_score") >= 1)
      .filter(col("Country").isin('USA', 'Italy'))
      .select(['Res_ID', 'Name', 'Description'])
      .withColumn('GenAI_Img', save_Imgs('Description'))
      .drop('Description')
      .write
      .mode('overwrite')
      .saveAsTable('rest_descr_img_genai'))