# Define table name, stored procedure name, connections, and GraphQL query

In [None]:
# Define output table and stored procedure name
output_table_name = "companies_raw"
stored_proc_name = "load_companies"

In [None]:
import os

# Define Azure SQL Database connection
jdbcHostname = os.getenv("SQLDB_HOST")
user = os.getenv("SQLDB_USER")
password = dbutils.secrets.get(scope="azure_key_vault", key="SQLDB-PW") # use Azure Key Vault to save this password. 
jdbcDatabase = os.getenv("SQLDB_BB")
jdbcPort = 1433
jdbcUrl = "jdbc:sqlserver://{0}:{1};database={2}".format(jdbcHostname, jdbcPort, jdbcDatabase)
connectionProperties = {
"user" : user,
"password" : password,
"driver" : "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}

In [None]:
import shopify

# Define Shopify B2B private API connection
shop_url = "b2b.myshopify.com"
api_version = "2023-04" # the lastest version that can be supported by ShopifyAPI.
private_app_password = dbutils.secrets.get(scope="azure_key_vault", key="SHOPIFYB2B-PW") # use Azure Key Vault to save this password. 

# Create a Shopify session
session = shopify.Session(shop_url, api_version, private_app_password)

In [None]:
from datetime import datetime, timedelta
import pytz
# time zone
tz = 'US/Pacific'
# difference between current and previous date
delta = timedelta(days=7)
# define lastWeek
lastWeek = datetime.now(pytz.utc).astimezone(pytz.timezone(tz)) - delta
# make lastWeek as a string
lastWeek = lastWeek.strftime("%Y-%m-%d")

In [None]:
# Define the parameters
FIRST = '100' # change to "50" in production
CREATED_AT = '>' + lastWeek # '>2023-09-20' # change to "lastWeek" in production

query = """
{
  companies(first: %s, query: "created_at:%s")
    {
    edges {
      cursor
      node {
        id
        name
        createdAt
        externalId
        locations(first: 5)
        {
            edges {
                cursor
                node {
                        id
                        name
                        createdAt
                        }
                }
            pageInfo {
            hasNextPage
            hasPreviousPage
            startCursor
            endCursor
            }
        }
      }
    }
    pageInfo {
      hasNextPage
      hasPreviousPage
      startCursor
      endCursor
    }
  }
}
""" % (FIRST, CREATED_AT)

# Execute GraphQL query

In [None]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
import json

# Define the schema for the DataFrame
schema = T.StructType([
    T.StructField("CompanyID", T.StringType(), False),
    T.StructField("CompanyName", T.StringType(), False),
    T.StructField("CreatedAt", T.StringType(), False),
    T.StructField("CompanyExternalID", T.StringType(), True),
    T.StructField("LocationIDs", T.StringType(), False),
    T.StructField("LocationNames", T.StringType(), False),
])

# Create an empty list to hold the result temporarily
data = []

# Active the shopify session
shopify.ShopifyResource.activate_session(session)

# Execute GraphQL query and iterate to the end page
while True:
    result = shopify.GraphQL().execute(query)
    response = json.loads(result)

    for company in response['data']['companies']['edges']:
        company_id = company["node"]["id"]
        company_name = company["node"]["name"]
        created_at = company["node"]["createdAt"]
        company_external_id = company["node"]["externalId"]
        locations = company["node"]["locations"]["edges"]
        location_ids = [loc["node"]["id"] for loc in locations]
        location_names = [loc["node"]["name"] for loc in locations]

        data.append((company_id, company_name, created_at, company_external_id, location_ids, location_names))

    if not response['data']['companies']['pageInfo']['hasNextPage']:
        break
    
    query = query.replace('first: %s' % FIRST, 'first: %s, after: "%s"' % (FIRST, response['data']['companies']['pageInfo']['endCursor']))

# Create a DataFrame
df = spark.createDataFrame(data, schema)

# Disactive the shopify session
shopify.ShopifyResource.clear_session()

# Save the response to Azure SQL DB

In [None]:
# Define the output columns
output_cols = [
    "CompanyID",
    "CompanyName",
    "CreatedAt",
    "CompanyExternalID",
    "LocationIDs",
    "LocationNames",
    "RecordCreatedDate",
]

In [None]:
# Get the current Pacific Time
current_timestamp_pt = F.from_utc_timestamp(
    F.current_timestamp(), "America/Los_Angeles"
)

# Output with current timestamp
df = df.withColumn("RecordCreatedDate", current_timestamp_pt)

# Overwrite the output into Azure SQL Database
df.select(output_cols).write.jdbc(
    url=jdbcUrl,
    table=output_table_name,
    mode="overwrite",
    properties=connectionProperties,
)

# Execute Stored Procedure to delete the existing companies and load data into table

In [None]:
driver_manager = spark._sc._gateway.jvm.java.sql.DriverManager
connection = driver_manager.getConnection(jdbcUrl, user, password)
query = "EXEC {0};".format(stored_proc_name)
connection.prepareCall(query).execute()
connection.close()