In [2]:
from WGR import *
import numpy as np
from tqdm.notebook import tqdm
from numba import njit

# import emcee

In [3]:
# Licensed under the terms of http://www.apache.org/licenses/LICENSE-2.0
# Author/s (©): Alvaro del Castillo

from math import sqrt

import chunk
import logging

class Voxel:
    def __init__(self, bytes):
        self.x = bytes[0]
        self.z = bytes[1] # swap YZ
        self.y = bytes[2] # swap YZ
        self.color_index = bytes[3] - 1

class Color:
    """ RGBA format palette """
    _color2minecraft = {}  # Cache to convert a color to minecraft color

    def __init__(self, hex_str):
        self.hex_str = hex_str
        self.colorint = int(self.hex_str[0:8], 16)

    def rgb(self):
        red = int(self.hex_str[0:2], 16)
        green = int(self.hex_str[2:4], 16)
        blue = int(self.hex_str[4:6], 16)

        return red, green, blue

    def compare(self, color):
        # https://www.compuphase.com/cmetric.htm
        r1, g1, b1 = self.rgb()
        r2, g2, b2 = color.rgb()

        read_mean = (r1 + r2) / 2
        r = r1 - r2
        g = g1 - g2
        b = b1 - b2
        return sqrt((round((512 + read_mean) * r * r) >> 8) + 4.0 * g * g + (round((767 - read_mean) * b * b) >> 8))

    def minecraft(self):
        # https://gaming.stackexchange.com/questions/47212/what-are-the-color-values-for-dyed-wool
        mc_colors = [
            ("White", "e4e4e4"),
            ("Orange", "ea7e35"),
            ("Magenta", "be49c9"),
            ("Light Blue", "6387d2"),
            ("Yellow", "c2b51c"),
            ("Lime", "39ba2e"),
            ("Pink", "d98199"),
            ("Grey", "414141"),
            ("Light grey", "a0a7a7"),
            ("Cyan", "267191"),
            ("Purple", "7e34bf"),
            ("Blue", "253193"),
            ("Brown", "56331c"),
            ("Green", "364b18"),
            ("Red", "9e2b27"),
            ("Black", "181414")
        ]

        mc_color_number = {}
        mc_colors_hex = {}
        for i in range(0, len(mc_colors)):
            mc_color_number[mc_colors[i][1]] = i
            mc_colors_hex[mc_colors[i][1]] = mc_colors[i][0]

        # Find the closest Minecraft color
        rgb = self.hex_str[0:6]
        dist = float('inf')
        if rgb in mc_colors_hex:
            # Direct mapping
            color = rgb
        elif rgb in Color._color2minecraft:
            # Color already mapped
            color = Color._color2minecraft[rgb]
        else:
            for mc_color in mc_colors_hex:
                cdist = self.compare(Color(mc_color))
                if cdist < dist:
                    dist = cdist
                    color = mc_color
            self._color2minecraft[rgb] = color

        return mc_color_number[color]


