## Dataset snapshot

### Train / Test / val Presplit

In [1]:
import os, sys
sys.path.append(os.path.join(os.path.abspath(''), '../../'))
sys.path.append(os.path.join(os.path.abspath(''), '../../toyDb'))

import numpy as np
import json
from toyDb.databases import ExperimentDb
ExperimentDb.init_from_default_db()

def getShaderIdSplit() -> 'np.array':
  # Good, let's hard code the submission ddl for I3D'24
  RNGSeed = 20240105
  TRAIN_RATIO = 0.8
  TEST_RATIO = 0.05

  CACHE = os.path.join(os.path.abspath(''), 'shaderIdSplitCache.json')

  # this is complete (than ImageOnlyShader)
  allShaderIdQuery = ExperimentDb.ImageOnlyExperiment.select(
    ExperimentDb.ImageOnlyExperiment.shader_shadertoy_id
  ).order_by(ExperimentDb.ImageOnlyExperiment.shader_shadertoy_id.asc()).distinct()

  # the db in use have 20669 is_imageonly shaders
  print(f"Shaders that took part in the partition: {len(allShaderIdQuery)}")
  assert(len(allShaderIdQuery) == 20669)

  rng = np.random.default_rng(RNGSeed)

  shaderIds = np.array([i.shader_shadertoy_id for i in allShaderIdQuery])
  shuffledIds = rng.permutation(shaderIds)
  print(shaderIds)
  print(shuffledIds)

  if not os.path.isfile(CACHE):  
    with open(CACHE, "w") as fp:
      json.dump(shuffledIds.tolist(), fp)
  else:
    with open(CACHE, "r") as fp:
      idFromFile = json.load(fp)
      assert(idFromFile == shuffledIds.tolist())
  
  numTrain = int(len(shuffledIds) * TRAIN_RATIO)
  numTest = int(len(shuffledIds) * TEST_RATIO)
  numVal = len(shuffledIds) - numTrain - numTest
  print(f"numTrain: {numTrain}; numTest: {numTest}; numVal: {numVal}")

  return (shuffledIds[:numTrain], shuffledIds[numTrain:numTrain + numTest], shuffledIds[numTrain + numTest:])

trainIds, testIds, valIds = getShaderIdSplit()

Shaders that took part in the partition: 20669
['3d23DR' '3d23Dc' '3d23R3' ... 'wtyfzz' 'wtyyWD' 'wtyyz3']
['wd3GDH' '3sXyDn' 'MdK3Dd' ... 'WltBzM' 'MsVXzc' 'sdVGWW']
numTrain: 16535; numTest: 1033; numVal: 3101


### Dataset splits for each architecture


In [2]:
from misc.ComplexDatasetSnapshotter import (
    EnvironmentFilter,
    CycleTrialsFilter,
    ErrorFilter,
    ShadertoyIdFilter,
    WidthHeightFilter,
    ResourceFilter,
    TraceAvailabilityFilter,
    SpvTokenizedLengthFilter,
    TraceDuplicationPostFilter,
    AugmentationFilter,
    TimeThresholdFilter,
    ImageHashFilter,
    ComplexDatasetSnapshotter
)
from misc.Directory import (
  getIntermediateDir
)

import logging
import hashlib
import pickle

