In [1]:
import os, sys
sys.path.append(os.path.join(os.path.abspath(''), '../'))

from databases.ExperimentDb import (
    db,
    Environment,
    ImageOnlyExperiment,
    ImageOnlyResource,
    ImageOnlyShader,
    ImageOnlyTrace,
    init_from_default_db
)

import logging
import json
import pickle
import peewee as pw
from tqdm import tqdm

logger=logging.getLogger(__name__)

if not db.is_closed():
    db.close()

init_from_default_db()

IMPORT_SOURCE="4060-db-exported-fixed.dat"

The import process is as follows:
1. Load the pickled database
2. Resolve a proper order of importing ImageOnlyExperiment, so that every parentId while importing is already processed
   > Also note, id does not necessarily reflects the index position in list.
3. For each ImageOnlyExperiment
   1. If error, just import it and continue
   2. Check for associated Resource & Shader instance
      - If not present in the current database, insert and got a new ID
      - we then store the new ID into the database
   3. If traced, then check if current trace is a dupe of some existing one
      - If dupe, then this is good - just use the old trace
      - If not, warn and insert the new trace, link it back, continue
   4. Store the ImageOnlyExperiment object - Also patch `parentId` if not null
4. Patch trace (optional)

In [2]:
def resolveExprImportOrder(experimentIdToIdx: 'dict[int, int]', dataExperiments) -> 'list[int]':
    # resolve order
    rootNodesId = set(experimentIdToIdx.keys())
    # id -> parent of id
    parentRelationsId = {}
    for id in rootNodesId:
        if dataExperiments[experimentIdToIdx[id]]["parent"] != None:
            parentRelationsId[id] = dataExperiments[experimentIdToIdx[id]]["parent"]
            rootNodesId.remove(id)
    
    # place all root nodes first
    exprQueue = [id for id in rootNodesId]
    enqueuedNodes = set(exprQueue)
    remainingNodes = set(experimentIdToIdx.keys()) - set(exprQueue)
    
    while True:
        toBeRemovedId = []
        for nodeId in remainingNodes:
            if dataExperiments[experimentIdToIdx[nodeId]]["parent"] in enqueuedNodes:
                toBeRemovedId.append(nodeId)
                exprQueue.append(nodeId)
                enqueuedNodes.add(nodeId)

        if len(toBeRemovedId) > 0:
            for id in toBeRemovedId:
                remainingNodes.remove(id)
        else:
            # reached convergence
            break
    
    if len(remainingNodes) > 0:
        raise RuntimeError(f"Circular or self dependencies encountered at expr id set {remainingNodes}")

    return exprQueue

def getOrCreateShader(shaderId, shaderIdToIdx, dataShaders):
    if shaderId is None:
        return None

    shdr = dataShaders[shaderIdToIdx[shaderId]]
    existingShdr = ImageOnlyShader.get_or_none(
        ImageOnlyShader.shader_id == shdr["shader_id"],
        ImageOnlyShader.fragment_spv == shdr["fragment_spv"]
    )
    if existingShdr is not None:
        return existingShdr.id
    else:
        if ImageOnlyShader.get_or_none(ImageOnlyShader.shader_id == shdr['shader_id']):
            logger.warning(
                f"Shader {shdr['shader_id']} compiled to different SPIR-V from imported databases"
            )
        
        newShdr = ImageOnlyShader.create(
            shader_id=shdr['shader_id'],
            fragment_spv=shdr["fragment_spv"]
        )
        
        return newShdr.id

def getOrCreateResource(resourceId, resourceIdToIdx, dataResources):
    if resourceId is None:
        return None

    res = dataResources[resourceIdToIdx[resourceId]]
    resInst, isCreated = ImageOnlyResource.get_or_create(
        uniform_block=res["uniform_block"]
    )
    return resInst.id

