In [None]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os
import re
import pandas as pd
import matplotlib as mpl
from glob import glob
from tifffile import imread

import yaml

In [None]:
output_folder = Path('output')

if not output_folder.is_dir():
    os.makedirs(str(output_folder))

In [None]:
def readDataset(datasetName, verbose=False):

    dataset_root = r'..\datasets'

    data_class = 'masks'
    data_purpose = 'train'
    
    verbose and print(os.path.join(dataset_root, datasetName, data_purposes[0], data_classes[0], '*.tif'))
    
    Y = {data_purpose:
            [imread(x) for x in sorted(glob(os.path.join(dataset_root, datasetName, data_purpose, data_class, '*.tif')))]
        }

    return Y

In [None]:
with open('..\config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

datasetname = 'full_semimanual-raw'
outputname = 'test.png'

final_models = config['cellpose_models_raw_full_low']
final_models

In [None]:
def get_minor_models(modelname):
    
    tmp = modelname.split('_ep')[-1].split('_dep')
    epochs = int(tmp[0])
    delta_epochs = int(tmp[-1])
    minor_models = []
    
    for ep in range(epochs, 1, -delta_epochs):
        minor_models.append(modelname.replace('_ep500', f'_ep{ep}'))
        
    return minor_models

In [None]:
models = []
for modelname in final_models:
    for m in get_minor_models(modelname):
        models.append(m)
        

In [None]:
accuracy_files = []

In [None]:
accuracy_files = [str(Path('..\data') / m / 'accuracy_manual_raw_v3.csv' ) for m in models]
#accuracy_files = [f'data/{m}/accuracy_full_semimanual-raw.csv' for m in models]
accuracy_files = [Path(f).parent for f in accuracy_files if Path(f).is_file()]

accuracy_files

In [None]:
df = pd.DataFrame(columns=['path', 'type', 'percentage', 'replicate', 'epoch', 'cell_number', 'accuracy_manual', 'accuracy_semimanual'])

p = '.*True_(?P<percentage>[\d\.]+)prc_rep(?P<replicate>\d+)_ep(?P<epoch>\d+)_dep.*'
pattern = re.compile(p)

for f in accuracy_files:
    match = pattern.match(str(f))
    df = df.append({'path':str(f) , 'type':'cellpose', **match.groupdict()}, ignore_index=True)

In [None]:
df

In [None]:
Y = readDataset('patches-semimanual-raw-64x128x128')

In [None]:
len(Y['train'])

In [None]:
sum_Y = [np.sum(y) for y in Y['train']]
Y['train'] = [Y['train'][i] for i in range(len(Y['train'])) if sum_Y[i] > 0]

In [None]:
len(Y['train'])

In [None]:
N_cells = [len(np.unique(y))-1 for y in Y['train']]

In [None]:
for index, row in df.iterrows():
    seed = int(row.replicate) if row.type == 'cellpose' else 42
    rng = np.random.RandomState(int(row.replicate))
    ind = rng.permutation(len(Y['train']))
    n_val = max(1, int(round(float(row.percentage) / 100 * len(ind))))
    df.iloc[index]['cell_number'] = np.sum([N_cells[i] for i in ind[:n_val]])
    
    for data_name, col in zip(['accuracy_manual_raw_v3.csv', 'accuracy_full_semimanual-raw.csv'], ['accuracy_manual', 'accuracy_semimanual']):
    #for data_name, col in zip(['accuracy_full_semimanual-raw.csv', 'accuracy_full_semimanual-raw.csv'], ['accuracy_manual', 'accuracy_semimanual']):
        if (Path(row.path) / data_name).is_file():
            data = np.genfromtxt(Path(row.path) / data_name, delimiter=' ')
            df.iloc[index][col] = data[1][np.where(data[0]==0.5)[0]][0]

        else:
            df.iloc[index][col] = np.nan

df

In [None]:
df = df.astype({'accuracy_manual': 'float', 'accuracy_semimanual':'float', 'percentage':'float', 'epoch':'float'})

In [None]:
df[(df.type=='cellpose') & (df.percentage == 100) & (df.epoch==500)]

In [None]:
# read default cellpose iterative training

cellpose_vals = []
for i, row in df[(df.type=='cellpose') & (df.percentage == 100) & (df.epoch==500)].iterrows():
    acc_file = sorted(Path(row.path).glob('accuracy_full_semimanual-raw.csv'))
    data = np.genfromtxt(str(acc_file[0]), delimiter=' ')
    cellpose_vals.append(data[1])
    

In [None]:
# read horovod
acc_files_horovod = sorted(Path('..\data').glob('horovod*prc100*/accuracy_full_semimanual-raw.csv'))

In [None]:
acc_files_horovod

In [None]:
horovod_vals = []
for acc_file in acc_files_horovod:
    data = np.genfromtxt(str(acc_file), delimiter=' ')
    horovod_vals.append(data[1])

tau_vals = data[0]

horovod_mean = np.mean(horovod_vals, axis=0)
cellpose_mean = np.mean(cellpose_vals, axis=0)

horovod_std =  np.std(horovod_vals, axis=0)
cellpose_std =  np.std(cellpose_vals, axis=0)
    
f, ax = plt.subplots(1)
h, = ax.plot(tau_vals, horovod_mean, label='horovod')
c, = ax.plot(tau_vals, cellpose_mean, label='cellpose')

ax.fill_between(tau_vals, horovod_mean - horovod_std, horovod_mean + horovod_std,
    color=h.get_color(), alpha=0.2)

ax.fill_between(tau_vals, cellpose_mean - cellpose_std, cellpose_mean + cellpose_std,
    color=c.get_color(), alpha=0.2)
            

    
ax.legend()
ax.grid()

ax.set_xlabel('Intersection over Union [a.u.]')
ax.set_ylabel('Mean accuracy [a.u.]')

plt.savefig(str(output_folder / 'horovod_vs_default_cellpose_100prc.svg'))

In [None]:
for i, row in df[(df.type=='cellpose') & (df.percentage == 100) & (df.epoch==500)].iterrows():
    print(row.path)

In [None]:
acc_files_horovod = sorted(Path('..\data').glob('horovod*prc100*/accuracy_full_semimanual-raw.csv'))
acc_files_horovod