Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
783 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import os | ||
from src.animal.generate_dgm_provider import generate_dgm_provider | ||
from multiprocessing import cpu_count | ||
|
||
|
||
if os.path.exists('./data/dgm_rovider/animal_corrected.h5'): | ||
print('Persistence diagram provider does not exists, creating ... (this may need some time)') | ||
n_cores = max(1, cpu_count()-1) | ||
generate_dgm_provider('./data/raw_data/animal_corrected', | ||
'./data/dgm_provider/animal_corrected.h5', | ||
32, | ||
n_cores=n_cores) | ||
else: | ||
print('Found Persistence diagram provider!') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
raw_data |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import multiprocessing | ||
|
||
import numpy as np | ||
import scipy.misc | ||
import scipy.ndimage | ||
import skimage.morphology | ||
from pershombox import calculate_discrete_NPHT_2d | ||
|
||
from src.sharedCode.fileSys import Folder | ||
from src.sharedCode.gui import SimpleProgressCounter | ||
from src.sharedCode.provider import Provider | ||
|
||
|
||
def preprocess_img(img): | ||
label_map, n = skimage.morphology.label(img, neighbors=4, background=0, return_num=True) | ||
volumes = [] | ||
for i in range(n): | ||
volumes.append(np.count_nonzero(label_map == (i + 1))) | ||
|
||
arg_max = np.argmax(volumes) | ||
img = (label_map == (arg_max + 1)) | ||
|
||
return img | ||
|
||
|
||
def get_npht(img, number_of_directions): | ||
img = np.ndarray.astype(img, bool) | ||
|
||
npht = calculate_discrete_NPHT_2d(img, number_of_directions) | ||
return npht | ||
|
||
|
||
def job(args): | ||
sample_file_path = args['file_path'] | ||
label = args['label'] | ||
sample_id = args['sample_id'] | ||
number_of_directions = args['number_of_directions'] | ||
return_value = {'label': label, 'sample_id': sample_id, 'dgms': {}} | ||
|
||
img = scipy.misc.imread(sample_file_path, flatten=True) | ||
img = preprocess_img(img) | ||
try: | ||
npht = get_npht(img, number_of_directions) | ||
|
||
except Exception as ex: | ||
return_value['error'] = ex | ||
else: | ||
dgms_dim_0 = [x[0] for x in npht] | ||
dgms_dim_1 = [x[1] for x in npht] | ||
|
||
for dir_i, dgm_0, dgm_1 in zip(range(1, number_of_directions + 1), dgms_dim_0, dgms_dim_1): | ||
if len(dgm_0) == 0: | ||
return_value['error'] = 'Degenerate diagram detected.' | ||
break | ||
|
||
return_value['dgms']['dim_0_dir_{}'.format(dir_i)] = dgm_0 | ||
return_value['dgms']['dim_1_dir_{}'.format(dir_i)] = dgm_1 | ||
|
||
return return_value | ||
|
||
|
||
def generate_dgm_provider(data_path, output_file_path, number_of_directions, n_cores=4): | ||
src_folder = Folder(data_path) | ||
class_folders = src_folder.folders() | ||
|
||
n = sum([len(cf.files(name_pred=lambda n: n != 'Thumbs.db')) for cf in class_folders]) | ||
progress = SimpleProgressCounter(n) | ||
progress.display() | ||
|
||
views = {} | ||
for i in range(1, number_of_directions + 1): | ||
views['dim_0_dir_{}'.format(i)] = {} | ||
views['dim_1_dir_{}'.format(i)] = {} | ||
job_args = [] | ||
|
||
for class_folder in class_folders: | ||
for view in views.values(): | ||
view[class_folder.name] = {} | ||
|
||
for sample_file in class_folder.files(name_pred=lambda n: n != 'Thumbs.db'): | ||
args = {'file_path': sample_file.path, | ||
'label': class_folder.name, | ||
'sample_id': sample_file.name, | ||
'number_of_directions': number_of_directions} | ||
job_args.append(args) | ||
|
||
pool = multiprocessing.Pool(n_cores) | ||
|
||
errors = [] | ||
for result in pool.imap(job, job_args): | ||
try: | ||
label = result['label'] | ||
sample_id = result['sample_id'] | ||
|
||
if 'error' in result: | ||
errors.append((sample_id, result['error'])) | ||
else: | ||
for view_id, dgm in result['dgms'].items(): | ||
views[view_id][label][sample_id] = dgm | ||
progress.trigger_progress() | ||
|
||
except Exception as ex: | ||
errors.append(ex) | ||
|
||
prv = Provider() | ||
for key, view_data in views.items(): | ||
prv.add_view(key, view_data) | ||
|
||
meta = {'number_of_directions': number_of_directions} | ||
prv.add_meta_data(meta) | ||
|
||
prv.dump_as_h5(output_file_path) | ||
|
||
if len(errors) > 0: | ||
print(errors) | ||
|
||
|
||
if __name__ == '__main__': | ||
from argparse import ArgumentParser | ||
import os.path | ||
|
||
parser = ArgumentParser() | ||
parser.add_argument('input_folder_path', type=str) | ||
parser.add_argument('output_file_path', type=str) | ||
parser.add_argument('number_of_directions', type=int) | ||
parser.add_argument('--n_cores', type=int, default=4) | ||
|
||
args = parser.parse_args() | ||
|
||
output_dir = os.path.dirname(args.output_file_path) | ||
if not os.path.exists(output_dir): | ||
print(output_dir, 'does not exist.') | ||
else: | ||
generate_dgm_provider(args.input_folder_path, args.output_file_path, args.number_of_directions, n_cores=args.n_cores) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import torch | ||
import numpy as np | ||
import time | ||
import shutil | ||
import json | ||
import numpy | ||
import datetime | ||
import os | ||
|
||
from torch.utils.data import DataLoader | ||
from sklearn.model_selection import StratifiedShuffleSplit | ||
from sklearn.preprocessing.label import LabelEncoder | ||
from collections import defaultdict | ||
|
||
|
||
class PersistenceDiagramProviderCollate: | ||
def __init__(self, provider, wanted_views: [str] = None, | ||
label_map: callable = lambda x: x, | ||
output_type=torch.FloatTensor, | ||
target_type=torch.LongTensor, | ||
gpu=False): | ||
provided_views = provider.view_names | ||
|
||
if wanted_views is None: | ||
self.wanted_views = provided_views | ||
|
||
else: | ||
for wv in wanted_views: | ||
if wv not in provided_views: | ||
raise ValueError('{} is not provided by {} which provides {}'.format(wv, provider, provided_views)) | ||
|
||
self.wanted_views = wanted_views | ||
|
||
if not callable(label_map): | ||
raise ValueError('label_map is expected to be callable.') | ||
|
||
self.label_map = label_map | ||
|
||
self.output_type = output_type | ||
self.target_type = target_type | ||
self.gpu = gpu | ||
|
||
def __call__(self, sample_target_iter): | ||
batch_views, targets = defaultdict(list), [] | ||
|
||
for dgm_dict, label in sample_target_iter: | ||
for view_name in self.wanted_views: | ||
dgm = list(dgm_dict[view_name]) | ||
dgm = self.output_type(dgm) | ||
|
||
batch_views[view_name].append(dgm) | ||
|
||
targets.append(self.label_map(label)) | ||
|
||
targets = self.target_type(targets) | ||
|
||
if self.gpu: | ||
targets = targets.cuda() | ||
|
||
return batch_views, targets | ||
|
||
|
||
class SubsetRandomSampler: | ||
def __init__(self, indices): | ||
self.indices = indices | ||
|
||
def __iter__(self): | ||
return (self.indices[i] for i in torch.randperm(len(self.indices))) | ||
|
||
def __len__(self): | ||
return len(self.indices) | ||
|
||
|
||
def train_test_from_dataset(dataset, | ||
test_size=0.2, | ||
batch_size=64, | ||
gpu=False, | ||
wanted_views=None): | ||
|
||
sample_labels = list(dataset.sample_labels) | ||
label_encoder = LabelEncoder().fit(sample_labels) | ||
sample_labels = label_encoder.transform(sample_labels) | ||
|
||
label_map = lambda l: int(label_encoder.transform([l])[0]) | ||
collate_fn = PersistenceDiagramProviderCollate(dataset, label_map=label_map, gpu=gpu, wanted_views=wanted_views) | ||
|
||
sp = StratifiedShuffleSplit(n_splits=1, test_size=test_size) | ||
train_i, test_i = list(sp.split([0]*len(sample_labels), sample_labels))[0] | ||
|
||
data_train = DataLoader(dataset, | ||
batch_size=batch_size, | ||
collate_fn=collate_fn, | ||
shuffle=False, | ||
sampler=SubsetRandomSampler(train_i.tolist())) | ||
|
||
data_test = DataLoader(dataset, | ||
batch_size=batch_size, | ||
collate_fn=collate_fn, | ||
shuffle=False, | ||
sampler=SubsetRandomSampler(test_i.tolist())) | ||
|
||
return data_train, data_test | ||
|
||
|
||
class UpperDiagonalThresholdedLogTransform: | ||
def __init__(self, nu): | ||
self.b_1 = (torch.Tensor([1, 1]) / np.sqrt(2)) | ||
self.b_2 = (torch.Tensor([-1, 1]) / np.sqrt(2)) | ||
self.nu = nu | ||
|
||
def __call__(self, dgm): | ||
if dgm.ndimension() == 0: | ||
return dgm | ||
|
||
x = torch.mul(dgm, self.b_1.repeat(dgm.size(0), 1)) | ||
x = torch.sum(x, 1).squeeze() | ||
y = torch.mul(dgm, self.b_2.repeat( dgm.size(0), 1)) | ||
y = torch.sum(y, 1).squeeze() | ||
i = (y <= self.nu) | ||
y[i] = torch.log(y[i] / self.nu) + self.nu | ||
ret = torch.stack([x, y], 1) | ||
return ret | ||
|
||
|
||
def pers_dgm_center_init(n_elements): | ||
centers = [] | ||
while len(centers) < n_elements: | ||
x = np.random.rand(2) | ||
if x[1] > x[0]: | ||
centers.append(x.tolist()) | ||
|
||
return torch.Tensor(centers) | ||
|
||
|
||
def run_experiment_n_times(n, experiment, experiment_file_path): | ||
tmp_dir_path = os.path.join(os.getcwd(), str(time.time())) | ||
os.mkdir(tmp_dir_path) | ||
|
||
exp_file_name = os.path.basename(experiment_file_path) | ||
shutil.copy(experiment_file_path, os.path.join(tmp_dir_path, exp_file_name)) | ||
|
||
date = datetime.datetime.now() | ||
date = date.strftime("%Y-%m-%d %H:%M:%S").replace(' ', '_') | ||
|
||
res_pth = os.path.join(tmp_dir_path, 'results__' + date + '.json') | ||
|
||
result = [] | ||
|
||
for i in range(n): | ||
|
||
print('==================^================') | ||
print('Run {}'.format(i)) | ||
res_of_run = experiment() | ||
|
||
# model = res_of_run['model'] | ||
# | ||
# with open(os.path.join(tmp_dir_path, 'model_run_{}.pickle'.format(i)), 'bw') as f: | ||
# pickle.dump(model, f) | ||
|
||
del res_of_run['model'] | ||
|
||
result.append(res_of_run) | ||
|
||
with open(res_pth, 'w') as f: | ||
json.dump(result, f) | ||
|
||
avg_test_acc = numpy.mean([numpy.mean(r['test_accuracies'][-10:]) for r in result]) | ||
|
||
new_folder_name = '{}_{:.2f}_acc_on_{}'.format(exp_file_name.split('.py')[0], avg_test_acc, date) | ||
new_folder_name.replace('.', '_') | ||
os.rename(tmp_dir_path, os.path.join(os.path.dirname(tmp_dir_path), new_folder_name)) |
Oops, something went wrong.