In [1]:
import sys
sys.path.append("C:/Users/User/Desktop/hopfield-izhikevich/hopfield/")
from hopfield import Hopfield
sys.path.pop();

In [2]:
import os
import matplotlib.pyplot as plt
import numpy as np
import random
from pprint import pprint
from tqdm.notebook import tqdm, tqdm_notebook
from typing import Callable
import numpy
import itertools as it

In [3]:
def load_dir(dataset: list, path: str, label: bool, sort=True):
    dataset.sort(key=lambda e: e["id"])
    if dataset:
        index = dataset[-1]["id"]+1
    else:
        index = 0
    fnames = [(path+"/Red/"+i[:2]+"_Red.txt", path+"/Green/"+i[:2]+"_Green.txt",path+"/Blue/"+i[:2]+"_Blue.txt") for i in [i for i in os.walk(path)][1][2]]
    for fname in fnames:
        person = {"id": index}
        person["label"] = label
        with open(fname[0],"r") as rfile, open(fname[1],"r") as gfile, open(fname[2],"r") as bfile:
            person["r"] = [float(i) for i in rfile.readlines()[1:]]
            person["g"] = [float(i) for i in gfile.readlines()[1:]]
            person["b"] = [float(i) for i in bfile.readlines()[1:]]
        
        if sort:
            person["r"].sort()
            person["g"].sort()            
            person["b"].sort()
        
        index+=1
        dataset.append(person)
        
def load_dataset(dataset: list, path: str, pathpos: str, pathneg: str, sort=True):
    load_dir(dataset, path+"/"+pathpos, True, sort=sort)
    load_dir(dataset, path+"/"+pathneg, False, sort=sort)    
        
        
# def get_filter(dataset: list, filter: str):
#     if filter!="r" and filter!="g" and filter!="b":
#         raise ValueError('Wrong filter: must be "r", "g", "b"')
#     new_dataset = []
#     for person in dataset:
#         new_dataset.append({"id": person["id"], "label": person["label"], "data": person[filter]})
    
#     return new_dataset

def show_plot_by_filter(dataset: list, filter: str):
    plt.clf()
    if filter!="r" and filter!="g" and filter!="b":
        raise ValueError('Wrong filter: must be "r", "g", "b"')
    for i in np.arange(0,len(dataset), 1):
        color="green"
        if dataset[i]["label"]:
            color="red"
        for j in dataset[i][filter]:
            plt.plot(i,j, "o", markersize=0.5, color=color)
    plt.show()

    
