In [1]:
%load_ext autoreload
# %reload_ext autoreload

# Reload all modules imported with %aimport every time before executing the Python code typed
%autoreload 1

%aimport context_nn 
%aimport phrase_feeder 
%aimport notes
%aimport watch_point
%aimport cluster
import numpy as np
from bitarray import bitarray
from context_nn import ContextNN
from watch_point import WatchPoint
from cluster import Cluster
from phrase_feeder import PhraseFeeder
from pprint import pprint
import math
import constants as const

In [8]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=10, 
                        active_bits=8, 
                        bit_count=255)

In [2]:
import pickle

def load_phrase_base(file_name: str) -> (dict, list):
    with open(file_name, 'rb') as f:
        data = pickle.load(f)
        phrase_base = data.get('phrase_base', {})
        marks = data.get('marks', {})
        return phrase_base, marks

In [3]:
phrase_base, marks = load_phrase_base('./data/texts/phrase_base.pickle')

In [6]:
from ipythonblocks import BlockGrid

def draw_notation(notation: np.array):
    bit_grid = BlockGrid(len(notation), 1, fill=(17, 41, 129))
    for block in range(bit_grid.width):
        color = bit_grid[0, block]
        if notation[block]:
            bit_grid[0, block] = (244, 195, 173)
    bit_grid.lines_on = False
    bit_grid.show()

def draw_note_notations(notes, note_idx: int):
    for i in range(len(notes[note_idx])):
        draw_notation(notes.note_notation_as_bits(note_idx, i))

In [7]:
key = list(phrase_base.keys())[0]
phrase = phrase_base[key][0]
bit_chord = abc_notes.phrase_chord(phrase)
draw_notation(abc_notes.notation_as_bits(bit_chord))
print(phrase)
for note_idx, notation_idx in phrase:
    notation = abc_notes.note_notation_as_bits(note_idx, notation_idx)
    draw_notation(notation)

[(8, 0), (0, 1), (13, 2), (24, 3), (18, 4)]


In [4]:
feeder = PhraseFeeder(phrase_base, marks)    

In [None]:
phrases, output_bits = feeder.take_phrase_bunch()
print(output_bits, marks[output_bits])
print(phrases)

In [4]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=10, 
                        active_bits=8, 
                        bit_count=255)

feeder = phrase_feeder.PhraseFeeder(phrase_base, marks)  

cxnn = context_nn.ContextNN(input_bit_count=255,
                            output_bit_count=8,
                            watch_point_count=100,
                            watch_bit_count=32,
                            cluster_make_threshold=6,
                            cluster_activate_threshold=4)

In [99]:
len(cxnn.watch_points.keys())

100

In [8]:
%%time
phrases, output_bits = feeder.take_phrase_batch(count=200)
output_bits = set(output_bits)
for phrase in phrases:
    bit_chord = abc_notes.phrase_chord(phrase)
    cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)
    

Wall time: 3.11 s


In [5]:
def feed_phrase_batch(feeder, count=200):
    phrases, output_bits = feeder.take_phrase_batch(count=200)
    output_bits = set(output_bits)
    for phrase in phrases:
        bit_chord = abc_notes.phrase_chord(phrase)
        cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)

def feed_n_batches(feeder, n=10, batch_size=200):
    for i in range(n):
        feed_phrase_batch(feeder, count=batch_size)
        print(f'batch {i+1}/{n}')
    print(f'clusters: {cxnn.cluster_count()}')
        
def show_point_stats(cxnn: ContextNN):
    cluster_counts = [wp.cluster_count() for wp in cxnn.watch_points.values()]
    output_bits = [(wp.output_bit, len(wp.cluster_objects)) 
                   for wp in cxnn.watch_points.values()]
    print('Cluster count:', sum(cluster_counts))
    print(output_bits)
  

In [None]:
cxnn.state = const.STATE_LEARN

In [35]:
cxnn.state = const.STATE_CONSOLIDATE