def getOrCreateEnvironments(environmentId, environmentIdToIdx, dataEnvironments):
    env = dataEnvironments[environmentIdToIdx[environmentId]]
    envInst, isCreated = Environment.get_or_create(
        node=env["node"],
        os=env["os"],
        cpu=env["cpu"],
        gpu=env["gpu"],
        gpu_driver=env["gpu_driver"],
        comment=env["comment"]
    )

    return envInst.id

def getOrCreateTrace(traceId, traceIdIdToIdx, dataTraces) -> 'int':
    if traceId is None:
        return None

    trace = dataTraces[traceIdIdToIdx[traceId]]
    traceInst, isCreated = ImageOnlyTrace.get_or_create(
        bb_idx_map=trace["bb_idx_map"],
        bb_trace_counters=trace["bb_trace_counters"],
        traced_fragment_spv=trace["traced_fragment_spv"]
    )

    return traceInst.id

def doImport(importSource: str):
    with open(importSource, "rb") as fp:
        combinedDb = pickle.load(fp)
    
    dataExperiments = combinedDb["experiments"]
    dataShaders = combinedDb["shaders"]
    dataResources = combinedDb["resources"]
    dataTraces = combinedDb["traces"]
    dataEnvironments = combinedDb["environments"]

    # maps Id from remote_db -> cur_db
    experimentIdMap = {}

    # maps Idx <-> Id
    environmentIdxToId = {idx: elem["id"] for idx, elem in enumerate(dataEnvironments)}
    environmentIdToIdx = {elem["id"]: idx for idx, elem in enumerate(dataEnvironments)}

    shaderIdxToId = {idx: elem["id"] for idx, elem in enumerate(dataShaders)}
    shaderIdToIdx = {elem["id"]: idx for idx, elem in enumerate(dataShaders)}

    resourceIdxToId = {idx: elem["id"] for idx, elem in enumerate(dataResources)}
    resourceIdToIdx = {elem["id"]: idx for idx, elem in enumerate(dataResources)}

    traceIdIdxToId = {idx: elem["id"] for idx, elem in enumerate(dataTraces)}
    traceIdIdToIdx = {elem["id"]: idx for idx, elem in enumerate(dataTraces)}

    experimentIdxToId = {idx: elem["id"] for idx, elem in enumerate(dataExperiments)}
    experimentIdToIdx = {elem["id"]: idx for idx, elem in enumerate(dataExperiments)}

    exprProcIdQueue = resolveExprImportOrder(experimentIdToIdx, dataExperiments)

    with db.atomic() as transaction:
        for exprId in tqdm(exprProcIdQueue):
            expr = dataExperiments[experimentIdToIdx[exprId]]
            
            # check for associated resources
            shdrId = getOrCreateShader(expr["shader"], shaderIdToIdx, dataShaders)
            resourceId = getOrCreateResource(expr["resource"], resourceIdToIdx, dataResources)
            envId = getOrCreateEnvironments(expr["environment"], environmentIdToIdx, dataEnvironments)
            traceId = getOrCreateTrace(expr["trace"], traceIdIdToIdx, dataTraces)

            parentId = experimentIdMap[expr["parent"]] if expr["parent"] is not None else None

            exprInst = ImageOnlyExperiment.create(
                time=expr["time"],
                environment=envId,
                augmentation=expr["augmentation"],
                augmentation_annotation=expr["augmentation_annotation"],
                width=expr["width"],
                height=expr["height"],
                shader_shadertoy_id=expr["shader_shadertoy_id"],
                shader=shdrId,
                resource=resourceId,
                trace=traceId,
                parent=parentId,
                measurement=expr["measurement"],
                image_hash=expr["image_hash"],
                num_cycles=expr["num_cycles"],
                num_trials=expr["num_trials"],
                errors=expr["errors"],
                results=expr["results"],
            )

            experimentIdMap[expr["id"]] = exprInst.id


In [3]:
doImport(IMPORT_SOURCE)

 58%|█████▊    | 11903/20669 [13:38<03:24, 42.77it/s] Shader 3tBfDG compiled to different SPIR-V from imported databases
100%|██████████| 20669/20669 [24:19<00:00, 14.17it/s]
