# Imports

In [1]:
import warnings
warnings.filterwarnings("ignore")

# Load variables
import os
from dotenv import load_dotenv
load_dotenv()

# Snowpark Imports
from snowflake.snowpark.session import Session
import snowflake.snowpark.functions as F
import snowflake.snowpark.types as T

# Connect to Snowflake

In [2]:
snowflake_connection_cfg = {
    "ACCOUNT": os.getenv('SF_ACCOUNT'),
    "USER": os.getenv('SF_USER'),
    "ROLE": os.getenv('SF_ROLE'),
    "PASSWORD": os.getenv('SF_PASSWORD'),
    "DATABASE": os.getenv('SF_DATABASE'),
    "SCHEMA": os.getenv('SF_SCHEMA'),
    "WAREHOUSE": os.getenv('SF_WAREHOUSE')
}

# Creating Snowpark Session
session = Session.builder.configs(snowflake_connection_cfg).create()

print('Role:     ', session.get_current_role())
print('Warehouse:', session.get_current_warehouse())
print('Database: ', session.get_current_database())
print('Schema:   ', session.get_current_schema())

Role:      "ACCOUNTADMIN"
Warehouse: "COMPUTE_WH"
Database:  "MACHINE_LEARNING"
Schema:    "PUBLIC"


In [13]:
# Create some test data to work with
masked_texts = [
    "Snowflake is an awesome technology because of its superior [MASK].",
    "During the summer, many people enjoy going to the [MASK] for a refreshing swim.",
    "The chef added a pinch of [MASK] to enhance the flavor of the soup.",
    "To capture the perfect shot, the photographer adjusted the [MASK] on his camera.",
    "The novel's twist was so unexpected that it completely changed the [MASK] of the story."
]


df = session.create_dataframe(masked_texts, schema=['INPUTS'])
df.show(n=15, max_width=1000)

-------------------------------------------------------------------------------------------
|"INPUTS"                                                                                 |
-------------------------------------------------------------------------------------------
|Snowflake is an awesome technology because of its superior [MASK].                       |
|During the summer, many people enjoy going to the [MASK] for a refreshing swim.          |
|The chef added a pinch of [MASK] to enhance the flavor of the soup.                      |
|To capture the perfect shot, the photographer adjusted the [MASK] on his camera.         |
|The novel's twist was so unexpected that it completely changed the [MASK] of the story.  |
-------------------------------------------------------------------------------------------



# Register & Run Fill-Mask Model

In [4]:
# Get the model registry object
from snowflake.ml.registry import Registry
reg = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name=session.get_current_schema()
    )

In [39]:
# Get the fill-mask model from Huggingface
# Make sure it fits into a Snowflake warehouse and does not require GPUs
# Otherwise the model must deployed in Snowpark Container Services
from transformers import pipeline
pipe = pipeline("fill-mask", model="google-bert/bert-base-uncased")

# use Snowpark Optimized Warehouse because model is big
session.use_warehouse('snowpark_opt_wh')

# Register the model to Snowflake
snow_model = reg.log_model(
    pipe, 
    model_name='bert_base_uncased', 
    sample_input_data=df.limit(10),
    conda_dependencies=['tokenizers','transformers','sentencepiece']
    )

# Outputs
filled_masks = snow_model.run(df).cache_result()
filled_masks.select('OUTPUTS').show(n=5, max_width=100)

--------------------------------------------------------------------------------------------------------
|"OUTPUTS"                                                                                             |
--------------------------------------------------------------------------------------------------------
|[{"score": 0.223093181848526, "token": 2836, "token_str": "performance", "sequence": "snowflake i...  |
|[{"score": 0.34172508120536804, "token": 5474, "token_str": "salt", "sequence": "the chef added a...  |
|[{"score": 0.3048401176929474, "token": 10014, "token_str": "lens", "sequence": "to capture the p...  |
|[{"score": 0.2125847488641739, "token": 4309, "token_str": "tone", "sequence": "the novel's twist...  |
|[{"score": 0.43861153721809387, "token": 3509, "token_str": "beach", "sequence": "during the summ...  |
--------------------------------------------------------------------------------------------------------



In [40]:
# Transform outputs into rows
filled_masks = filled_masks.join_table_function('FLATTEN', F.parse_json(F.col('OUTPUTS')))
filled_masks = filled_masks.with_column('SCORE', F.col('VALUE')['score'].cast(T.FloatType()))
filled_masks = filled_masks.with_column('SEQUENCE', F.col('VALUE')['sequence'].cast(T.StringType()))
filled_masks = filled_masks.with_column('TOKEN_STR', F.col('VALUE')['token_str'].cast(T.StringType()))
filled_masks.select(['INPUTS','SEQUENCE','TOKEN_STR','SCORE']).show(n=50, max_width=1000)

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"INPUTS"                                                                                 |"SEQUENCE"                                                                                  |"TOKEN_STR"  |"SCORE"               |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|Snowflake is an awesome technology because of its superior [MASK].                       |snowflake is an awesome technology because of its superior performance.                     |performance  |0.223093181848526     |
|Snowflake is an awesome technology because of its superior [MASK].                       |snowflake is an aweso