class VoxDefaultPalette:
    # Removed first "0x00000000" (it does not appear in MV)
    # Reverse order: ABGR
    palette = [
        "0xffffffff", "0xffccffff", "0xff99ffff", "0xff66ffff", "0xff33ffff", "0xff00ffff", "0xffffccff",
        "0xffccccff",
        "0xff99ccff", "0xff66ccff", "0xff33ccff", "0xff00ccff", "0xffff99ff", "0xffcc99ff", "0xff9999ff",
        "0xff6699ff", "0xff3399ff", "0xff0099ff", "0xffff66ff", "0xffcc66ff", "0xff9966ff", "0xff6666ff", "0xff3366ff",
        "0xff0066ff",
        "0xffff33ff", "0xffcc33ff", "0xff9933ff", "0xff6633ff", "0xff3333ff", "0xff0033ff", "0xffff00ff",
        "0xffcc00ff", "0xff9900ff", "0xff6600ff", "0xff3300ff", "0xff0000ff", "0xffffffcc", "0xffccffcc", "0xff99ffcc",
        "0xff66ffcc",
        "0xff33ffcc", "0xff00ffcc", "0xffffcccc", "0xffcccccc", "0xff99cccc", "0xff66cccc", "0xff33cccc",
        "0xff00cccc", "0xffff99cc", "0xffcc99cc", "0xff9999cc", "0xff6699cc", "0xff3399cc", "0xff0099cc", "0xffff66cc",
        "0xffcc66cc",
        "0xff9966cc", "0xff6666cc", "0xff3366cc", "0xff0066cc", "0xffff33cc", "0xffcc33cc", "0xff9933cc",
        "0xff6633cc", "0xff3333cc", "0xff0033cc", "0xffff00cc", "0xffcc00cc", "0xff9900cc", "0xff6600cc", "0xff3300cc",
        "0xff0000cc",
        "0xffffff99", "0xffccff99", "0xff99ff99", "0xff66ff99", "0xff33ff99", "0xff00ff99", "0xffffcc99",
        "0xffcccc99", "0xff99cc99", "0xff66cc99", "0xff33cc99", "0xff00cc99", "0xffff9999", "0xffcc9999", "0xff999999",
        "0xff669999",
        "0xff339999", "0xff009999", "0xffff6699", "0xffcc6699", "0xff996699", "0xff666699", "0xff336699",
        "0xff006699", "0xffff3399", "0xffcc3399", "0xff993399", "0xff663399", "0xff333399", "0xff003399", "0xffff0099",
        "0xffcc0099",
        "0xff990099", "0xff660099", "0xff330099", "0xff000099", "0xffffff66", "0xffccff66", "0xff99ff66",
        "0xff66ff66", "0xff33ff66", "0xff00ff66", "0xffffcc66", "0xffcccc66", "0xff99cc66", "0xff66cc66", "0xff33cc66",
        "0xff00cc66",
        "0xffff9966", "0xffcc9966", "0xff999966", "0xff669966", "0xff339966", "0xff009966", "0xffff6666",
        "0xffcc6666", "0xff996666", "0xff666666", "0xff336666", "0xff006666", "0xffff3366", "0xffcc3366", "0xff993366",
        "0xff663366",
        "0xff333366", "0xff003366", "0xffff0066", "0xffcc0066", "0xff990066", "0xff660066", "0xff330066",
        "0xff000066", "0xffffff33", "0xffccff33", "0xff99ff33", "0xff66ff33", "0xff33ff33", "0xff00ff33", "0xffffcc33",
        "0xffcccc33",
        "0xff99cc33", "0xff66cc33", "0xff33cc33", "0xff00cc33", "0xffff9933", "0xffcc9933", "0xff999933",
        "0xff669933", "0xff339933", "0xff009933", "0xffff6633", "0xffcc6633", "0xff996633", "0xff666633", "0xff336633",
        "0xff006633",
        "0xffff3333", "0xffcc3333", "0xff993333", "0xff663333", "0xff333333", "0xff003333", "0xffff0033",
        "0xffcc0033", "0xff990033", "0xff660033", "0xff330033", "0xff000033", "0xffffff00", "0xffccff00", "0xff99ff00",
        "0xff66ff00",
        "0xff33ff00", "0xff00ff00", "0xffffcc00", "0xffcccc00", "0xff99cc00", "0xff66cc00", "0xff33cc00",
        "0xff00cc00", "0xffff9900", "0xffcc9900", "0xff999900", "0xff669900", "0xff339900", "0xff009900", "0xffff6600",
        "0xffcc6600",
        "0xff996600", "0xff666600", "0xff336600", "0xff006600", "0xffff3300", "0xffcc3300", "0xff993300",
        "0xff663300", "0xff333300", "0xff003300", "0xffff0000", "0xffcc0000", "0xff990000", "0xff660000", "0xff330000",
        "0xff0000ee",
        "0xff0000dd", "0xff0000bb", "0xff0000aa", "0xff000088", "0xff000077", "0xff000055", "0xff000044",
        "0xff000022", "0xff000011", "0xff00ee00", "0xff00dd00", "0xff00bb00", "0xff00aa00", "0xff008800", "0xff007700",
        "0xff005500",
        "0xff004400", "0xff002200", "0xff001100", "0xffee0000", "0xffdd0000", "0xffbb0000", "0xffaa0000",
        "0xff880000", "0xff770000", "0xff550000", "0xff440000", "0xff220000", "0xff110000", "0xffeeeeee", "0xffdddddd",
        "0xffbbbbbb",
        "0xffaaaaaa", "0xff888888", "0xff777777", "0xff555555", "0xff444444", "0xff222222", "0xff111111"
    ]


