In [14]:
%load_ext autoreload
# %reload_ext autoreload

# Reload all modules imported with %aimport every time before executing the Python code typed
%autoreload 1

%aimport context_nn 
%aimport phrase_feeder 
%aimport notes
%aimport watch_point
%aimport cluster
import numpy as np
from bitarray import bitarray
from context_nn import ContextNN
from watch_point import WatchPoint
from cluster import Cluster
from phrase_feeder import PhraseFeeder

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
abc_notes = notes.Notes(note_count=26, 
                        notation_count=10, 
                        active_bits=8, 
                        bit_count=255)

In [53]:
import pickle

def load_phrase_base(file_name: str) -> (dict, list):
    with open(file_name, 'rb') as f:
        data = pickle.load(f)
        phrase_base = data.get('phrase_base', {})
        marks = data.get('marks', {})
        return phrase_base, marks

In [54]:
phrase_base, marks = load_phrase_base('./data/texts/phrase_base.pickle')

In [55]:
from ipythonblocks import BlockGrid

def draw_notation(notation: np.array):
    bit_grid = BlockGrid(len(notation), 1, fill=(17, 41, 129))
    for block in range(bit_grid.width):
        color = bit_grid[0, block]
        if notation[block]:
            bit_grid[0, block] = (244, 195, 173)
    bit_grid.lines_on = False
    bit_grid.show()

def draw_note_notations(notes, note_idx: int):
    for i in range(len(notes[note_idx])):
        draw_notation(notes.note_notation_as_bits(note_idx, i))

In [58]:
key = list(phrase_base.keys())[0]
phrase = phrase_base[key][0]
bit_chord = abc_notes.phrase_chord(phrase)
draw_notation(abc_notes.notation_as_bits(bit_chord))
print(phrase)
for note_idx, notation_idx in phrase:
    notation = abc_notes.note_notation_as_bits(note_idx, notation_idx)
    draw_notation(notation)

[(8, 0), (0, 1), (13, 2), (24, 3), (18, 4)]


In [59]:
feeder = phrase_feeder.PhraseFeeder(phrase_base, marks)    

In [None]:
phrases, output_bits = feeder.take_phrase_bunch()
print(output_bits, marks[output_bits])
print(phrases)

In [60]:
cxnn = context_nn.ContextNN(input_bit_count=255,
                            output_bit_count=8,
                            watch_point_count=100,
                            watch_bit_count=32,
                            cluster_make_threshold=6,
                            cluster_activate_threshold=4)

In [9]:
len(cxnn.watch_points.keys())

1000

In [273]:
%%time
phrases, output_bits = feeder.take_phrase_bunch(count=200)
output_bits = set(output_bits)
for phrase in phrases:
    bit_chord = abc_notes.phrase_chord(phrase)
    cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)
    

Wall time: 1.63 s


In [61]:
def feed_phrase_bunch(feeder, count=200):
    phrases, output_bits = feeder.take_phrase_bunch(count=200)
    output_bits = set(output_bits)
    for phrase in phrases:
        bit_chord = abc_notes.phrase_chord(phrase)
        cxnn.receive_bits(input_bits=bit_chord, output_bits=output_bits)

def show_point_stats(cxnn: ContextNN):
    cluster_counts = [len(wp.clusters.values()) for wp in cxnn.watch_points.values()]
    output_bits = [(wp.output_bit, len(wp.clusters.values())) 
                   for wp in cxnn.watch_points.values()]
    print('Cluster count:', sum(cluster_counts))
    print(output_bits)
  

In [64]:
%%time
for i in range(10):
    feed_phrase_bunch(feeder, count=200)

Wall time: 25.9 s


In [65]:
show_point_stats(cxnn)
marks

