# Imports

In [23]:
import warnings
warnings.filterwarnings("ignore")

# Load variables
import os
from dotenv import load_dotenv
load_dotenv()

# Snowpark Imports
from snowflake.snowpark.session import Session
from snowflake.snowpark import functions as F
from snowflake.snowpark import types as T

# Connect to Snowflake

In [2]:
snowflake_connection_cfg = {
    "ACCOUNT": os.getenv('SF_ACCOUNT'),
    "USER": os.getenv('SF_USER'),
    "ROLE": os.getenv('SF_ROLE'),
    "PASSWORD": os.getenv('SF_PASSWORD'),
    "DATABASE": os.getenv('SF_DATABASE'),
    "SCHEMA": os.getenv('SF_SCHEMA'),
    "WAREHOUSE": os.getenv('SF_WAREHOUSE')
}

# Creating Snowpark Session
session = Session.builder.configs(snowflake_connection_cfg).create()

print('Role:     ', session.get_current_role())
print('Warehouse:', session.get_current_warehouse())
print('Database: ', session.get_current_database())
print('Schema:   ', session.get_current_schema())

Role:      "ACCOUNTADMIN"
Warehouse: "COMPUTE_WH"
Database:  "MACHINE_LEARNING"
Schema:    "PUBLIC"


In [7]:
# Create some test data to work with
ner_texts = [
    "Elon Musk, the CEO of Tesla, announced on January 15, 2022, that the company will start manufacturing in Berlin by the end of the year.",
    "Microsoft, founded by Bill Gates and Paul Allen, has its headquarters in Redmond, Washington and was established on April 4, 1975.",
    "The Louvre Museum in Paris, France, houses the famous Mona Lisa painting and attracts millions of visitors from around the globe annually.",
    "The Treaty of Versailles was signed on June 28, 1919, by representatives from Germany and the Allied Powers, marking the end of World War I.",
    "Serena Williams, an American professional tennis player, won her 23rd Grand Slam singles title at the Australian Open in Melbourne on January 28, 2017."
]

df = session.create_dataframe(ner_texts, schema=['INPUTS'])
df.show(n=15, max_width=1000)

-----------------------------------------------------------------------------------------------------------------------------------------------------------
|"INPUTS"                                                                                                                                                 |
-----------------------------------------------------------------------------------------------------------------------------------------------------------
|Elon Musk, the CEO of Tesla, announced on January 15, 2022, that the company will start manufacturing in Berlin by the end of the year.                  |
|Microsoft, founded by Bill Gates and Paul Allen, has its headquarters in Redmond, Washington and was established on April 4, 1975.                       |
|The Louvre Museum in Paris, France, houses the famous Mona Lisa painting and attracts millions of visitors from around the globe annually.               |
|The Treaty of Versailles was signed on June 28, 1919, by repres

# Register & Run Token Classification Model

In [4]:
# Get the model registry object
from snowflake.ml.registry import Registry
reg = Registry(
    session=session, 
    database_name=session.get_current_database(), 
    schema_name=session.get_current_schema()
    )

In [29]:
# Get the token classification model from Huggingface
# Make sure it fits into a Snowflake warehouse and does not require GPUs
# Otherwise the model must deployed in Snowpark Container Services
from transformers import pipeline
pipe = pipeline("token-classification", model="Babelscape/wikineural-multilingual-ner")

# use SOWH because model is big
session.use_warehouse('snowpark_opt_wh')

# Register the model to Snowflake (predict is the model's function we want to call)
snow_model_custom = reg.log_model(
    pipe, 
    model_name='wikineural_multilingual_ner', 
    sample_input_data=df.limit(10),
    conda_dependencies=['tokenizers','transformers','sentencepiece']
    )

# Model signature
ner_tags = snow_model_custom.run(df).cache_result()
ner_tags.show(n=15, max_width=1000)

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [30]:
# Transform outputs into single rows
ner_tags = ner_tags.join_table_function('FLATTEN', F.parse_json(F.col('OUTPUTS')))
ner_tags = ner_tags.with_column('ENTITY', F.col('VALUE')['entity'].cast(T.StringType()))
ner_tags = ner_tags.with_column('SCORE', F.col('VALUE')['score'].cast(T.FloatType()))
ner_tags = ner_tags.with_column('WORD', F.col('VALUE')['word'].cast(T.StringType()))
ner_tags = ner_tags.with_column('INDEX', F.col('VALUE')['index'].cast(T.IntegerType()))
ner_tags = ner_tags.with_column('START', F.col('VALUE')['start'].cast(T.IntegerType()))
ner_tags = ner_tags.with_column('END', F.col('VALUE')['end'].cast(T.IntegerType()))
ner_tags = ner_tags.select(['INPUTS', 'ENTITY', 'SCORE', 'WORD', 'INDEX', 'START', 'END'])
ner_tags.show(max_width=1000)

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"INPUTS"                                                                                                                                    |"ENTITY"  |"SCORE"             |"WORD"     |"INDEX"  |"START"  |"END"  |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|The Louvre Museum in Paris, France, houses the famous Mona Lisa painting and attracts millions of visitors from around the globe annually.  |B-LOC     |0.9992496371269226  |Louvre     |2        |4        |10     |
|The Louvre Museum in Paris, France, houses the famous Mona Lisa painting and attracts millions of visitors from around the globe annually. 