In [None]:
!pip install pydgraph
#!pip install grapl_analyzerlib
!pip install --index-url https://test.pypi.org/simple/ grapl_analyzerlib

In [None]:
deployment_name = <YOUR_DEPLOYMENT>

In [None]:
from typing import *
import json
import pydgraph

from pydgraph import DgraphClient, DgraphClientStub
from grapl_analyzerlib.schemas import *
from grapl_analyzerlib.prelude import *
from grapl_analyzerlib.node_types import EdgeRelationship, PropPrimitive, PropType
from grapl_analyzerlib.nodes.base import BaseSchema

In [None]:
def set_schema(client, schema) -> None:
    op = pydgraph.Operation(schema=schema)
    client.alter(op)


def drop_all(client) -> None:
    op = pydgraph.Operation(drop_all=True)
    client.alter(op)


def format_schemas(schema_defs: List["BaseSchema"]) -> str:
    schemas = "\n\n".join([schema.generate_schema() for schema in schema_defs])

    types = "\n\n".join([schema.generate_type() for schema in schema_defs])

    return "\n".join(
        ["  # Type Definitions", types, "\n  # Schema Definitions", schemas]
    )


def query_dgraph_predicate(client: "GraphClient", predicate_name: str):
    query = f"""
        schema(pred: {predicate_name}) {{  }}
    """
    txn = client.txn(read_only=True)
    try:
        res = json.loads(txn.query(query).json)["schema"][0]
    finally:
        txn.discard()

    return res


def meta_into_edge(schema, predicate_meta):
    if predicate_meta.get("list"):
        return EdgeT(type(schema), BaseSchema, EdgeRelationship.OneToMany)
    else:
        return EdgeT(type(schema), BaseSchema, EdgeRelationship.OneToOne)


def meta_into_property(schema, predicate_meta):
    is_set = predicate_meta.get("list")
    type_name = predicate_meta["type"]
    primitive = None
    if type_name == "string":
        primitive = PropPrimitive.Str
    if type_name == "int":
        primitive = PropPrimitive.Int
    if type_name == "bool":
        primitive = PropPrimitive.Bool

    return PropType(primitive, is_set, index=predicate_meta.get("index", []))


def meta_into_predicate(schema, predicate_meta):
    try:
        if predicate_meta["type"] == "uid":
            return meta_into_edge(schema, predicate_meta)
        else:
            return meta_into_property(schema, predicate_meta)
    except Exception as e:
        LOGGER.error(f"Failed to convert meta to predicate: {predicate_meta} {e}")
        raise e


def query_dgraph_type(client: "GraphClient", type_name: str):
    query = f"""
        schema(type: {type_name}) {{ type }}
    """
    txn = client.txn(read_only=True)
    try:
        res = json.loads(txn.query(query).json)
    finally:
        txn.discard()

    if not res:
        return []
    if not res.get("types"):
        return []

    res = res["types"][0]["fields"]
    predicate_names = []
    for pred in res:
        predicate_names.append(pred["name"])

    predicate_metas = []
    for predicate_name in predicate_names:
        predicate_metas.append(query_dgraph_predicate(client, predicate_name))

    return predicate_metas


def extend_schema(graph_client: GraphClient, schema: "BaseSchema"):
    predicate_metas = query_dgraph_type(graph_client, schema.self_type())

    for predicate_meta in predicate_metas:
        predicate = meta_into_predicate(schema, predicate_meta)
        if isinstance(predicate, PropType):
            schema.add_property(predicate_meta["predicate"], predicate)
        else:
            schema.add_edge(predicate_meta["predicate"], predicate, "")


def provision_master_graph(
        master_graph_client: GraphClient, schemas: List["BaseSchema"]
) -> None:
    mg_schema_str = format_schemas(schemas)
    set_schema(master_graph_client, mg_schema_str)


In [None]:
mclient = DgraphClient(DgraphClientStub(deployment_name.lower() + '.dgraph.grapl:9080'))

In [None]:
# drop_all(mclient)

def provision_mg(mclient) -> None:
    # drop_all(mclient)

    schemas = (
        AssetSchema(),
        ProcessSchema(),
        FileSchema(),
        IpConnectionSchema(),
        IpAddressSchema(),
        IpPortSchema(),
        NetworkConnectionSchema(),
        ProcessInboundConnectionSchema(),
        ProcessOutboundConnectionSchema(),
        RiskSchema(),
        LensSchema(),
    )

    for schema in schemas:
        schema.init_reverse()

    for schema in schemas:
        extend_schema(mclient, schema)

    provision_master_graph(mclient, schemas)

provision_mg(mclient)

In [None]:
import os
import string

from hashlib import pbkdf2_hmac, sha256
from random import randint, choice

import boto3

def hash_password(cleartext, salt) -> str:
    hashed = sha256(cleartext).digest()
    return pbkdf2_hmac(
        'sha256', 
        hashed,
        salt,
        512000
    ).hex()

def create_user(username, cleartext):
    assert cleartext
    dynamodb = boto3.resource('dynamodb')
    table = dynamodb.Table(deployment_name + '-user_auth_table')
    
    # We hash before calling 'hashed_password' because the frontend will also perform
    # client side hashing
    cleartext += "f1dafbdcab924862a198deaa5b6bae29aef7f2a442f841da975f1c515529d254";
    
    cleartext += username;

    hashed = sha256(cleartext.encode('utf8')).hexdigest()
    
    for i in range(0, 5000):
        hashed = sha256(hashed.encode('utf8')).hexdigest()
    
    salt = os.urandom(16)
    password = hash_password(hashed.encode('utf8'), salt)
    table.put_item(
        Item={
            'username': username,
            'salt': salt,
            'password': password
        }
    )
    
    

allchar = string.ascii_letters + string.punctuation + string.digits
password = "".join(choice(allchar) for x in range(randint(14, 16)))
print(f'your password is {password}')
username = ''
assert username, 'Replace the username with your desired username'
create_user(username, password)
password = ""
print("""Make sure to clear this cell and restart the notebook to ensure your password does not leak!""")

In [None]:
# CLEAR CACHE
def clear_redis_caches():
    from redis import Redis

    def chunker(seq, size):
        return [seq[pos:pos + size] for pos in range(0, len(seq), size)]

    CACHE_ADDRS = [
    
    ]
    CACHE_PORT = 6379
    for CACHE_ADDR in CACHE_ADDRS:
        r = Redis(host=CACHE_ADDR, port=CACHE_PORT, db=0, decode_responses=True)

        for keys in chunker([k for k in r.keys()], 10000):
            r.delete(*keys)
            
clear_redis_caches()
