In [120]:
%load_ext autoreload
# %reload_ext autoreload

# Reload all modules imported with %aimport every time before executing the Python code typed
%autoreload 1

%aimport context_nn 
%aimport phrase_feeder 
%aimport notes
%aimport watch_point
%aimport cluster
%aimport constants
import numpy as np
from bitarray import bitarray
from context_nn import ContextNN
from watch_point import WatchPoint
from cluster import Cluster
from phrase_feeder import PhraseFeeder
from notes import Notes
from pprint import pprint
import math
import constants as const

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [60]:
import pickle

def load_phrase_base(file_name: str) -> (dict, list):
    with open(file_name, 'rb') as f:
        data = pickle.load(f)
        phrase_base = data.get('phrase_base', {})
        marks = data.get('marks', {})
        return phrase_base, marks
    
phrase_base, marks = load_phrase_base('./data/texts/phrase_base.pickle')

In [43]:
from ipythonblocks import BlockGrid

def draw_notation(notation: np.array):
    bit_grid = BlockGrid(len(notation), 1, fill=(17, 41, 129))
    for block in range(bit_grid.width):
        color = bit_grid[0, block]
        if notation[block]:
            bit_grid[0, block] = (244, 195, 173)
    bit_grid.lines_on = False
    bit_grid.show()

def draw_note_notations(notes, note_idx: int):
    for i in range(len(notes[note_idx])):
        draw_notation(notes.note_notation_as_bits(note_idx, i))

In [44]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=5, 
                        active_bits=8, 
                        bit_count=255)

key = list(phrase_base.keys())[0]
phrase = phrase_base[key][0]
bit_chord = abc_notes.phrase_chord(phrase)
draw_notation(abc_notes.notation_as_bits(bit_chord))
print(phrase)
for note_idx, notation_idx in phrase:
    notation = abc_notes.note_notation_as_bits(note_idx, notation_idx)
    draw_notation(notation)

[(8, 0), (0, 1), (13, 2), (24, 3), (18, 4)]


In [4]:
feeder = PhraseFeeder(phrase_base, marks)    

In [None]:
phrases, output_bits = feeder.take_phrase_bunch()
print(output_bits, marks[output_bits])
print(phrases)

In [125]:
def feed_phrase_batch(cxnn: ContextNN, feeder: PhraseFeeder, abc_notes: Notes, count=200):
    phrases, output_bits = feeder.take_phrase_batch(count=200)
    output_bits = set(output_bits)
    for phrase in phrases:
        bit_chord = abc_notes.phrase_chord(phrase)
        cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)

def feed_n_batches(cxnn: ContextNN, feeder: PhraseFeeder, abc_notes: Notes, n=10, batch_size=200):
    for i in range(n):
        feed_phrase_batch(cxnn, feeder, abc_notes, count=batch_size)
        print(f'batch {i+1}/{n}')
    print(f'clusters: {cxnn.cluster_count()}')
        
def learn_cycle(cxnn: ContextNN, feeder: PhraseFeeder, abc_notes: Notes):
    print('accumulating clusters...')
    cxnn.state = const.STATE_ACCUMULATE
    feed_n_batches(cxnn, feeder, abc_notes, n=50, batch_size=200)
    
    print('\nverifying clusters...')
    cxnn.state = const.STATE_VERIFY
    feed_n_batches(cxnn, feeder, abc_notes, n=30, batch_size=200)
    
    print('\nreducing clusters...')
    cxnn.reduce_clusters(min_component=0.2, 
                         min_activations=100,
                         trim=True,
                         remain_part=0.3,
                         clear_stats=True,
                         consolidate=True,
                         amnesty=True)
    print(f'clusters: {cxnn.cluster_count()}')
    

In [122]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=5, 
                        active_bits=8, 
                        bit_count=255)

feeder = phrase_feeder.PhraseFeeder(phrase_base, marks)  

cxnn = context_nn.ContextNN(input_bit_count=255,
                            output_bit_count=8,
                            watch_point_count=100,
                            watch_bit_count=32,
                            cluster_make_threshold=6,
                            cluster_activate_threshold=4)

In [132]:
%%time
learn_cycle(cxnn, feeder, abc_notes)

accumulating clusters...
batch 1/50
batch 2/50
batch 3/50
batch 4/50
batch 5/50
batch 6/50
batch 7/50
batch 8/50
batch 9/50
batch 10/50
batch 11/50
batch 12/50
batch 13/50
batch 14/50
batch 15/50
batch 16/50
batch 17/50
batch 18/50
batch 19/50
batch 20/50
batch 21/50
batch 22/50
batch 23/50
batch 24/50
batch 25/50
batch 26/50
batch 27/50
batch 28/50
batch 29/50
batch 30/50
batch 31/50
batch 32/50
batch 33/50
batch 34/50
batch 35/50
batch 36/50
batch 37/50
batch 38/50
batch 39/50
batch 40/50
batch 41/50
batch 42/50
batch 43/50
batch 44/50
batch 45/50
batch 46/50
batch 47/50
batch 48/50
batch 49/50
batch 50/50
clusters: 161119

