In [1]:
%load_ext autoreload
# %reload_ext autoreload

# Reload all modules imported with %aimport every time before executing the Python code typed
%autoreload 1

%aimport context_nn 
%aimport phrase_feeder 
%aimport notes
%aimport watch_point
%aimport cluster
import numpy as np
from bitarray import bitarray
from context_nn import ContextNN
from watch_point import WatchPoint
from cluster import Cluster
from phrase_feeder import PhraseFeeder
from pprint import pprint
import math
import constants as const

In [8]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=10, 
                        active_bits=8, 
                        bit_count=255)

In [2]:
import pickle

def load_phrase_base(file_name: str) -> (dict, list):
    with open(file_name, 'rb') as f:
        data = pickle.load(f)
        phrase_base = data.get('phrase_base', {})
        marks = data.get('marks', {})
        return phrase_base, marks

In [3]:
phrase_base, marks = load_phrase_base('./data/texts/phrase_base.pickle')

In [6]:
from ipythonblocks import BlockGrid

def draw_notation(notation: np.array):
    bit_grid = BlockGrid(len(notation), 1, fill=(17, 41, 129))
    for block in range(bit_grid.width):
        color = bit_grid[0, block]
        if notation[block]:
            bit_grid[0, block] = (244, 195, 173)
    bit_grid.lines_on = False
    bit_grid.show()

def draw_note_notations(notes, note_idx: int):
    for i in range(len(notes[note_idx])):
        draw_notation(notes.note_notation_as_bits(note_idx, i))

In [7]:
key = list(phrase_base.keys())[0]
phrase = phrase_base[key][0]
bit_chord = abc_notes.phrase_chord(phrase)
draw_notation(abc_notes.notation_as_bits(bit_chord))
print(phrase)
for note_idx, notation_idx in phrase:
    notation = abc_notes.note_notation_as_bits(note_idx, notation_idx)
    draw_notation(notation)

[(8, 0), (0, 1), (13, 2), (24, 3), (18, 4)]


In [96]:
feeder = PhraseFeeder(phrase_base, marks)    

In [None]:
phrases, output_bits = feeder.take_phrase_bunch()
print(output_bits, marks[output_bits])
print(phrases)

In [24]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=10, 
                        active_bits=8, 
                        bit_count=255)

feeder = phrase_feeder.PhraseFeeder(phrase_base, marks)  

cxnn = context_nn.ContextNN(input_bit_count=255,
                            output_bit_count=8,
                            watch_point_count=100,
                            watch_bit_count=32,
                            cluster_make_threshold=6,
                            cluster_activate_threshold=4)

In [99]:
len(cxnn.watch_points.keys())

100

In [8]:
%%time
phrases, output_bits = feeder.take_phrase_batch(count=200)
output_bits = set(output_bits)
for phrase in phrases:
    bit_chord = abc_notes.phrase_chord(phrase)
    cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)
    

Wall time: 3.11 s


In [11]:
def feed_phrase_batch(feeder, count=200):
    phrases, output_bits = feeder.take_phrase_batch(count=200)
    output_bits = set(output_bits)
    for phrase in phrases:
        bit_chord = abc_notes.phrase_chord(phrase)
        cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)

def feed_n_batches(feeder, n=10, batch_size=200):
    for i in range(n):
        feed_phrase_batch(feeder, count=batch_size)
        print(f'batch {i+1}/{n}')
    print(f'clusters: {cxnn.cluster_count()}')
        
def show_point_stats(cxnn: ContextNN):
    cluster_counts = [wp.cluster_count() for wp in cxnn.watch_points.values()]
    output_bits = [(wp.output_bit, len(wp.cluster_objects)) 
                   for wp in cxnn.watch_points.values()]
    print('Cluster count:', sum(cluster_counts))
    print(output_bits)
  

In [None]:
cxnn.state = const.STATE_LEARN

In [35]:
cxnn.state = const.STATE_CONSOLIDATE

In [36]:
%%time
feed_n_batches(feeder, n=10, batch_size=200)

batch 1/10
batch 2/10
batch 3/10
batch 4/10
batch 5/10
batch 6/10
batch 7/10
batch 8/10
batch 9/10
batch 10/10
clusters: 200134
Wall time: 34.5 s


In [37]:
show_point_stats(cxnn)
marks