class Vox():
    file_path = None
    """ file path for the MagicaVoxel vox file """

    def parse_vox_file(self):
        if not self.file_path:
            RuntimeError("Missing file_path param")

        self.voxels = []
        self.palette = []
        self.materials = []

        # Read the vox data in RIFF format
        # https://github.com/python/cpython/blob/3.8/Lib/chunk.py
        vox_file = open(self.file_path, "rb")
        vox_chunk = chunk.Chunk(vox_file, bigendian=False)
        chunk_name = vox_chunk.getname().decode("utf-8")
        if chunk_name != 'VOX ':
            raise RuntimeError('File %s is not a VOX file' % self.file_path)
        version = vox_chunk.chunksize
        if version != 150:
            raise RuntimeError('File %s has a not supported VOX version %i' % (self.file_path, version))
        # Let's read chunks
        """
        2. Chunk Structure
        -------------------------------------------------------------------------------
        # Bytes  | Type       | Value
        -------------------------------------------------------------------------------
        1x4      | char       | chunk id
        4        | int        | num bytes of chunk content (N)
        4        | int        | num bytes of children chunks (M)      
        N        |            | chunk content
        M        |            | children chunks
        -------------------------------------------------------------------------------
        """
        # MAIN Chunk
        main_chunk = chunk.Chunk(vox_file, bigendian=False)
        # Pass last 4 bytes for MAIN Chunk with children chunks
        vox_file.seek(vox_file.tell() + 4)

        # SIZE CHUNK
        """
        -------------------------------------------------------------------------------
        # Bytes  | Type       | Value
        -------------------------------------------------------------------------------
        4        | int        | size x
        4        | int        | size y
        4        | int        | size z : gravity direction
        -------------------------------------------------------------------------------
        """
        size_chunk = chunk.Chunk(vox_file, bigendian=False)
        vox_file.seek(vox_file.tell() + 4)  # number of children chunks
        x = size_chunk.read(4)
        y = size_chunk.read(4)
        z = size_chunk.read(4)

        # XYZI voxels
        """
        -------------------------------------------------------------------------------
        # Bytes  | Type       | Value
        -------------------------------------------------------------------------------
        4        | int        | numVoxels (N)
        4 x N    | int        | (x, y, z, colorIndex) : 1 byte for each component
        -------------------------------------------------------------------------------
        """
        xyzi_chunk = chunk.Chunk(vox_file, bigendian=False)
        vox_file.seek(vox_file.tell() + 4)  # number of children chunks
        n_voxels_bytes = xyzi_chunk.read(4)
        n_voxels = int.from_bytes(n_voxels_bytes, "little")
        for i in range(0, n_voxels):
            self.voxels.append(Voxel(xyzi_chunk.read(4)))
        # Transform or palette chunk or no more chunk if default palette
        try:
            transform_chunk = chunk.Chunk(vox_file, bigendian=False)
        except EOFError:
            transform_chunk = None
            logging.info("Legacy vox file with default palette")
        if transform_chunk and transform_chunk.chunkname.decode("utf-8") == "nTRN":
            vox_file.seek(vox_file.tell() + 4)  # number of children chunks
            transform_chunk.skip()

            group_chunk = chunk.Chunk(vox_file, bigendian=False)
            vox_file.seek(vox_file.tell() + 4)  # number of children chunks
            group_chunk.skip()

            transform_chunk = chunk.Chunk(vox_file, bigendian=False)
            vox_file.seek(vox_file.tell() + 4)  # number of children chunks
            # transform_chunk.skip()  # it is skipping 1 byte in the next chunk
            vox_file.read(transform_chunk.getsize())

            shape_chunk = chunk.Chunk(vox_file, bigendian=False)
            vox_file.seek(vox_file.tell() + 4)  # number of children chunks
            vox_file.read(shape_chunk.getsize())

            # Layers chunk
            NUM_LAYERS = 8  # By default
            for i in range(0, NUM_LAYERS):
                layer_chunk = chunk.Chunk(vox_file, bigendian=False)
                vox_file.seek(vox_file.tell() + 4)  # number of children chunks
                vox_file.read(layer_chunk.getsize())
            """
            7. Chunk id 'RGBA' : palette
            -------------------------------------------------------------------------------
            # Bytes  | Type       | Value
            -------------------------------------------------------------------------------
            4 x 256  | int        | (R, G, B, A) : 1 byte for each component
                                  | * <NOTICE>
                                  | * color [0-254] are mapped to palette index [1-255], e.g : 
                                  | 
                                  | for ( int i = 0; i <= 254; i++ ) {
                                  |     palette[i + 1] = ReadRGBA(); 
                                  | }
            -------------------------------------------------------------------------------
            """
            rgba_chunk = chunk.Chunk(vox_file, bigendian=False)
            vox_file.seek(vox_file.tell() + 4)  # number of children chunks
        else:
            rgba_chunk = transform_chunk
            vox_file.seek(vox_file.tell() + 4)  # notice
        if rgba_chunk:
            if rgba_chunk.getname().decode("utf-8") != 'RGBA':
                raise RuntimeError('VOX format not supported (multimodel?)')
            for i in range(0, round(rgba_chunk.getsize() / 4)):
                # RGBA
                color_bytes = rgba_chunk.read(1) + rgba_chunk.read(1) + rgba_chunk.read(1) + rgba_chunk.read(1)
                self.palette.append(Color(color_bytes.hex()))
        else:
            # Default palette
            for i in range(0, len(VoxDefaultPalette.palette)):
                color = VoxDefaultPalette.palette[i].replace('0x', '')
                # Convert ABGR to RGBA
                color = color[::-1]
                self.palette.append(Color(color))

        # Read the materials palette
        """
        (4) Material Chunk : "MATL"

        int32	: material id
        DICT	: material properties
                    (_type : str) _diffuse, _metal, _glass, _emit
                    (_weight : float) range 0 ~ 1
                    (_rough : float)
                    (_spec : float)
                    (_ior : float)
                    (_att : float)
                    (_flux : float)
                    (_plastic)
        """
        # One material per each color
        for i in range(0, len(self.palette)):
            try:
                materials_chunk = chunk.Chunk(vox_file, bigendian=False)
                if materials_chunk.getname().decode("utf-8") != 'MATL':
                    logging.info("Material data not found")
                    break
                else:
                    vox_file.seek(vox_file.tell() + 4)  # number of children chunks
                    material_id = int.from_bytes(materials_chunk.read(4), "little")
                    dict_entries_len = int.from_bytes(materials_chunk.read(4), "little")
                    # Read the _type key from dict
                    key_str_len = int.from_bytes(materials_chunk.read(4), "little")
                    key_str = materials_chunk.read(key_str_len).decode('utf-8')
                    value_str_len = int.from_bytes(materials_chunk.read(4), "little")
                    value_str = materials_chunk.read(value_str_len).decode('utf-8')
                    self.materials.append(value_str)
                    materials_chunk.skip()
                    if materials_chunk.tell() > materials_chunk.getsize():
                        vox_file.seek(vox_file.tell() - 1)  # Hack: not sure why skip goes 1 byte more
            except EOFError:
                logging.info("Material data not found")
                break

    def create(self):

        self.parse_vox_file()
        
        minxyz = np.ones(3, dtype = np.int32) * 10000
        maxxyz = np.ones(3, dtype = np.int32) * (-10000)
        
        for voxel in self.voxels:
            vpos = np.array([voxel.x, voxel.y, voxel.z])
            minxyz = np.minimum(minxyz, vpos)
            maxxyz = np.maximum(maxxyz, vpos + np.ones_like(vpos))
        
        self.bounds = (minxyz, maxxyz)
        self.arr = np.zeros(maxxyz - minxyz, dtype = np.uint32)
        print("Got voxels as ndarray: ", end = '')
        print(self.arr.shape)
        
        for voxel in self.voxels:
            vpos = np.array([voxel.x, voxel.y, voxel.z]) - minxyz
            self.arr[vpos[0], vpos[1], vpos[2]] = self.palette[voxel.color_index].colorint