In [11]:
%%time
feed_n_batches(feeder, n=30, batch_size=200)

batch 1/30
batch 2/30
batch 3/30
batch 4/30
batch 5/30
batch 6/30
batch 7/30
batch 8/30
batch 9/30
batch 10/30
batch 11/30
batch 12/30
batch 13/30
batch 14/30
batch 15/30
batch 16/30
batch 17/30
batch 18/30
batch 19/30
batch 20/30
batch 21/30
batch 22/30
batch 23/30
batch 24/30
batch 25/30
batch 26/30
batch 27/30
batch 28/30
batch 29/30
batch 30/30
clusters: 117674
Wall time: 1min 33s


In [14]:
%%time
cxnn.reduce_clusters(min_component=0.1, min_activations=10)

Wall time: 4.19 s


In [15]:
show_point_stats(cxnn)
marks

Cluster count: 103477
[(7, 910), (4, 564), (1, 1458), (0, 892), (4, 611), (6, 1069), (7, 1078), (4, 513), (7, 753), (4, 842), (0, 1119), (4, 435), (5, 877), (4, 777), (0, 1103), (5, 1289), (5, 1025), (7, 1066), (7, 692), (0, 925), (4, 1124), (4, 1206), (3, 773), (6, 814), (3, 999), (4, 698), (3, 817), (2, 718), (3, 1170), (6, 1073), (2, 714), (1, 1183), (3, 1088), (0, 1106), (1, 1339), (1, 1389), (3, 1233), (4, 1081), (6, 1593), (2, 665), (4, 647), (7, 1244), (2, 519), (0, 864), (7, 853), (3, 1139), (6, 1703), (1, 1770), (2, 630), (6, 1648), (6, 1963), (0, 832), (2, 785), (0, 936), (4, 586), (7, 974), (6, 1528), (1, 1317), (7, 986), (1, 1195), (3, 1575), (0, 832), (0, 1056), (4, 1022), (4, 757), (4, 972), (5, 993), (1, 1485), (6, 1767), (5, 724), (2, 734), (1, 1336), (0, 1181), (5, 1224), (3, 933), (5, 847), (1, 1024), (6, 1341), (4, 1269), (6, 1363), (5, 1130), (2, 673), (6, 1531), (2, 696), (3, 1041), (6, 1329), (5, 758), (6, 1525), (7, 747), (1, 1045), (5, 915), (2, 730), (5, 1086),

{(0, 1, 3, 5): 'ukr',
 (0, 1, 6, 7): 'blg',
 (0, 3, 5, 6): 'epo',
 (0, 5, 6, 7): 'bel',
 (1, 2, 3, 6): 'pol',
 (1, 2, 4, 7): 'jbo',
 (1, 4, 6, 7): 'eng',
 (2, 3, 4, 5): 'rus'}

In [13]:
def dict_value(d: dict, n=0) -> object:
    i = 0
    diter = iter(d)
    while i <= n:
        key = next(diter)
        i += 1
    return d[key]

point = dict_value(cxnn.watch_points, n=0)
cluster = point.cluster_objects[0]
pprint(cluster.stats)
pprint(cluster.component_stats())
# pprint(cluster.bit_rate())

print(sorted(cluster.component_stats().items(), key=lambda x: x[1], reverse=True))
print(point.cluster_masks)
print(tuple(sorted(cluster.bits)))
print(point.watch_bits)
print(point.cluster_masks[0])
print(point.cluster_masks[1])
print(len(point.cluster_masks), len(point.cluster_objects))

# parts = cluster.component_stats()
# enthropy = 0.0
# for part in parts.values():
#     enthropy += -part * math.log2(part)
# print(enthropy)

{(64, 103, 184, 253): 3,
 (64, 154, 168, 253): 12,
 (103, 154, 168, 184): 3,
 (103, 154, 184, 253): 1,
 (103, 168, 184, 253): 1,
 (154, 168, 184, 253): 1}
{(64, 103, 184, 253): 0.14285714285714285,
 (64, 154, 168, 253): 0.5714285714285714,
 (103, 154, 168, 184): 0.14285714285714285,
 (103, 154, 184, 253): 0.047619047619047616,
 (103, 168, 184, 253): 0.047619047619047616,
 (154, 168, 184, 253): 0.047619047619047616}
[((64, 154, 168, 253), 0.5714285714285714), ((103, 154, 168, 184), 0.14285714285714285), ((64, 103, 184, 253), 0.14285714285714285), ((103, 154, 184, 253), 0.047619047619047616), ((154, 168, 184, 253), 0.047619047619047616), ((103, 168, 184, 253), 0.047619047619047616)]
[[0 0 0 ..., 0 0 1]
 [0 1 0 ..., 0 1 0]
 [0 1 0 ..., 0 0 1]
 ..., 
 [0 0 0 ..., 1 0 1]
 [0 1 0 ..., 0 1 0]
 [0 0 1 ..., 1 0 1]]
(64, 103, 154, 168, 184, 253)
(0, 10, 25, 41, 64, 70, 83, 90, 101, 103, 111, 113, 121, 126, 133, 137, 141, 154, 158, 163, 168, 171, 180, 184, 196, 212, 215, 223, 226, 233, 248, 253)


In [34]:
noize_clusters = 0
total_clusters = 0
cluster_lens = {}
cluster_acts = {}
for wp in cxnn.watch_points.values():
    for cluster in wp.cluster_objects:
        total_clusters += 1
        cluster_lens[len(cluster.bits)] = cluster_lens.get(len(cluster.bits), 0) + 1
        cluster_acts[len(cluster.stats)] = cluster_acts.get(len(cluster.stats), 0) + 1
        if not cluster.has_big_component(threshold=0.15, min_activations=15):
            noize_clusters += 1
print(total_clusters)
print(noize_clusters)
pprint(cluster_lens)
pprint(cluster_acts)


204808
128930
{6: 83104,
 7: 62789,
 8: 35167,
 9: 15740,
 10: 5725,
 11: 1728,
 12: 418,
 13: 115,
 14: 18,
 15: 4}
{0: 233,
 1: 920,
 2: 1898,
 3: 2453,
 4: 2945,
 5: 3461,
 6: 4149,
 7: 4728,
 8: 5450,
 9: 6102,
 10: 6728,
 11: 7249,
 12: 7829,
 13: 7975,
 14: 7748,
 15: 7600,
 16: 7353,
 17: 6728,
 18: 5932,
 19: 4988,
 20: 4022,
 21: 3327,
 22: 2657,
 23: 2248,
 24: 2369,
 25: 2460,
 26: 2476,
 27: 2458,
 28: 2456,
 29: 2584,
 30: 2539,
 31: 2564,
 32: 2472,
 33: 2483,
 34: 2292,
 35: 2251,
 36: 2231,
 37: 2072,
 38: 2134,
 39: 1971,
 40: 1849,
 41: 1760,
 42: 1678,
 43: 1543,
 44: 1481,
 45: 1411,
 46: 1212,
 47: 1198,
 48: 1115,
 49: 1077,
 50: 989,
 51: 906,
 52: 857,
 53: 869,
 54: 804,
 55: 804,
 56: 734,
 57: 751,
 58: 674,
 59: 691,
 60: 658,
 61: 676,
 62: 659,
 63: 621,
 64: 631,
 65: 645,
 66: 608,
 67: 624,
 68: 623,
 69: 621,
 70: 611,
 71: 561,
 72: 536,
 73: 555,
 74: 488,
 75: 464,
 76: 492,
 77: 502,
 78: 459,
 79: 433,
 80: 448,
 81: 457,
 82: 444,
 83: 398,
 84: 

In [26]:
my_cluster = Cluster((10, 20, 30), bitarray(10), 4)
print(type(my_cluster))
my_cluster.bits
my_cluster.stat_parts()

<class 'cluster.Cluster'>


{}