verifying clusters...
batch 1/30
batch 2/30
batch 3/30
batch 4/30
batch 5/30
batch 6/30
batch 7/30
batch 8/30
batch 9/30
batch 10/30
batch 11/30
batch 12/30
batch 13/30
batch 14/30
batch 15/30
batch 16/30
batch 17/30
batch 18/30
batch 19/30
batch 20/30
batch 21/30
batch 22/30
batch 23/30
batch 24/30
batch 25/30
batch 26/30
batch 27/30
batch 28/30
batch 29/30
batch

In [127]:
print(cxnn.cluster_count())
print(cxnn.point_stats())
marks

70251
[(7, 327), (6, 170), (1, 824), (3, 1018), (3, 661), (5, 574), (0, 876), (3, 1744), (3, 1407), (4, 1546), (2, 298), (5, 835), (4, 1237), (4, 656), (1, 700), (0, 924), (6, 138), (0, 679), (6, 97), (5, 916), (0, 823), (2, 296), (1, 1113), (5, 591), (2, 155), (0, 652), (6, 264), (3, 843), (4, 956), (2, 334), (7, 277), (6, 257), (2, 169), (3, 797), (6, 122), (1, 758), (4, 1789), (2, 390), (7, 116), (1, 593), (2, 261), (5, 564), (3, 1347), (3, 1100), (3, 1376), (5, 701), (1, 354), (4, 943), (2, 281), (4, 832), (7, 280), (1, 799), (3, 747), (1, 1198), (0, 1062), (4, 1401), (1, 790), (5, 690), (7, 193), (4, 1251), (3, 1203), (2, 167), (3, 540), (7, 133), (6, 194), (6, 144), (3, 1110), (7, 172), (1, 746), (0, 795), (2, 396), (6, 47), (4, 1340), (5, 960), (4, 1402), (6, 88), (3, 773), (6, 195), (6, 203), (1, 765), (4, 1173), (7, 365), (3, 1152), (7, 455), (1, 825), (0, 917), (3, 1208), (7, 123), (0, 1108), (3, 1319), (5, 633), (3, 1351), (2, 458), (4, 1317), (1, 782), (7, 190), (4, 1293), 

{(0, 1, 3, 4): 'bel',
 (0, 1, 3, 6): 'pol',
 (0, 1, 4, 5): 'ukr',
 (0, 1, 5, 6): 'epo',
 (0, 3, 4, 5): 'jbo',
 (1, 3, 4, 7): 'blg',
 (2, 3, 4, 5): 'rus',
 (2, 3, 4, 7): 'eng'}

In [133]:
print(cxnn.cluster_count())
pprint(cxnn.cluster_len_stats())
pprint(cxnn.cluster_activity_stats())

100941
{4: 74315, 5: 17518, 6: 7496, 7: 1318, 8: 271, 9: 22, 10: 1}
{0: 100941}


In [134]:
cxnn.cluster_consolidated_stats()
    

{0: 7290, 1: 43837, 2: 27609, 3: 22205}

In [17]:
def dict_value(d: dict, n=0) -> object:
    i = 0
    diter = iter(d)
    while i <= n:
        key = next(diter)
        i += 1
    return d[key]


point = dict_value(cxnn.watch_points, n=0)
cluster_idx = 0
cluster = point.cluster_objects[cluster_idx]
pprint(cluster.stats)
print(sum(cluster.stats.values()))
# pprint(cluster.component_stats())
# pprint(cluster.bit_rate())

pprint(sorted(cluster.component_stats().items(), key=lambda x: x[1], reverse=True))
print(point.cluster_masks)
print(tuple(sorted(cluster.bits)))
print(point.watch_bits)
print(point.cluster_masks[cluster_idx])
print(len(point.cluster_masks), len(point.cluster_objects))

# parts = cluster.component_s
# enthropy = 0.0
# for part in parts.values():
#     enthropy += -part * math.log2(part)
# print(enthropy)

{}
0
[]
[[0 0 0 ..., 0 0 0]
 [0 0 0 ..., 1 0 0]
 [0 0 0 ..., 0 0 0]
 ..., 
 [0 0 0 ..., 0 0 0]
 [0 0 0 ..., 0 0 0]
 [0 1 0 ..., 0 0 0]]
(57, 122, 212, 215)
(2, 3, 22, 30, 40, 54, 57, 94, 102, 106, 122, 128, 132, 137, 140, 154, 157, 164, 175, 179, 200, 208, 211, 212, 214, 215, 216, 223, 227, 239, 245, 249)
[0 0 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 0 0 0 0]
358 358


In [131]:
cluster_with_stats = None
for wp in cxnn.point_objects:
    for cluster in wp.cluster_objects:    
        if cluster.stats:
            cluster_with_stats = cluster
            break
    else: 
        continue
    break    

if cluster_with_stats:
    print(cluster_with_stats.stats, cluster_with_stats.consolidated)
else:
    print('nothing to print')
    
    

nothing to print


In [26]:
my_cluster = Cluster((10, 20, 30), bitarray(10), 4)
print(type(my_cluster))
my_cluster.bits
my_cluster.stat_parts()

<class 'cluster.Cluster'>


{}

In [6]:
[(8, 0), (0, 1), (13, 2), (24, 3), (18, 4)].__sizeof__()

80