# Imports

In [4]:
import warnings
warnings.filterwarnings("ignore")

# Load variables
import os
from dotenv import load_dotenv
load_dotenv()

# Snowpark Imports
from snowflake.snowpark.session import Session
import snowflake.snowpark.functions as F
from snowflake.snowpark import types as T

# Connect to Snowflake

In [5]:
snowflake_connection_cfg = {
    "ACCOUNT": os.getenv('SF_ACCOUNT'),
    "USER": os.getenv('SF_USER'),
    "ROLE": os.getenv('SF_ROLE'),
    "PASSWORD": os.getenv('SF_PASSWORD'),
    "DATABASE": os.getenv('SF_DATABASE'),
    "SCHEMA": os.getenv('SF_SCHEMA'),
    "WAREHOUSE": os.getenv('SF_WAREHOUSE')
}

# Creating Snowpark Session
session = Session.builder.configs(snowflake_connection_cfg).create()

print('Role:     ', session.get_current_role())
print('Warehouse:', session.get_current_warehouse())
print('Database: ', session.get_current_database())
print('Schema:   ', session.get_current_schema())

Role:      "ACCOUNTADMIN"
Warehouse: "COMPUTE_WH"
Database:  "MACHINE_LEARNING"
Schema:    "PUBLIC"


In [29]:
# Create some test data to work with
incomplete_sentences = [
    "If I could travel anywhere in the world, I would go to ",
    "One of the most important lessons I've learned in life is ",
    "When I think about the future of technology, I wonder ",
    "The best advice I ever received was to always ",
    "Every morning, I start my day by "
]


df = session.create_dataframe(incomplete_sentences, schema=['INPUTS'])
df.show(n=15, max_width=1000)

--------------------------------------------------------------
|"INPUTS"                                                    |
--------------------------------------------------------------
|If I could travel anywhere in the world, I would go to      |
|One of the most important lessons I've learned in life is   |
|When I think about the future of technology, I wonder       |
|The best advice I ever received was to always               |
|Every morning, I start my day by                            |
--------------------------------------------------------------



# Register & Run Text Generation Model

In [7]:
# Get the model registry object
from snowflake.ml.registry import Registry
reg = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name=session.get_current_schema()
    )

In [37]:
# Get the text generation model from Huggingface
# Make sure it fits into a Snowflake warehouse and does not require GPUs
# Otherwise the model must deployed in Snowpark Container Services
from transformers import pipeline
pipe = pipeline("text-generation", model="distilbert/distilgpt2")

# Register the model to Snowflake (predict is the model's function we want to call)
snow_model = reg.log_model(
    pipe, 
    model_name='distilgpt2', 
    sample_input_data=df.limit(10),
    conda_dependencies=['tokenizers','transformers','sentencepiece']
    )

# Outputs
results = snow_model.run(df).cache_result()
results.select('INPUTS', F.parse_json(F.col('OUTPUTS'))[0]['generated_text'].cast(T.StringType()).as_('GENERATED_TEXT')).show(max_width=1000)

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"INPUTS"                                                    |"GENERATED_TEXT"                                                                                                                                                                                                |
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|If I could travel anywhere in the world, I would go to      |If I could travel anywhere in the world, I would go to ices.‍‍                                                            