Cluster count: 200134
[(6, 1529), (2, 1970), (4, 2136), (3, 1518), (6, 2049), (7, 1563), (4, 2203), (2, 2984), (4, 2061), (2, 3229), (1, 1309), (3, 1650), (5, 2665), (6, 1370), (4, 2325), (5, 3085), (2, 2516), (4, 2382), (4, 1507), (5, 2553), (5, 2231), (1, 1589), (7, 1179), (7, 1007), (4, 2355), (5, 2760), (7, 1204), (7, 937), (6, 1285), (0, 2287), (2, 2145), (7, 1273), (6, 1849), (3, 3270), (0, 2512), (1, 1486), (6, 1595), (2, 2930), (7, 732), (4, 1813), (6, 1354), (6, 1836), (5, 2166), (3, 2739), (4, 1666), (1, 1776), (7, 1583), (3, 2329), (4, 1533), (6, 1225), (4, 2462), (1, 2137), (7, 1571), (6, 1327), (7, 1340), (5, 3256), (0, 2510), (0, 1991), (2, 2107), (4, 2041), (5, 2480), (4, 1538), (2, 2112), (1, 1277), (6, 1417), (3, 2614), (2, 1796), (4, 1965), (1, 1370), (4, 2460), (1, 1223), (4, 1783), (2, 2197), (6, 1491), (3, 2982), (3, 2322), (0, 2616), (6, 1419), (0, 2199), (2, 2806), (0, 2342), (5, 2496), (7, 1304), (6, 2109), (2, 2590), (4, 1924), (7, 1111), (5, 2020), (1, 2126), 

{(0, 1, 3, 6): 'jbo',
 (0, 2, 3, 5): 'rus',
 (0, 2, 3, 7): 'bel',
 (0, 2, 4, 6): 'blg',
 (1, 3, 4, 5): 'pol',
 (1, 4, 5, 7): 'eng',
 (2, 3, 5, 7): 'epo',
 (2, 4, 5, 6): 'ukr'}

In [15]:
def dict_value(d: dict, n=0) -> any:
    i = 0
    diter = iter(d)
    while i <= n:
        key = next(diter)
        i += 1
    return d[key]

In [41]:
point = dict_value(cxnn.watch_points)
cluster = point.cluster_objects[7]
# pprint(cluster.stats)
# pprint(cluster.component_stats())
pprint(cluster.bit_rate())

sorted(cluster.component_stats().items(), key=lambda x: x[1], reverse=True)

# parts = cluster.component_stats()
# enthropy = 0.0
# for part in parts.values():
#     enthropy += -part * math.log2(part)
# print(enthropy)

(array([ 0.70469799,  0.75167785,  0.97315436,  0.34228188,  1.        ,
        0.89932886]),
 [7, 77, 83, 102, 127, 146])


[((77, 83, 127, 146), 0.2619047619047619),
 ((7, 83, 127, 146), 0.19642857142857142),
 ((7, 77, 83, 127), 0.09523809523809523),
 ((7, 77, 83, 127, 146), 0.07738095238095238),
 ((7, 83, 102, 127), 0.05952380952380952),
 ((77, 102, 127, 146), 0.05357142857142857),
 ((7, 77, 83, 146), 0.03571428571428571),
 ((7, 77, 127, 146), 0.02976190476190476),
 ((7, 77, 83, 102), 0.02976190476190476),
 ((7, 102, 127, 146), 0.02976190476190476),
 ((77, 83, 102, 146), 0.023809523809523808),
 ((83, 102, 127, 146), 0.023809523809523808),
 ((7, 77, 83, 102, 127, 146), 0.017857142857142856),
 ((7, 83, 102, 127, 146), 0.017857142857142856),
 ((7, 77, 102, 146), 0.017857142857142856),
 ((77, 83, 102, 127, 146), 0.005952380952380952),
 ((77, 83, 102, 127), 0.005952380952380952),
 ((7, 83, 102, 146), 0.005952380952380952),
 ((7, 77, 102, 127), 0.005952380952380952),
 ((7, 77, 83, 102, 127), 0.005952380952380952)]

In [45]:
noize_clusters = 0
total_clusters = 0
cluster_lens = {}
cluster_acts = {}
for wp in cxnn.watch_points.values():
    for cluster in wp.cluster_objects:
        total_clusters += 1
        cluster_lens[len(cluster.bits)] = cluster_lens.get(len(cluster.bits), 0) + 1
        cluster_acts[len(cluster.stats)] = cluster_acts.get(len(cluster.stats), 0) + 1
        if not cluster.has_big_component(threshold=0.1, min_activations=15):
            noize_clusters += 1
print(total_clusters)
print(noize_clusters)
pprint(cluster_lens)
pprint(cluster_acts)


200134
82945
{6: 81220,
 7: 61724,
 8: 34821,
 9: 15222,
 10: 5229,
 11: 1517,
 12: 309,
 13: 79,
 14: 13}
{0: 2,
 1: 63,
 2: 246,
 3: 656,
 4: 1214,
 5: 1790,
 6: 2601,
 7: 3444,
 8: 4183,
 9: 5195,
 10: 5807,
 11: 6743,
 12: 7298,
 13: 7628,
 14: 7832,
 15: 7994,
 16: 7696,
 17: 7219,
 18: 6419,
 19: 5540,
 20: 4429,
 21: 3418,
 22: 2677,
 23: 2164,
 24: 2271,
 25: 2421,
 26: 2395,
 27: 2596,
 28: 2510,
 29: 2550,
 30: 2644,
 31: 2635,
 32: 2671,
 33: 2566,
 34: 2410,
 35: 2430,
 36: 2531,
 37: 2351,
 38: 2293,
 39: 2217,
 40: 2068,
 41: 2072,
 42: 1848,
 43: 1708,
 44: 1623,
 45: 1547,
 46: 1473,
 47: 1333,
 48: 1167,
 49: 1177,
 50: 1082,
 51: 1039,
 52: 992,
 53: 945,
 54: 857,
 55: 813,
 56: 798,
 57: 759,
 58: 767,
 59: 802,
 60: 749,
 61: 729,
 62: 697,
 63: 702,
 64: 700,
 65: 718,
 66: 683,
 67: 666,
 68: 628,
 69: 656,
 70: 606,
 71: 643,
 72: 655,
 73: 630,
 74: 629,
 75: 565,
 76: 593,
 77: 611,
 78: 557,
 79: 554,
 80: 524,
 81: 538,
 82: 531,
 83: 483,
 84: 468,
 85: 442

In [26]:
my_cluster = Cluster((10, 20, 30), bitarray(10), 4)
print(type(my_cluster))
my_cluster.bits
my_cluster.stat_parts()

<class 'cluster.Cluster'>


{}