# Notebook for querying gremlin


In [None]:
#| default_exp gremlin_connection

In [None]:
#| export


from azure.keyvault.secrets import SecretClient
import pandas as pd
import os
from contextlib import contextmanager
#from azure.cosmos import CosmosClient
import time
from gremlin_python.driver import serializer
from tqdm import tqdm
from datetime import  timedelta

from gremlin_python.driver.client import Client
from azure.cosmos import CosmosClient



from azure.identity import DefaultAzureCredential
from azure.keyvault.secrets import SecretClient

#from azure.cosmos import CosmosClient
from azure.identity import DefaultAzureCredential

# from nebari_dia_adx.credentials import DefaultAzureCredential, secret_client


In [None]:
#| export
os.environ['key_vault_url'] = "https://nebaridia.vault.azure.net/"
credential = DefaultAzureCredential()
secret_client = SecretClient(vault_url=os.environ.get('key_vault_url'), credential=credential)

In [None]:
#| export

class GremlinClientManager():
    def __init__(self, database_name, container_name):
        self.database_name = database_name
        self.container_name = container_name
        self.credential = DefaultAzureCredential()
        self.token = None
        self.token_expiration = 0
        self.gremlin_client = None
        self.gremlin_endpoint =   secret_client.get_secret('gremlin-endpoint').value
        self.endpoint = secret_client.get_secret('cosmos-endpoint').value
        self.get_token() # Get the initial token

    def get_token(self):
        token_response = self.credential.get_token( secret_client.get_secret('token-cosmos-endpoint').value)
        self.token = token_response.token
        self.token_expiration = time.time() + token_response.expires_on  # Set expiration time

    def ensure_token_validity(self):
        # Refresh the token if it's about to expire (e.g., within the next 5 minutes)
        if time.time() >= self.token_expiration - 300:
            print('Refreshing token...')
            self.get_token()  # Refresh token if expired or about to expire

    
    @contextmanager
    def client(self):
        self.ensure_token_validity()

        # Create the Gremlin client
        gremlin_client = Client(
            url=self.gremlin_endpoint,
            traversal_source="g",
            username=f"/dbs/{self.database_name}/colls/{self.container_name}",
            password=self.token,
            message_serializer=serializer.GraphSONSerializersV2d0()
        )

        try:
            self.ensure_token_validity()
            yield gremlin_client
        finally:
            gremlin_client.close() 

    def get_client_cosmos(self):
        return CosmosClient(self.endpoint, self.credential)
    
    def get_database(self):
        return self.get_client_cosmos().get_database_client(self.database_name)
    
    def get_container(self):
        return self.get_database().get_container_client(self.container_name)


# create single instance that can be reused to maintain state
gremlin_manager = GremlinClientManager(database_name= secret_client.get_secret('gremlin-prod-databasename').value, container_name=   secret_client.get_secret('gremlin-prod-containername').value)

gremlin_manager.client()


In [None]:
#| export

import nest_asyncio
nest_asyncio.apply()

def submit_query(query): # TODO ensure no code injeciton here!!!
    start_time = time.time()
    try:
        with gremlin_manager.client() as client:
            result= client.submit(query).all().result()
            print(f'Getting result took {time.time() - start_time} s')
            return result
    except Exception as e:
            print(f'An error occured in submit_query: {e}')

In [None]:
#| export

def convert_date_to_proper_format(date):
    return pd.to_datetime(date).strftime('%Y-%m-%dT%H:%M:%SZ')

In [None]:
#| export

def add_day(date):
    return (pd.to_datetime(date) + timedelta(days=1)).strftime("%Y-%m-%dT%H:%M:%SZ")

# Single rig at a time

In [None]:
#| export

def get_rigs( start_date, end_date =None): 
    """Get rigs with active operations for the given date"""
    start_date = convert_date_to_proper_format(start_date)
    if end_date is None:
        end_date = add_day(start_date)
    query = f"""
    g.V().hasLabel('operations')
        .has('startTimeUTC', gt('{start_date}')).has('endTimeUTC', lt('{end_date}'))
        .in('PERFORMED')
        .group().by(values('properties', 'rigName')).by(values('id'))
    """


    return submit_query(query)[0]


In [None]:
#| export

# TODO rig might have no relaitons
def get_wellbores_by_rig(rigId, start_date, end_date=None):
    start_date = convert_date_to_proper_format(start_date)
    if end_date is None:
        end_date = add_day(start_date)
    query = f"""g.V('{rigId}')
                .out('PERFORMED').has('startTimeUTC', gt('{start_date}')).has('endTimeUTC',lt('{end_date}'))
                .out('ON')
                   .group().by(values('properties', 'wellboreName')).by(values('id'))
                """

    
    return submit_query(query)[0]

In [None]:
#| export

def query_rigs(rigId, start_date, end_date):
    query= f"""g.V('{rigId}').out('PERFORMED')
                .has('startTimeUTC', gt('{start_date}')).has('endTimeUTC',lt('{end_date}')) 
                    .project('startTimeUTC', 'endTimeUTC', 'conveyance','mainActivity',  'activityName', 'description', 'activityCategory')
                    .by(values('properties', 'startTimeUTC'))
                    .by(values('properties', 'endTimeUTC'))
                    .by(values('properties', 'conveyance'))
                    .by(values('properties', 'mainActivity'))
                    .by(values('properties', 'activityName'))
                    .by(values('properties', 'description'))
                    .by(values('properties', 'activityCategory'))"""
    return submit_query(query)