logging.basicConfig(
  level=logging.INFO,
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

snapshotter = ComplexDatasetSnapshotter()

# Basic filters
snapshotter.registerFilter(EnvironmentFilter())
snapshotter.registerFilter(WidthHeightFilter())
snapshotter.registerFilter(ResourceFilter())
snapshotter.registerFilter(CycleTrialsFilter())
snapshotter.registerFilter(ErrorFilter())
snapshotter.registerFilter(TraceAvailabilityFilter())
snapshotter.registerFilter(ImageHashFilter())
snapshotter.registerFilter(AugmentationFilter())
snapshotter.registerFilter(TimeThresholdFilter(10))

lengthFilter = SpvTokenizedLengthFilter()
lengthFilter.setThreshold(4096)

if os.path.isfile(os.path.join(getIntermediateDir(), "./lengthFilterCache.json")):
    lengthFilter.readFromCache(os.path.join(getIntermediateDir(), "./lengthFilterCache.json"))
else:
    lengthFilter.process(parallel=True)
    lengthFilter.writeToCache(os.path.join(getIntermediateDir(), "./lengthFilterCache.json"))

snapshotter.registerFilter(lengthFilter)

# Train / test split filter
shdrIdFilter = ShadertoyIdFilter()
shdrIdFilter.registerGroup("trainShdrExprs", trainIds.tolist())
shdrIdFilter.registerGroup("testShdrExprs", testIds.tolist())
shdrIdFilter.registerGroup("valShdrExprs", valIds.tolist())
snapshotter.registerFilter(shdrIdFilter)

snapshotter.examineGroups(1)

2023-12-31 13:54:44,527 - misc.ComplexDatasetSnapshotter - INFO - White hash for (1024, 768): 908b6cfc9aef496dd5ab5c5540d80c6383ed6e92f86044574c996315381bc064
2023-12-31 13:54:44,528 - misc.ComplexDatasetSnapshotter - INFO - Transparent black hash for (1024, 768): bbd05cf6097ac9b1f89ea29d2542c1b7b67ee46848393895f5a9e43fa1f621e5
2023-12-31 13:54:44,537 - misc.ComplexDatasetSnapshotter - INFO - White hash for (800, 600): d883267b40e389a772f00ef4c50d49471138afed85da8f101de16ad6cf5a9d9f
2023-12-31 13:54:44,538 - misc.ComplexDatasetSnapshotter - INFO - Transparent black hash for (800, 600): 124617c1f65e92d3bc895fbd869e4bb16a30754b198f59e6e973949b9aaa1b01
2023-12-31 13:54:44,564 - misc.ComplexDatasetSnapshotter - INFO - White hash for (1920, 1080): 598ddbfa658eaf70a2e0c50fa12f914c28a4496bbb150a21d6d77d73b0d8c55d
2023-12-31 13:54:44,565 - misc.ComplexDatasetSnapshotter - INFO - Transparent black hash for (1920, 1080): 788ae0147bdf979a6575938ca2d7d4403788588f7be2010f03776c968fd1ab49


  0%|          | 0/28278 [00:00<?, ?it/s]

  0%|          | 0/28278 [00:00<?, ?it/s]

EnvironmentFilter_EnvId1 = 447888
EnvironmentFilter_EnvId2 = 20669
EnvironmentFilter_EnvId3 = 20669
EnvironmentFilter_EnvId4 = 20669
EnvironmentFilter_EnvId5 = 20669
EnvironmentFilter_EnvId6 = 20669
EnvironmentFilter_EnvId7 = 20669
EnvironmentFilter_EnvId8 = 20669
WidthHeightFilter_1024-768 = 551233
WidthHeightFilter_800-600 = 20669
WidthHeightFilter_1920-1080 = 20669
ResourceFilter_resource1 = 126226
ResourceFilter_resourceNone = 188875
ResourceFilter_resource2 = 13878
ResourceFilter_resource3 = 13873
ResourceFilter_resource4 = 13873
ResourceFilter_resource5 = 13875
ResourceFilter_resource6 = 13876
ResourceFilter_resource7 = 13872
ResourceFilter_resource8 = 13867
ResourceFilter_resource9 = 13872
ResourceFilter_resource10 = 13881
ResourceFilter_resource11 = 13873
ResourceFilter_resource12 = 13879
ResourceFilter_resource13 = 13871
ResourceFilter_resource14 = 13869
ResourceFilter_resource15 = 13868
ResourceFilter_resource16 = 13874
ResourceFilter_resource17 = 13874
ResourceFilter_resourc

In [3]:
canonicalFilter = [
    [('CycleTrialsFilter', '30cycles-10trials')],
    [('WidthHeightFilter', '1024-768')],
    [('AugmentationFilter', 'aug0')]
]

commonAdditionalFilters = [
    [('ResourceFilter', 'resource1')],
    [('ErrorFilter', 'error0')],
    [('TraceAvailabilityFilter', 'haveTrace')],
    [('ImageHashFilter', 'normalHash')],
    [('SpvTokenizedLengthFilter', 'belowOrEqualThreshold4096')],
    [('TimeThresholdFilter', 'meanBelowOrEqualThreshold10')]
]

candidateDatasets = {
    "RX7900GRE": {
        "baseFilters": canonicalFilter + [[('EnvironmentFilter', 'EnvId7')]],
        "additionalFilters": commonAdditionalFilters,
        "trainAdditionalFilters": [
            [('ShadertoyIdFilter', 'trainShdrExprs')]
        ],
        "testAdditionalFilters": [
            [('ShadertoyIdFilter', 'testShdrExprs')]
        ],
        "valAdditionalFilters": [
            [('ShadertoyIdFilter', 'valShdrExprs')]
        ]
    },
    "RX6600XT-Refresh": {
        "baseFilters": canonicalFilter + [[('EnvironmentFilter', 'EnvId8')]],
        "additionalFilters": commonAdditionalFilters,
        "trainAdditionalFilters": [
            [('ShadertoyIdFilter', 'trainShdrExprs')]
        ],
        "testAdditionalFilters": [
            [('ShadertoyIdFilter', 'testShdrExprs')]
        ],
        "valAdditionalFilters": [
            [('ShadertoyIdFilter', 'valShdrExprs')]
        ]
    },
}

for setName, setDesc in candidateDatasets.items():
    outputName = f"FragPerfSnapshotTracedFinalDataset-{setName}-Val-TimeFiltered.dat"
    outputPath = os.path.join(getIntermediateDir(), f"./{outputName}")
    print(f"Dataset {outputName}:")

    # first we go interpretation
    baseResults = snapshotter.evalFilters(setDesc['baseFilters'])
    
    print(f"=> Train Set Filter Interpretation:")
    trainExprIds = snapshotter.interpretFilters(
        set(baseResults), setDesc['additionalFilters'] + setDesc['trainAdditionalFilters']
    )

    print(f"=> Test Set Filter Interpretation:")
    testExprIds = snapshotter.interpretFilters(
        set(baseResults), setDesc['additionalFilters'] + setDesc['testAdditionalFilters']
    )

    print(f"=> Val Set Filter Interpretation:")
    valExprIds = snapshotter.interpretFilters(
        set(baseResults), setDesc['additionalFilters'] + setDesc['valAdditionalFilters']
    )

    if os.path.isfile(outputPath):
        print(f"=> file {outputPath} is already there, verify")

        with open(outputPath, "rb") as f:
            dataDict = pickle.load(f)
            if len(dataDict['train']) != len(trainExprIds):
                raise Exception(f"{outputPath} got train length {len(dataDict['train'])}, but from filters we got {len(trainExprIds)}")
            
            if len(dataDict['test']) != len(testExprIds):
                raise Exception(f"{outputPath} got test length {len(dataDict['test'])}, but from filters we got {len(testExprIds)}")
            
            if len(dataDict['val']) != len(valExprIds):
                raise Exception(f"{outputPath} got val length {len(dataDict['val'])}, but from filters we got {len(valExprIds)}")
        
        del dataDict

    else:
        trainLen, testLen, valLen = snapshotter.doSnapshot(
            outputPath,
            trainGroupFilters=setDesc['baseFilters'] + setDesc['additionalFilters'] + setDesc['trainAdditionalFilters'],
            testGroupFilters=setDesc['baseFilters'] + setDesc['additionalFilters'] + setDesc['testAdditionalFilters'],
            valGroupFilters=setDesc['baseFilters'] + setDesc['additionalFilters'] + setDesc['valAdditionalFilters']
        )
        assert(len(trainExprIds) == trainLen)
        assert(len(testExprIds) == testLen)
        assert(len(valExprIds) == valLen)

        with open(outputPath, "rb") as f:
            file_hash = hashlib.md5()
            chunk = f.read(8192)
            while chunk:
                file_hash.update(chunk)
                chunk = f.read(8192)

        print(f"Hash for {outputPath}:\n- md5sum: {file_hash.hexdigest()}")
        with open(outputPath + ".md5sum", "w") as f:
            f.write(file_hash.hexdigest())

Dataset FragPerfSnapshotTracedFinalDataset-RX7900GRE-Val-TimeFiltered.dat:
=> Train Set Filter Interpretation:
  Base expr id: 20669 items
  Intersect with ResourceFilter_resource1: 14084 items
  After ResourceFilter: 14084 items
  Intersect with ErrorFilter_error0: 14084 items
  After ErrorFilter: 14084 items
  Intersect with TraceAvailabilityFilter_haveTrace: 14080 items
  After TraceAvailabilityFilter: 14080 items
  Intersect with ImageHashFilter_normalHash: 13925 items
  After ImageHashFilter: 13925 items
  Intersect with SpvTokenizedLengthFilter_belowOrEqualThreshold4096: 11360 items
  After SpvTokenizedLengthFilter: 11360 items
  Intersect with TimeThresholdFilter_meanBelowOrEqualThreshold10: 11360 items
  After TimeThresholdFilter: 11360 items
  Intersect with ShadertoyIdFilter_trainShdrExprs: 9049 items
  After ShadertoyIdFilter: 9049 items
  Returning expr id: 9049 items
=> Test Set Filter Interpretation:
  Base expr id: 20669 items
  Intersect with ResourceFilter_resource1: 1