In [9]:
from matplotlib import pyplot as plt
from PIL import Image
im = Image.fromarray(np.uint8(np.random.uniform(size=(1024,1024,3)) * 255))
im.save("random.png")

In [22]:
m = Vox()
m.file_path = "alfheimTest.vox"
m.create()

Got voxels as ndarray: (16, 11, 16)


In [23]:
monu10 = m.arr

In [24]:
temp = np.zeros((16, 18, 16), np.uint32)
temp[:, :11, :] = monu10

np.save("alfheimTest.npy", temp)

In [5]:
xs, ys, zs = monu10.nonzero()
print(xs.shape)

for x, y, z in tqdm(zip(xs, ys, zs)):
    SetBlocks(x, y + 65, z, monu10[x,y,z])
SetBlocks()

(150764,)


0it [00:00, ?it/s]

In [6]:
# Block index
@njit
def jit_CalcBlockIndex(blocks):
    blockCount = [0]
    blockDict = {0: 0}
    for p, b in np.ndenumerate(blocks):
        if b not in blockDict:
            blockDict[b] = len(blockDict)
            blockCount.append(0)
        blockCount[blockDict[b]] += 1
    return blockDict, blockCount

def CalcBlockIndex(blocks):
    bD, bC = jit_CalcBlockIndex(blocks)
    blockDict = {}
    for k in bD:
        blockDict[k] = [bD[k], bC[bD[k]]]
    return blockDict