Cluster count: 53463
[(5, 826), (6, 209), (6, 324), (7, 330), (4, 703), (3, 769), (6, 308), (6, 386), (7, 321), (1, 686), (6, 250), (5, 852), (7, 493), (4, 826), (1, 696), (4, 600), (6, 416), (1, 535), (5, 711), (6, 279), (5, 759), (6, 364), (1, 498), (2, 454), (7, 337), (7, 273), (3, 877), (2, 355), (1, 949), (0, 449), (4, 663), (7, 393), (7, 350), (6, 192), (0, 400), (2, 824), (1, 675), (4, 608), (4, 663), (6, 256), (1, 545), (6, 402), (0, 525), (6, 208), (4, 823), (6, 254), (0, 730), (7, 425), (0, 582), (1, 498), (7, 409), (7, 272), (1, 431), (4, 748), (2, 494), (5, 835), (1, 688), (5, 751), (5, 828), (1, 292), (3, 925), (1, 470), (2, 574), (7, 222), (4, 643), (4, 356), (5, 754), (2, 660), (2, 604), (1, 469), (7, 215), (3, 1028), (4, 719), (6, 253), (7, 297), (2, 332), (5, 683), (1, 459), (1, 521), (3, 732), (1, 504), (7, 326), (1, 633), (6, 314), (4, 598), (5, 868), (0, 418), (3, 635), (1, 535), (2, 545), (5, 727), (1, 681), (4, 739), (7, 309), (2, 441), (2, 637), (3, 855), (1, 551

{(0, 1, 3, 6): 'jbo',
 (0, 2, 3, 5): 'rus',
 (0, 2, 3, 7): 'bel',
 (0, 2, 4, 6): 'blg',
 (1, 3, 4, 5): 'pol',
 (1, 4, 5, 7): 'eng',
 (2, 3, 5, 7): 'epo',
 (2, 4, 5, 6): 'ukr'}

In [66]:
def dict_value(d: dict, n=0) -> any:
    i = 0
    diter = iter(d)
    while i <= n:
        key = next(diter)
        i += 1
    return d[key]

In [70]:
# point = cxnn.watch_points.values()[0]
point = dict_value(cxnn.watch_points)
cluster = dict_value(point.clusters, 0)
print(cluster.stats)
cluster.stat_parts()

{(50, 86, 111, 117, 125, 130, 137, 158): 1, (50, 86, 117, 130, 137): 2, (50, 86, 130, 137): 1, (50, 111, 117, 125): 1, (86, 111, 117, 125): 2, (86, 111, 117, 130): 2, (111, 117, 125, 130, 158): 3, (111, 125, 130, 158): 10, (111, 125, 130, 137, 158): 3, (111, 125, 137, 158): 4, (111, 117, 130, 137): 3, (86, 117, 125, 130): 12, (117, 125, 130, 137): 1, (117, 130, 137, 158): 2, (111, 130, 137, 158): 4, (111, 117, 125, 158): 6, (50, 111, 117, 158): 2, (86, 111, 125, 137, 158): 1, (86, 111, 125, 130, 137, 158): 1, (86, 125, 130, 137): 3, (86, 117, 130, 158): 6, (50, 86, 111, 117): 3, (50, 125, 137, 158): 3, (111, 117, 130, 158): 2, (50, 125, 130, 158): 4, (86, 111, 117, 125, 130, 158): 6, (50, 86, 117, 137): 3, (86, 111, 117, 125, 158): 12, (86, 111, 117, 137): 1, (86, 111, 130, 137, 158): 3, (86, 111, 130, 158): 2, (111, 117, 125, 137, 158): 3, (50, 111, 125, 158): 1, (50, 86, 125, 137): 1, (50, 86, 117, 125, 137): 1, (50, 117, 125, 137): 1, (86, 117, 137, 158): 1, (125, 130, 137, 158): 3,

{(50, 86, 111, 117): 0.023622047244094488,
 (50, 86, 111, 117, 125, 130, 137, 158): 0.007874015748031496,
 (50, 86, 111, 117, 137): 0.007874015748031496,
 (50, 86, 111, 130, 137): 0.007874015748031496,
 (50, 86, 117, 125, 137): 0.007874015748031496,
 (50, 86, 117, 130, 137): 0.015748031496062992,
 (50, 86, 117, 137): 0.023622047244094488,
 (50, 86, 125, 137): 0.007874015748031496,
 (50, 86, 130, 137): 0.007874015748031496,
 (50, 86, 130, 137, 158): 0.007874015748031496,
 (50, 111, 117, 125): 0.007874015748031496,
 (50, 111, 117, 158): 0.015748031496062992,
 (50, 111, 125, 158): 0.007874015748031496,
 (50, 117, 125, 137): 0.007874015748031496,
 (50, 117, 130, 158): 0.007874015748031496,
 (50, 125, 130, 158): 0.031496062992125984,
 (50, 125, 137, 158): 0.023622047244094488,
 (86, 111, 117, 125): 0.015748031496062992,
 (86, 111, 117, 125, 130, 158): 0.047244094488188976,
 (86, 111, 117, 125, 158): 0.09448818897637795,
 (86, 111, 117, 130): 0.015748031496062992,
 (86, 111, 117, 137): 0.007

In [51]:
my_cluster = Cluster((10, 20, 30), bitarray(10), 4)
print(type(my_cluster))
my_cluster.bit_set
my_cluster.stat_parts()

<class 'cluster.Cluster'>


{}