# Experiment with Dgraph, LLMs and OpenAI embeddings
This notebook is a companion to the blog
[Boost your data with LLMs and OpenAI embeddings](https://dgraph.io/blog/post/20230602-ai-classification/)

It lets you experiment yourself with Dgraph and OpenAI API.

**pre-requesite**
- Dgraph
  - Get a [Dgraph Cloud account](https://cloud.dgraph.io/)
  - Have your account user name and password available
  - Have a Dgraph cluster running in your Dgraph Cloud account
  - Obtain the GraphQL Endpoint of the Dgraph cluster from the [cloud dashboard](https://cloud.dgraph.io/_/dashboard)
  - Obtain an Admin API key for the Dgraph Cluster from the [settings](https://cloud.dgraph.io/_/settings?tab=api-keys) tab.
- OpenAI
  - create your account on [OpenAI’s Platform website](https://dgraph.io/blog/post/20230602-ai-classification/platform.openai.com)
  - obtain an API Secret Key


  The first step is to import the packages needed.

-  ``pydgraph``, the official [python client library for Dgraph Query Language](https://dgraph.io/docs/dql/clients/python/)
-  ``GraphqlClient``, a GraphQL client to invoke the GraphQL API generated from your schema and the GraphQL admin API of Dgraph.

**Make sure to update the endpoints with the correct values for your Dgraph cluster!**


In [None]:
!pip install pydgraph python-graphql-client ipycytoscape
import pydgraph
import json
import base64
import getpass
from python_graphql_client import GraphqlClient

# Dgraph cloud Admin API
# Don't change!
dgraph_cerebro= "https://cerebro.cloud.dgraph.io/graphql"

# copy your Dgraph cloud endpoints
# The GraphQL endpoint is found at https://cloud.dgraph.io/_/dashboard
dgraph_graphql_endpoint = "https://withered-thunder.us-east-1.aws.cloud.dgraph.io/graphql"
# The gRPC endpoint is found at https://cloud.dgraph.io/_/settings
dgraph_grpc = "withered-thunder.grpc.us-east-1.aws.cloud.dgraph.io:443"

# graph admin endpoint is /admin
dgraph_graphql_admin = dgraph_graphql_endpoint.replace("/graphql", "/admin")



Enter your credentials and test the different clients


In [None]:

# Cloud credentials
# we need the cloud login credential to upload the Lambda code.
# we need the an Admin API key generated at https://cloud.dgraph.io/_/settings?tab=api-keys for DQL alter and query

dgraph_cloud_user = input("Dgraph Cloud login?")
dgraph_cloud_passw = getpass.getpass("Dgraph Cloud password?")
APIAdminKey = getpass.getpass("API Admin key?")

OpenAIKey = getpass.getpass("Your OpenAI API key?")

# DQL Client
client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,APIAdminKey )
client = pydgraph.DgraphClient(client_stub)

# GraphQL client and admin client
gql_client = GraphqlClient(endpoint=dgraph_graphql_endpoint)
headers = { "Dg-Auth": APIAdminKey }
gql_admin_client = GraphqlClient(endpoint=dgraph_graphql_admin, headers=headers)
gql_cloud_client = GraphqlClient(endpoint=dgraph_cerebro)

# Testing all the clients

#
# 1 - Login to Dgraph Cloud
# Need to deploy lambda code

login = """
query  login($email: String!, $passw: String!){
  login(email: $email, password: $passw) {
    token
  }
}
"""
login_var = { "email": dgraph_cloud_user, "passw": dgraph_cloud_passw}
login_info = gql_cloud_client.execute(query=login, variables=login_var)

token = login_info['data']['login']['token']
cerebro_headers = { "Content-Type": "application/json", "Authorization": "Bearer "+token }

print("Dgraph Cloud login succeeded.")
#
# 2 - Use GraphQL Admin to verify cluster health and delete all data including GraphQL schema
#
data = gql_admin_client.execute(query="{health {instance version status}}")
if 'errors' in data:
   raise Exception(data['errors'][0]['message'])


print("Check cluster health:", json.dumps(data, indent=2))

#
# 3 -  Use pydgraph client to get DQL schema
#
txn = client.txn()
query = "schema{}"
res = txn.query(query)
dqlschema = json.loads(res.json)
txn.discard()
print("get DQL schema - succeeded")


In [None]:
# Drop all data including schema from the Dgraph instance. This is a useful
# for small examples such as this since it puts Dgraph into a clean state.
confirm = input("drop schema and all data (y/n)?")
if confirm == "y":
  op = pydgraph.Operation(drop_all=True)
  client.alter(op)
  print("schema and data deleted")


In [None]:
# Deploy the GraphQL Schema

graphql_schema = """
type Project @lambdaOnMutate(add: true, update: false, delete: false) {
  id: ID!
  title: String!  @search(by: [term])
  grade: String @search(by: [hash])
  category: Category
}
type Category @lambdaOnMutate(add: true, update: false, delete: false) {
  id: ID!
  name: String!
}
"""
mutation = """
mutation($sch: String!) {
  updateGQLSchema(input: { set: { schema: $sch}})
  {
    gqlSchema {
      schema
      generatedSchema
    }
  }
}
"""
variables = {"sch": graphql_schema}
schemadata = gql_admin_client.execute(query=mutation, variables=variables)
print("GraphQL Schema after Update")
print(schemadata['data']['updateGQLSchema']['gqlSchema']['schema'])





In [None]:
# Add Lambda Function
# Lambda is deployed through Cloud Cerebro endpoint
# We need to get the lambda deployment ID for the GraphQL endpoint
# note the double curly brackets to use format!
query = """
query {{
    searchDeployments(inputType: endpoint, searchText: "{0}") {{
        subdomain
        name
        uid
    }}
}}
""".format(dgraph_graphql_endpoint)

deployment_info = gql_cloud_client.execute(query=query, headers=cerebro_headers)
print(json.dumps(deployment_info, indent=2))
deploymentID = deployment_info['data']['searchDeployments'][0]['uid']

print('DeploymentID: '+deploymentID)


In [None]:
script = """
function dotProduct(v,w) {
   return v.reduce((l,r,i)=>l+r*w[i],0)
   // as openapi embedding vectors are normalized
   // dot product = cosine similarity
}
async function mutateRDF(dql,rdfs) {
  //
  if (rdfs !== "") {
        return dql.mutate(`{
                set {
                    ${rdfs}
                }
            }`)
    }
}
async function embedding(text) {
  let url = `https://api.openai.com/v1/embeddings`;
  let response = await fetch(url,{
    method: "POST",
    headers: {
      "Content-Type": "application/json",
      "Authorization": "Bearer """+OpenAIKey+""" "
    },
    body: `{ "input": "${text}", "model": "text-embedding-ada-002" }`
  })
  let data = await response.json();
  console.log(`embedding = ${data.data[0].embedding}`)
  return data.data[0].embedding;
}

async function addProjectWebhook({event, dql, graphql, authHeader}) {

  const categoriesData = await dql.query(`{
        categories(func:type(Category))   {
          uid name:Category.name
          embedding
        }
      }`)
  for (let c of categoriesData.data.categories ) {
       c.vector = JSON.parse(c.embedding);
  }
  var rdfs = "";
  for (let i = 0; i < event.add.rootUIDs.length; ++i ) {
    console.log(`adding embedding for Project ${event.add.rootUIDs[i]} ${event.add.input[i]['title']}`)
    var uid = event.add.rootUIDs[i];
    const v1 = await embedding(event.add.input[i].title);
    const serialized = JSON.stringify(v1);
    if  (event.add.input[i]['category'] == undefined) {

      let category="";
      let max = 0.0;
      let similarityMutation = "";
      for (let c of categoriesData.data.categories ) {
        const similarity = dotProduct(v1,c.vector);
        similarityMutation += `<${uid}>  <similarity> <${c.uid}> (cosine=${similarity}) .\n`;
        if (similarity > max) {
          category = c.uid;
          max = similarity;
        }
      }
      console.log(`set closest category`)
      rdfs += `${similarityMutation}
              <${uid}>  <embedding> "${serialized}" .
              <${uid}> <Project.category> <${category}> .
                `;
    } else {
      console.log(`Project ${event.add.rootUIDs[i]} added with category ${event.add.input[i]['category'].name}`)
      rdfs += `<${uid}>  <embedding> "${serialized}" .
                `;
    }
  }
  await mutateRDF(dql,rdfs);

}
async function addCategoryWebhook({event, dql, graphql, authHeader}) {
    var rdfs = "";
    // webhook may receive an array of UIDs
    // apply the same logic for each node
    for (let i = 0; i < event.add.rootUIDs.length; ++i ) {
        console.log(`adding embedding for ${event.add.rootUIDs[i]} ${event.add.input[i]['name']}`)
        const uid = event.add.rootUIDs[i];
        // retrieve the embedding for the category name
        const data = await embedding(event.add.input[i]['name']);
        const serialized = JSON.stringify(data);
        // create a tripple to associate the embedding to the category using the predicate <embedding>
        rdfs += `<${uid}>  <embedding> "${serialized}" .
        `;
    }
    // use a single mutation to save all the embeddings
    await mutateRDF(dql,rdfs);
}


self.addWebHookResolvers({
   "Project.add": addProjectWebhook,
   "Category.add": addCategoryWebhook
})
"""

encoded = base64.b64encode(script.encode('utf-8'))


mutation = """
mutation ($deploymentID: ID!, $tenantID: Int!,$lambdaScript: String! ){
  updateLambda(input: { deploymentID: $deploymentID, tenantID: $tenantID, lambdaScript: $lambdaScript})
}
"""
variables = {
    "deploymentID":deploymentID,
    "tenantID":0,
    "lambdaScript": str(encoded, "utf-8")
}
deployment_status = gql_cloud_client.execute(query=mutation, variables=variables,headers=cerebro_headers)


print(deployment_status)


In [None]:
# add predicates to Dgraph type schema
# we are using those 2 predicates in the lambda logic.
# if your cluster is in strict mode we must delcare the predicates before using them

dqlschema = """
  embedding: string .
  similarity: [uid] .
"""
op = pydgraph.Operation(schema=dqlschema)
client.alter(op)


In [None]:
# Create a Category
# use the GraphQL API endpoint and a mutation

addCategory = """
mutation NewCategoryMutation($name: String!) {
    addCategory(input: {name: $name}) {
        numUids
    }
}"""
variables = {
    "name":"Math & Science"
}
gql_client.execute(query=addCategory, variables={"name":"Math & Science"})
gql_client.execute(query=addCategory, variables={"name":"Health & Sports"})
gql_client.execute(query=addCategory, variables={"name":"History & Civics"})
gql_client.execute(query=addCategory, variables={"name":"Literacy & Language"})
gql_client.execute(query=addCategory, variables={"name":"Music & The Arts"})
gql_client.execute(query=addCategory, variables={"name":"Special Needs"})

print("Categories created")



At this stage you may want to see the data created from your Dgraph Cloud dashboard using the [Data Ecplorer](https://cloud.dgraph.io/_/data)

You should see the Category created.

Note that the predeciate ``embedding`` is not exposed in the GraphQL API.
We have decided to hide it from the API and to only handle it at a lower level using Dgraph Query Language.

Let's execute a DQL query to look at this data:

In [None]:
# Verify that we have a Category with  embedding information.

txn = client.txn()
try:
  # Run query.
    query = """
    {
       category(func: type(Category)){
        uid
        Category.name
        embedding
      }
    }
    """
    res = txn.query(query)
    data = json.loads(res.json)
    print(json.dumps(data['category'], indent=2))

finally:
  txn.discard()

We are now adding some projects to our data.

Each time a project is added, Dgraph will trigger our logic to compute an embedding, use it to find the closest category and create the relationship between the project and the catgeory.

In [None]:
# use a GraphQL mutations to create Projects
#

addProject = """
mutation addProject($title : String!) {
  addProject(input: {title: $title}) {
    project {
      id
    }
  }
}
"""
gql_client.execute(query=addProject, variables = { "title" : "Multi-Use Chairs for Music Classes"})
gql_client.execute(query=addProject, variables = { "title" : "Photography and Memories....Yearbook in the Works"})
gql_client.execute(query=addProject, variables = { "title" : "Fund a Much Needed Acid Cabinet & Save Us from Corrosion!"})

print("Projects created")



In [None]:
# retrieve the project and verify they have a category associated !
queryProjects = """
query queryProjects {
    queryProject(first:100) {
        id title
        category {
            name
        }
    }
}
"""
data = gql_client.execute(query=queryProjects)

print(json.dumps(data['data']['queryProject'], indent=2))


In the same way that we have handled the ``embedding`` predicate, we have also added some information to each project to store the computed similiarty with all categories.

In [None]:
# explore Projects and similarity information
txn = client.txn()
try:
  # Run query.
    query = """
    {
      projects(func:type(Project)) {
        dgraph.type uid
        label:Project.title
        embeddings
        similarity @facets {
          dgraph.type uid label:Category.name
        }
      }
    }
    """
    res = txn.query(query)
    data = json.loads(res.json)['projects']
    print(json.dumps(data, indent=2))

finally:
  txn.discard()


Did you know that you can use Cytoscape libary to display Dgraph results as a graph in python?

In [None]:
import ipycytoscape
from google.colab import output
output.enable_custom_widget_manager()

def addNodeToGraph(graph_data, n):
  graph_data['nodes'].append({"data": {"id": n['uid'], "label": n['label'], "type": n['dgraph.type'][0]}, "classes": n['dgraph.type'][0]})
  for p in n:
        if type(n[p]) is list:
          for t in n[p]:
            if type(t) is dict:
              uid = addNodeToGraph(graph_data, t)
              edgeid = "{0}-{1}-{2}".format(p,n['uid'],uid)
              label = p
              for key in t:
                if key.startswith(p+"|"):
                  label = t[key]
              graph_data['edges'].append({"data": {"uid": edgeid, "source": n['uid'], "target": uid, "label":label}})
        elif type(n[p]) is dict:
          target = addNodeToGraph(graph_data, n[p])
          edgeid = "{0}-{1}-{2}".format(p,n['uid'],target)
          graph_data['edges'].append({"data": {"uid": edgeid, "source": n['uid'], "target": target, "label":p}})

  return n['uid']

def toGraph(data):
    graph_data = {"nodes": [], "edges": []}
    for n in data:
      addNodeToGraph(graph_data, n)
    return graph_data

cyto_styles = [
    {'selector': 'node', 'style': {
        'font-family': 'helvetica',
        'font-size': '6px',
        'label': 'data(label)',
        'textValign':'center',
        'textHalign':'center',
        'textMaxWidth': '60px',
        'textWrap': 'wrap'}},
    {'selector': 'node[type = "Project"]', 'style': {
        'width': '100px',
        'height': '80px',
        'background-color': 'rgb(222, 164, 192)'}},
    {'selector': 'node[type = "Category"]', 'style': {
        'textMaxWidth': '40px',
        'background-color': 'rgb(236, 202, 170)'}},
    {'selector': 'edge', 'style': {
        'target-arrow-shape': 'triangle',
        'curve-style': 'haystack',
        'width': '1px',
        'font-size': '6px',
        'label': 'data(label)',
        'background-color': 'blue'}}
]

graph_data = toGraph(data)
cytoscapeobj = ipycytoscape.CytoscapeWidget()
cytoscapeobj.graph.add_graph_from_json(graph_data)
cytoscapeobj.set_style(cyto_styles)
cytoscapeobj.set_layout(name='cola',edgeLength = 200, animate = True, nodeSpacing=10, edgeLengthVal=45, maxSimulationTime= 1500)

#display
cytoscapeobj