blockDict = CalcBlockIndex(monu10)
print(blockDict)

{0: [0, 481684], 3165097215: [1, 144148], 1263225855: [2, 204], 4287889663: [3, 84], 1986287359: [4, 5680], 2661576447: [5, 648]}


In [19]:
from itertools import product
from numba.experimental import jitclass
from numba import typed, types, int32

N = 3
step = 1

# For building slice weights
@jitclass([('slices', types.DictType(types.UniTuple(types.uint32, 27), types.int32)),
           ('totalNum', int32),
           ('blockDict', types.DictType(types.uint32, types.int32))])
class SliceWeights:
    
    def __init__(self, blockDict):
        self.slices = typed.Dict.empty(key_type = types.UniTuple(), value_type = types.int32)
        self.totalNum = 0
        self.blockDict = blockDict
#         self.bDTotal = sum([b[1] for b in self.blockDict.values()])
    
    def NdToStr(self, arr):
        # return "".join([chr(self.blockDict[c]) for c in arr.flat])
        return arr.tobytes()
    
    def Add(self, sstr):
        if sstr not in self.slices:
            self.slices[sstr] = 0
        self.slices[sstr] += 1
        self.totalNum += 1
        
    # def AddNd(self, arr):
    #     self.Add(self.NdToStr(arr))
        
    def Prob(self, sstr, eps = 0.01):
        if sstr not in self.slices:
            return eps
        return self.slices[sstr] / self.totalNum
    
    # def ProbNd(self, arr, eps = 0.01):
    #     return self.Prob(self.NdToStr(arr), eps)
    
class SliceWeights:
    
    def __init__(self):
        self.slices = {}
        self.totalNum = 0
    
    def NdToStr(self, arr):
        return arr.tobytes()
    
    def Add(self, sstr):
        if sstr not in self.slices:
            self.slices[sstr] = 0
        self.slices[sstr] += 1
        self.totalNum += 1
        
    def Prob(self, sstr, eps = 0.01):
        if sstr not in self.slices:
            return eps
        return self.slices[sstr] / self.totalNum
    
def RandomBlocks(blockDict, size):
    bDTotal = sum([b[1] for b in blockDict.values()])
    return np.random.choice(
        np.array([b[0] for b in blockDict.values()], dtype = np.uint8),
        size = size,
        replace = True,
        p = np.array([b[1] / bDTotal for b in blockDict.values()])
    )

In [20]:
import numba as nb

a = monu10[0:3, 0:3, 0:3]
nb.typeof(tuple(a.flatten()))

UniTuple(uint32 x 27)

In [21]:
import time

jit_blockDict = typed.Dict.empty(types.uint32, types.int32)
for k in blockDict:
    jit_blockDict[k] = blockDict[k][0]

sW = SliceWeights(jit_blockDict)

sW = SliceWeights()

