# Image Recognition in Snowflake using Snowpark Python and PyTorch

_For comments and feedback, please reach out to [Dash](https://twitter.com/iamontheinet)_.


In [13]:
# Snowpark
from snowflake.snowpark.session import Session
from snowflake.snowpark.functions import udf
from snowflake.snowpark.version import VERSION

# Misc
import pandas as pd
import json
import cachetools
import logging 
logger = logging.getLogger("snowflake.snowpark.session")
logger.setLevel(logging.ERROR)

### Establish Secure Connection to Snowflake

Using the Snowpark API, it’s quick and easy to establish a secure connection between Snowflake and Notebook.

 *Connection options: Username/Password, MFA, OAuth, Okta, SSO*

In [14]:
# Create Snowflake Session object
connection_parameters = json.load(open('connection.json'))
session = Session.builder.configs(connection_parameters).create()
session.sql_simplifier_enabled = True

snowflake_environment = session.sql('select current_user(), current_role(), current_database(), current_schema(), current_version(), current_warehouse()').collect()
snowpark_version = VERSION

# Current Environment Details
print('User                        : {}'.format(snowflake_environment[0][0]))
print('Role                        : {}'.format(snowflake_environment[0][1]))
print('Database                    : {}'.format(snowflake_environment[0][2]))
print('Schema                      : {}'.format(snowflake_environment[0][3]))
print('Warehouse                   : {}'.format(snowflake_environment[0][5]))
print('Snowflake version           : {}'.format(snowflake_environment[0][4]))
print('Snowpark for Python version : {}.{}.{}'.format(snowpark_version[0],snowpark_version[1],snowpark_version[2]))

OperationalError: 250001: Could not connect to Snowflake backend after 0 attempt(s).Aborting

### Upload MobileNet V3 files to Snowflake Internal stage

In [None]:
session.file.put('imagenet1000_clsidx_to_labels.txt','@evan_files',overwrite=True,auto_compress=False)
session.file.put('mobilenetv3.py','@dash_files',overwrite=True,auto_compress=False)
session.file.put('mobilenetv3-large-1cd25616.pth','@evan_files',overwrite=True,auto_compress=False)

### Snowpark Python User-Defined Function (UDF) for image recognition

Now to deploy the pre-trained model for inference, let's **create and register a Snowpark Python UDF and add the model files as dependencies**. Once registered, getting new predictions is as simple as calling the function by passing in data.

*NOTE: Scalar UDFs operate on a single row / set of data points and are great for online inference in real-time. And this UDF is called from [Snowpark_PyTorch_Streamlit_Upload_Image_Rec](Snowpark_PyTorch_Streamlit_Upload_Image_Rec.py) and [Snowpark_PyTorch_Streamlit_OpenAI_Image_Rec](Snowpark_PyTorch_Streamlit_OpenAI_Image_Rec.py) Streamlit apps.*

TIP: For more information on Snowpark Python User-Defined Functions, refer to the [docs](https://docs.snowflake.com/en/developer-guide/snowpark/python/creating-udfs.html).

In [None]:
session.clear_packages()
session.clear_imports()

# Add model files and test images as dependencies on the UDF
session.add_import('@evan_files/imagenet1000_clsidx_to_labels.txt')
session.add_import('@evan_files/mobilenetv3.py')
session.add_import('@evan_files/mobilenetv3-large-1cd25616.pth')

# Add Python packages from Snowflke Anaconda channel
session.add_packages('snowflake-snowpark-python','torchvision','joblib','cachetools')

@cachetools.cached(cache={})
def load_class_mapping(filename):
  with open(filename, "r") as f:
   return f.read()

@cachetools.cached(cache={})
def load_model():
  import sys
  import torch
  from torchvision import models, transforms
  import ast
  from mobilenetv3 import mobilenetv3_large

  IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
  import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]

  model_file = import_dir + 'mobilenetv3-large-1cd25616.pth'
  imgnet_class_mapping_file = import_dir + 'imagenet1000_clsidx_to_labels.txt'

  IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

  transform = transforms.Compose([
      transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
  ])

  # Load the Imagenet {class: label} mapping
  cls_idx = load_class_mapping(imgnet_class_mapping_file)
  cls_idx = ast.literal_eval(cls_idx)

  # Load pretrained image recognition model
  model = mobilenetv3_large()
  model.load_state_dict(torch.load(model_file))

  # Configure pretrained model for inference
  model.eval().requires_grad_(False)

  return model, transform, cls_idx

def load_image(image_bytes_in_str):
  import os
  image_file = '/tmp/' + str(os.getpid())
  image_bytes_in_hex = bytes.fromhex(image_bytes_in_str)

  with open(image_file, 'wb') as f:
    f.write(image_bytes_in_hex)

  return open(image_file, 'rb')

@udf(name='image_recognition_using_bytes',session=session,replace=True,is_permanent=True,stage_location='@dash_files')
def image_recognition_using_bytes(image_bytes_in_str: str) -> str:
  import sys
  import torch
  from PIL import Image
  import os

  model, transform, cls_idx = load_model()
  img = Image.open(load_image(image_bytes_in_str))
  img = transform(img).unsqueeze(0)

  # Get model output and human text prediction
  logits = model(img)

  outp = torch.nn.functional.softmax(logits, dim=1)
  _, idx = torch.topk(outp, 1)
  idx.squeeze_()
  predicted_label = cls_idx[idx.item()]

  return f"{predicted_label}"

*NOTE: This UDF is called from [Snowpark_PyTorch_Streamlit_Upload_Image_Rec](Snowpark_PyTorch_Streamlit_Upload_Image_Rec.py) and [Snowpark_PyTorch_Streamlit_OpenAI_Image_Rec](Snowpark_PyTorch_Streamlit_OpenAI_Image_Rec.py) Streamlit apps.*