def binarize_person(person: dict, precision=1e-3, up=1.8, down=0.2, radius=0, flatten=True, filters=3):
    person_data = np.zeros((3, int((up-down)//precision)+1), np.float32)[0:filters]
    
    dots  = (
        ((np.array(person["b"]).clip(down, up)-down)//precision).astype(np.uint32),
        ((np.array(person["g"]).clip(down, up)-down)//precision).astype(np.uint32),
        ((np.array(person["r"]).clip(down, up)-down)//precision).astype(np.uint32))[0:filters]
#     r = np.array(person["r"])
#     g = np.array(person["g"])
#     b = np.array(person["b"])
    
    
#     dots  = (
#         ((r[up>r & r>down]-down)//precision).astype(np.uint32),
#         (([up>g & g>down]-down)//precision).astype(np.uint32),
#         (([up>b & b>down]-down)//precision).astype(np.uint32))[0:filters]
    
    
    
    for color_index in range(filters):
        for dot in dots[color_index]:
            person_data[color_index][max(0, dot-radius) : min(dot+radius+1, person_data.shape[1])].fill(1)
    
    if flatten:
        return {"id": person["id"], "label": person["label"], "data": person_data.flatten()}
    else:
        return {"id": person["id"], "label": person["label"], "data": person_data}

def show_plot_avg_by_filter(dataset: list, filter: str):
    plt.clf()
    if filter!="r" and filter!="g" and filter!="b":
        raise ValueError('Wrong filter: must be "r", "g", "b"')
    for person in dataset:
        
        if person["label"]:
            color="red"
            x = 0
        else:
            color="green"
            x = 1
             
        for y in person[filter]:
            plt.plot(x,y, "o", markersize=0.5, color=color)
    plt.show()
    


In [4]:
def dice_score(prediction, data) -> float:
    tp, tn, fp, fn = 0, 0, 0, 0
    for p, d in zip(prediction.flat, data.flat):
        if d>0 and p>0:
            tp += 1
        # if d<0 and p<0:
        #     tn += 1
        if d>0 and p<0:
            fn += 1
        if d<0 and p>0:
            fp += 1
  
    return 2*tp/(2*tp + fp + fn)

In [5]:
def max_score_i(dataset: np.ndarray, X: np.ndarray, i: int, score_fn: Callable):
    if i==0:
        max_i = 1
    else:
        max_i = 0
    # print(f"len {len(dataset)}")
    for j in (k for k in range(len(dataset)) if k!=i):
        if score_fn(X,dataset[max_i]) <= score_fn(X,dataset[j]):
            max_i = j
            
    return max_i

In [6]:
def iou_score(prediction, data) -> float:
    tp, tn, fp, fn = 0, 0, 0, 0
    for p, d in zip(prediction.flat, data.flat):
        if d>0 and p>0:
            tp += 1
        # if d<0 and p<0:
        #     tn += 1
        if d>0 and p<0:
            fn += 1
        if d<0 and p>0:
            fp += 1
  
    return tp/(tp + fp + fn)

In [7]:
dataset = []
load_dataset(dataset, "data_orig/Data", "BC", "Control")

## One-leave-out with dice score

In [8]:
q_range=(0.0,0.6,6)
r_range=(1,5,5)
q = np.linspace(*q_range, dtype=np.float32)
r = np.linspace(*r_range,dtype=int)
tests = list(it.product(q,r))
_dataset=dataset[39:]

In [9]:
filters=1

time=500
dtype=np.float32

test_result=[]

for q,r in tqdm(tests):
    result = []
    answer = []
    maxs = []
    dataset_bin = [binarize_person(p, precision=1e-3, radius=r, filters=filters, down=0.8, up=1.1) for p in _dataset]
    raw_data = np.array([i["data"]*2-1 for i in dataset_bin], dtype=dtype)

    for i in tqdm(range(len(_dataset))):
        model = Hopfield(q=q, np_type=dtype)
        _mask = np.ones(len(dataset_bin), dtype=bool)
        _mask[i] = False

        _X = raw_data[i]
        _images = raw_data[_mask]


        model.train(images=_images, method='strokey')
        out = model.run(X=_X, time=time)


        max_i = max_score_i(dataset=raw_data, X=out["output"], i=i, score_fn=dice_score)
        maxs.append(max_i)


        answer.append(dataset_bin[max_i]["label"])
        if dataset_bin[max_i]["label"]==dataset_bin[i]["label"]:
            result.append(1)
        else:
            result.append(0)
            
    score = sum(result)/len(_dataset)
    test_result.append({"q":q,"r":r, "score":score, "result":result, "maxs":maxs})
    print("qr",(q,r), " score ",score)
    



  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/58 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
q=0
radius=5
#0 25 0.65 10-4

filters=1

time=250
dtype=np.float32

result = []
answer = []
maxs = []

dataset_bin = [binarize_person(p, precision=1e-3, radius=radius, down=0.37, up=1.7, filters=1) for p in _dataset]
raw_data = np.array([i["data"]*2-1 for i in dataset_bin], dtype=dtype)
plt.imshow(raw_data)
plt.show()
print(raw_data.shape)
print(f"coef: {raw_data.shape[0]/raw_data.shape[1]}")


In [10]:
p = 0
mask_info = np.array([True for i in range(len(raw_data))], dtype=bool)
for index in range(len(raw_data.T)):
    if (raw_data.T[index] > 0).all() or (raw_data.T[index] < 0).all():
        mask_info = False
        p+=1
p        
        
        
        

0

In [14]:


result = []
answer = []
maxs = []
for i in tqdm_notebook(range(len(_dataset))):
    model = Hopfield(q=q, np_type=dtype)
    _mask = np.ones(len(dataset_bin), dtype=bool)
    _mask[i] = False

    _X = raw_data[i]
    _images = raw_data[_mask]

    
    model.train(images=_images, method='strokey')
    out = model.run(X=_X, time=time)
    
    max_i = max_score_i(dataset=raw_data, X=out["output"], i=i, score_fn=iou_score)
    maxs.append(max_i)
    
    print(max_i, end=" ")
    
    answer.append(dataset_bin[max_i]["label"])
    if dataset_bin[max_i]["label"]==dataset_bin[i]["label"]:
        result.append(1)
    else:
        result.append(0)
    
print("")
score = sum(result)/len(_dataset)
print(score)

  0%|          | 0/58 [00:00<?, ?it/s]

1 31 23 34 49 1 47 40 7 43 22 54 19 43 36 46 35 43 4 12 1 4 10 2 47 31 31 40 36 4 56 25 51 56 53 8 5 34 20 34 7 51 53 26 52 29 50 57 31 4 51 31 4 34 42 4 50 47 
0.5344827586206896