# np slices
approxTotal = [(monu10.shape[i] - N) // step for i in range(3)]
print(approxTotal[0] * approxTotal[1] * approxTotal[2])

start = time.perf_counter()
for x, y, z in product(*[range(0, monu10.shape[i] - N, step) for i in range(3)]):
    nSlice = monu10[x:x+N, y:y+N, z:z+N]
    
    # Flips only rn
    for dx, dz in product(*[[-1, 1] for i in range(2)]):
        f = tuple(nSlice[::dx, :, ::dz].flatten())
        sW.Add(f)
        
end = time.perf_counter()
print("Finished in %f sec" % (end - start))

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1m[1mFailed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1mInvalid use of <class 'numba.core.types.containers.UniTuple'> with parameters ()
No type info available for <class 'numba.core.types.containers.UniTuple'> as a callable.[0m
[0m[1mDuring: resolving callee type: typeref[<class 'numba.core.types.containers.UniTuple'>][0m
[0m[1mDuring: typing of call at C:\Users\betairya\AppData\Local\Temp/ipykernel_74348/2564548238.py (15)
[0m
[1m
File "C:\Users\betairya\AppData\Local\Temp\ipykernel_74348\2564548238.py", line 15:[0m
[1m<source missing, REPL/exec in use?>[0m

[0m[1mDuring: resolving callee type: jitclass.SliceWeights#24238818d00<slices:DictType[UniTuple(uint32 x 27),int32]<iv=None>,totalNum:int32,blockDict:DictType[uint32,int32]<iv=None>>[0m
[0m[1mDuring: typing of call at <string> (3)
[0m
[0m[1mDuring: resolving callee type: jitclass.SliceWeights#24238818d00<slices:DictType[UniTuple(uint32 x 27),int32]<iv=None>,totalNum:int32,blockDict:DictType[uint32,int32]<iv=None>>[0m
[0m[1mDuring: typing of call at <string> (3)
[0m
[1m
File "<string>", line 3:[0m
[1m<source missing, REPL/exec in use?>[0m


In [49]:
print(sW.slices)

{  : 1,   : 1,   : 1,   : 1,   : 2,   : 2,   : 2,   : 2,    : 12,    : 12,   : 4,   : 4,   : 4,   : 4,  : 4,  : 4,  : 4,  : 4, : 394776,     : 1,     : 1,     : 1,     : 1,     : 2,     : 2,     : 2,     : 2,       : 12,       : 12,     : 4,     : 4

In [9]:
print(len(sW.slices))

airSlice = np.ones((N,N,N), dtype = np.uint8) * 0
print("All air: %d" % sW.slices["".join([chr(c) for c in airSlice.flat])])
sW.slices["".join([chr(c) for c in airSlice.flat])] = 500
print("Set to 500")

solidSlice = np.ones((N,N,N), dtype = np.uint8) * 1
print("All solid: %d" % sW.slices["".join([chr(c) for c in solidSlice.flat])])
sW.slices["".join([chr(c) for c in airSlice.flat])] = 400
print("Set to 400")

10782
All air: 1529480
Set to 500
All solid: 394776
Set to 400


In [30]:
def RefreshNd(arr, pos, full = False):
    if full:
        for p, b in tqdm(np.ndenumerate(arr), total = arr.size):
            SetBlocks(p[0] + pos[0], p[1] + pos[1], p[2] + pos[2], b)
    else:
        xs, ys, zs = arr.nonzero()
        print(xs.shape)
        for x, y, z in tqdm(zip(xs, ys, zs)):
            SetBlocks(x + pos[0], y + pos[1], z + pos[2], arr[x,y,z])
        
def IdxToColors(arr, blockDict):
    invMap = np.zeros(256, dtype = np.uint32)
    for k in blockDict:
        v = blockDict[k]
        invMap[v[0]] = k
    return invMap[arr]

In [44]:
# Rough mcmc
nIters = 1200000
blockMap = np.array([blockDict])
size = 16

# Initial
mcmc = RandomBlocks(blockDict, (size, size, size))

In [45]:
# Check what we've got
ox = 96
oy = 65
oz = 0

RefreshNd(IdxToColors(mcmc, blockDict), [ox, oy, oz], True)

  0%|          | 0/4096 [00:00<?, ?it/s]

In [46]:
# Get block distribution
bDistGT = np.zeros(len(sW.blockDict))
bDistLambda = 1e4

for k in blockDict:
    bDistGT[blockDict[k][0]] = blockDict[k][1]

bDistGT = bDistGT / bDistGT.sum()
print(bDistGT)

import ot

@njit
def KLDiv(p, q):
    kldiv = p * np.log(np.maximum(p / np.maximum(q, 1e-10), 1e-10))
    return p.sum()

@njit
def JSDiv(p, q):
    m = p + q / 2
    return KLDiv(p, m) / 2 + KLDiv(q, m) / 2

bDistRaw = np.array([(mcmc == i).sum() for i in range(len(sW.blockDict))])
bDist = bDistRaw / (size ** 3)
print(bDist)

wd = ot.emd2(bDist, bDistGT, 1 - np.eye(len(blockDict)))

print("Initial JSDiv = %f" % JSDiv(bDist, bDistGT))
print("Initial Wass1 = %f" % (wd * bDistLambda))

[7.61618346e-01 2.27920714e-01 3.22556163e-04 1.32817243e-04
 8.98097551e-03 1.02459016e-03]
[7.37304688e-01 2.50000000e-01 4.88281250e-04 2.44140625e-04
 1.00097656e-02 1.95312500e-03]
Initial JSDiv = 1.000000
Initial Wass1 = 243.136590


In [None]:
# E
E = 0
eps = 2e-5
for x, y, z in product(*[range(0, mcmc.shape[i] - N, 1) for i in range(3)]):
    nSlice = mcmc[x:x+N, y:y+N, z:z+N]
    E += np.log(sW.Prob("".join([chr(c) for c in nSlice.flat]), eps))

wd = ot.emd2(bDist, bDistGT, 1 - np.eye(len(blockDict)))
wd_prev = wd
E -= wd * bDistLambda

print(E)

bDTotal = sum([b[1] for b in blockDict.values()])
choice = np.array([b[0] for b in blockDict.values()], dtype = np.uint8)
weight = np.array([b[1] / bDTotal for b in blockDict.values()])

rejected = 0

@njit
def rand_choice_nb(arr, prob):
    """
    :param arr: A 1D numpy array of values to sample from.
    :param prob: A 1D numpy array of probabilities for the given samples. *** NEED TO BE NORMALIZED ***
    :return: A random sample from the given array with a given probability.
    """
    return arr[np.searchsorted(np.cumsum(prob), np.random.random(), side="right")]
    # return np.random.choice(arr, size = (1,), replace = True, p = prob)[0]

@njit
def ProposeAndCalcPatterns(mcmc, choice, weight, size, eps, sW):
    
    # Proposal
    pos = np.random.randint(0, size, size = (3,))
    blk = rand_choice_nb(choice, weight)

    # Original
    original = mcmc[pos[0], pos[1], pos[2]]
    
    dE = 0    
    # Calculate dE = Enew - Eold
    for x in range(max(0, pos[0] - N + 1)):
        for y in range(max(0, pos[1] - N + 1)):
            for z in range(max(0, pos[2] - N + 1)):
                nSlice = mcmc[x:x+N, y:y+N, z:z+N]
                dE -= np.log(sW.Prob("".join([chr(c) for c in nSlice.flat]), eps))
    
    # Set
    mcmc[pos[0], pos[1], pos[2]] = blk
    
    # Calculate dE
    for x in range(max(0, pos[0] - N + 1)):
        for y in range(max(0, pos[1] - N + 1)):
            for z in range(max(0, pos[2] - N + 1)):
                nSlice = mcmc[x:x+N, y:y+N, z:z+N]
                dE += np.log(sW.Prob("".join([chr(c) for c in nSlice.flat]), eps))
    
    return pos, blk, original, dE

startTime = None

# MCMC
for i in range(nIters):
    
    pos, blk, original, dE = ProposeAndCalcPatterns(mcmc, choice, weight, size, eps, sW)
    
#     # Proposal
#     pos = np.random.randint(0, size, size = (3,))
#     blk = np.random.choice(choice, size = (1,), replace = True, p = weight)[0]

#     # Original
#     original = mcmc[pos[0], pos[1], pos[2]]
    
#     dE = 0
#     for x, y, z in product(*[range(max(0, pos[di] - N + 1)) for di in range(3)]):
#         nSlice = mcmc[x:x+N, y:y+N, z:z+N]
#         dE -= np.log(sW.Prob("".join([chr(c) for c in nSlice.flat]), eps))
    
#     # Set
#     mcmc[pos[0], pos[1], pos[2]] = blk
    
#     # Calculate dE
#     for x, y, z in product(*[range(max(0, pos[di] - N + 1)) for di in range(3)]):
#         nSlice = mcmc[x:x+N, y:y+N, z:z+N]
#         dE += np.log(sW.Prob("".join([chr(c) for c in nSlice.flat]), eps))

    # Block dist
    bDistRaw_new = np.copy(bDistRaw)
    bDistRaw_new[original] -= 1
    bDistRaw_new[blk] += 1
    bDist_new = bDistRaw_new / (size ** 3)
    wd_new = ot.emd2(bDist_new, bDistGT, 1 - np.eye(len(sW.blockDict)))
    dE += ( - wd_new + wd_prev) * bDistLambda
    
    # MH step
    # Find accept / reject when dE < 0 (alpha < 1)
    accepted = True
    
    if dE < 0:
        alpha = np.exp(dE)
        if np.random.uniform() > alpha: # Reject
            mcmc[pos[0], pos[1], pos[2]] = original
            rejected += 1
            accepted = False
    
    if accepted:
        E += dE
        wd_prev = wd_new
        bDistRaw = bDistRaw_new
        SetBlocks(ox + pos[0], oy + pos[1], oz + pos[2], IdxToColors(mcmc[pos[0], pos[1], pos[2]], blockDict))
    
    if i % (nIters // 100) == 0:
        if startTime is None:
            startTime = time.perf_counter() - 0.001
        T = time.perf_counter()
        print("#%8d >> E = %+11.3f | Rejected: %8d [ Wasserstein-1 %9.3f ]    %8.2f it / sec" % (i, E, rejected, wd_prev * bDistLambda, (i - nIters // 100) / (T - startTime)))

-24018.55008061815
#       0 >> E =  -24018.550 | Rejected:        0 [ Wasserstein-1   243.137 ]    -11997600.48 it / sec
#   12000 >> E =  -23778.006 | Rejected:     2303 [ Wasserstein-1     2.593 ]        0.00 it / sec
#   24000 >> E =  -23781.823 | Rejected:     4628 [ Wasserstein-1     6.409 ]      374.61 it / sec
#   36000 >> E =  -23779.773 | Rejected:     6884 [ Wasserstein-1     4.359 ]      501.31 it / sec
#   48000 >> E =  -23780.667 | Rejected:     9137 [ Wasserstein-1     5.254 ]      560.72 it / sec
#   60000 >> E =  -23779.188 | Rejected:    11451 [ Wasserstein-1     3.775 ]      597.33 it / sec
#   72000 >> E =  -23779.492 | Rejected:    13757 [ Wasserstein-1     4.078 ]      625.47 it / sec
#   84000 >> E =  -23781.823 | Rejected:    16076 [ Wasserstein-1     6.409 ]      643.24 it / sec
#   96000 >> E =  -23780.558 | Rejected:    18456 [ Wasserstein-1     5.145 ]      658.29 it / sec
#  108000 >> E =  -23779.663 | Rejected:    20781 [ Wasserstein-1     4.250 ]      669

In [1]:
!pip install pyvox

ERROR: Could not find a version that satisfies the requirement pyvox (from versions: none)
ERROR: No matching distribution found for pyvox


In [4]:
# Save Snapshot as *.vox

from pyvox.models import Vox, Color
from pyvox.writer import VoxWriter
from datetime import datetime

ox = 0
oy = 65
oz = 0
size = 255
sizey = 40

bD = {}

a = GetBlockWithin(ox, oy, oz, ox + size, oy + sizey, oz + size)[:, ::-1, ::-1]
b = np.zeros_like(a, dtype = 'B')

for p, c in np.ndenumerate(a):
    if c == 0:
        continue
    if c not in bD:
        bD[c] = len(bD)
    b[p[0], p[1], p[2]] = bD[c] + 1

pal = [Color(i // (256 ** 3), (i // (256 ** 2)) % 256, (i // 256) % 256, i % 256) for i in bD] + [Color(255, 0, 255, 255)]
# np.save("Snapshots/Pattern3+W1_%s.npy" % datetime.now().strftime("%Y%m%d-%H%M%S"), a)
vox = Vox.from_dense(b)
vox.palette = pal
VoxWriter('../../Julia/Snapshots/GATest-BlockMargin-PathConnect%s.vox' % datetime.now().strftime("%Y%m%d-%H%M%S"), vox).write()

In [33]:
a = np.array([1,2,3])
b = np.array([1,2,3])

d = {}
d[a.flatten().tobytes()] = 1
d[b.flatten().tobytes()]

1