Skip to content

Commit

Permalink
starting to refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
c-hofer committed Oct 16, 2017
1 parent cd9145d commit 0e22daa
Show file tree
Hide file tree
Showing 11 changed files with 783 additions and 0 deletions.
16 changes: 16 additions & 0 deletions animal_with_dim_1_essentials.py
@@ -0,0 +1,16 @@
import os
from src.animal.generate_dgm_provider import generate_dgm_provider
from multiprocessing import cpu_count


if os.path.exists('./data/dgm_rovider/animal_corrected.h5'):
print('Persistence diagram provider does not exists, creating ... (this may need some time)')
n_cores = max(1, cpu_count()-1)
generate_dgm_provider('./data/raw_data/animal_corrected',
'./data/dgm_provider/animal_corrected.h5',
32,
n_cores=n_cores)
else:
print('Found Persistence diagram provider!')


1 change: 1 addition & 0 deletions data/.gitignore
@@ -0,0 +1 @@
raw_data
Empty file added src/__init__.py
Empty file.
Empty file added src/animal/__init__.py
Empty file.
Empty file added src/animal/experiments.py
Empty file.
134 changes: 134 additions & 0 deletions src/animal/generate_dgm_provider.py
@@ -0,0 +1,134 @@
import multiprocessing

import numpy as np
import scipy.misc
import scipy.ndimage
import skimage.morphology
from pershombox import calculate_discrete_NPHT_2d

from src.sharedCode.fileSys import Folder
from src.sharedCode.gui import SimpleProgressCounter
from src.sharedCode.provider import Provider


def preprocess_img(img):
label_map, n = skimage.morphology.label(img, neighbors=4, background=0, return_num=True)
volumes = []
for i in range(n):
volumes.append(np.count_nonzero(label_map == (i + 1)))

arg_max = np.argmax(volumes)
img = (label_map == (arg_max + 1))

return img


def get_npht(img, number_of_directions):
img = np.ndarray.astype(img, bool)

npht = calculate_discrete_NPHT_2d(img, number_of_directions)
return npht


def job(args):
sample_file_path = args['file_path']
label = args['label']
sample_id = args['sample_id']
number_of_directions = args['number_of_directions']
return_value = {'label': label, 'sample_id': sample_id, 'dgms': {}}

img = scipy.misc.imread(sample_file_path, flatten=True)
img = preprocess_img(img)
try:
npht = get_npht(img, number_of_directions)

except Exception as ex:
return_value['error'] = ex
else:
dgms_dim_0 = [x[0] for x in npht]
dgms_dim_1 = [x[1] for x in npht]

for dir_i, dgm_0, dgm_1 in zip(range(1, number_of_directions + 1), dgms_dim_0, dgms_dim_1):
if len(dgm_0) == 0:
return_value['error'] = 'Degenerate diagram detected.'
break

return_value['dgms']['dim_0_dir_{}'.format(dir_i)] = dgm_0
return_value['dgms']['dim_1_dir_{}'.format(dir_i)] = dgm_1

return return_value


def generate_dgm_provider(data_path, output_file_path, number_of_directions, n_cores=4):
src_folder = Folder(data_path)
class_folders = src_folder.folders()

n = sum([len(cf.files(name_pred=lambda n: n != 'Thumbs.db')) for cf in class_folders])
progress = SimpleProgressCounter(n)
progress.display()

views = {}
for i in range(1, number_of_directions + 1):
views['dim_0_dir_{}'.format(i)] = {}
views['dim_1_dir_{}'.format(i)] = {}
job_args = []

for class_folder in class_folders:
for view in views.values():
view[class_folder.name] = {}

for sample_file in class_folder.files(name_pred=lambda n: n != 'Thumbs.db'):
args = {'file_path': sample_file.path,
'label': class_folder.name,
'sample_id': sample_file.name,
'number_of_directions': number_of_directions}
job_args.append(args)

pool = multiprocessing.Pool(n_cores)

errors = []
for result in pool.imap(job, job_args):
try:
label = result['label']
sample_id = result['sample_id']

if 'error' in result:
errors.append((sample_id, result['error']))
else:
for view_id, dgm in result['dgms'].items():
views[view_id][label][sample_id] = dgm
progress.trigger_progress()

except Exception as ex:
errors.append(ex)

prv = Provider()
for key, view_data in views.items():
prv.add_view(key, view_data)

meta = {'number_of_directions': number_of_directions}
prv.add_meta_data(meta)

prv.dump_as_h5(output_file_path)

if len(errors) > 0:
print(errors)


if __name__ == '__main__':
from argparse import ArgumentParser
import os.path

parser = ArgumentParser()
parser.add_argument('input_folder_path', type=str)
parser.add_argument('output_file_path', type=str)
parser.add_argument('number_of_directions', type=int)
parser.add_argument('--n_cores', type=int, default=4)

args = parser.parse_args()

output_dir = os.path.dirname(args.output_file_path)
if not os.path.exists(output_dir):
print(output_dir, 'does not exist.')
else:
generate_dgm_provider(args.input_folder_path, args.output_file_path, args.number_of_directions, n_cores=args.n_cores)
Empty file added src/sharedCode/__init__.py
Empty file.
171 changes: 171 additions & 0 deletions src/sharedCode/experiments.py
@@ -0,0 +1,171 @@
import torch
import numpy as np
import time
import shutil
import json
import numpy
import datetime
import os

from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing.label import LabelEncoder
from collections import defaultdict


class PersistenceDiagramProviderCollate:
def __init__(self, provider, wanted_views: [str] = None,
label_map: callable = lambda x: x,
output_type=torch.FloatTensor,
target_type=torch.LongTensor,
gpu=False):
provided_views = provider.view_names

if wanted_views is None:
self.wanted_views = provided_views

else:
for wv in wanted_views:
if wv not in provided_views:
raise ValueError('{} is not provided by {} which provides {}'.format(wv, provider, provided_views))

self.wanted_views = wanted_views

if not callable(label_map):
raise ValueError('label_map is expected to be callable.')

self.label_map = label_map

self.output_type = output_type
self.target_type = target_type
self.gpu = gpu

def __call__(self, sample_target_iter):
batch_views, targets = defaultdict(list), []

for dgm_dict, label in sample_target_iter:
for view_name in self.wanted_views:
dgm = list(dgm_dict[view_name])
dgm = self.output_type(dgm)

batch_views[view_name].append(dgm)

targets.append(self.label_map(label))

targets = self.target_type(targets)

if self.gpu:
targets = targets.cuda()

return batch_views, targets


class SubsetRandomSampler:
def __init__(self, indices):
self.indices = indices

def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices)))

def __len__(self):
return len(self.indices)


def train_test_from_dataset(dataset,
test_size=0.2,
batch_size=64,
gpu=False,
wanted_views=None):

sample_labels = list(dataset.sample_labels)
label_encoder = LabelEncoder().fit(sample_labels)
sample_labels = label_encoder.transform(sample_labels)

label_map = lambda l: int(label_encoder.transform([l])[0])
collate_fn = PersistenceDiagramProviderCollate(dataset, label_map=label_map, gpu=gpu, wanted_views=wanted_views)

sp = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
train_i, test_i = list(sp.split([0]*len(sample_labels), sample_labels))[0]

data_train = DataLoader(dataset,
batch_size=batch_size,
collate_fn=collate_fn,
shuffle=False,
sampler=SubsetRandomSampler(train_i.tolist()))

data_test = DataLoader(dataset,
batch_size=batch_size,
collate_fn=collate_fn,
shuffle=False,
sampler=SubsetRandomSampler(test_i.tolist()))

return data_train, data_test


class UpperDiagonalThresholdedLogTransform:
def __init__(self, nu):
self.b_1 = (torch.Tensor([1, 1]) / np.sqrt(2))
self.b_2 = (torch.Tensor([-1, 1]) / np.sqrt(2))
self.nu = nu

def __call__(self, dgm):
if dgm.ndimension() == 0:
return dgm

x = torch.mul(dgm, self.b_1.repeat(dgm.size(0), 1))
x = torch.sum(x, 1).squeeze()
y = torch.mul(dgm, self.b_2.repeat( dgm.size(0), 1))
y = torch.sum(y, 1).squeeze()
i = (y <= self.nu)
y[i] = torch.log(y[i] / self.nu) + self.nu
ret = torch.stack([x, y], 1)
return ret


def pers_dgm_center_init(n_elements):
centers = []
while len(centers) < n_elements:
x = np.random.rand(2)
if x[1] > x[0]:
centers.append(x.tolist())

return torch.Tensor(centers)


def run_experiment_n_times(n, experiment, experiment_file_path):
tmp_dir_path = os.path.join(os.getcwd(), str(time.time()))
os.mkdir(tmp_dir_path)

exp_file_name = os.path.basename(experiment_file_path)
shutil.copy(experiment_file_path, os.path.join(tmp_dir_path, exp_file_name))

date = datetime.datetime.now()
date = date.strftime("%Y-%m-%d %H:%M:%S").replace(' ', '_')

res_pth = os.path.join(tmp_dir_path, 'results__' + date + '.json')

result = []

for i in range(n):

print('==================^================')
print('Run {}'.format(i))
res_of_run = experiment()

# model = res_of_run['model']
#
# with open(os.path.join(tmp_dir_path, 'model_run_{}.pickle'.format(i)), 'bw') as f:
# pickle.dump(model, f)

del res_of_run['model']

result.append(res_of_run)

with open(res_pth, 'w') as f:
json.dump(result, f)

avg_test_acc = numpy.mean([numpy.mean(r['test_accuracies'][-10:]) for r in result])

new_folder_name = '{}_{:.2f}_acc_on_{}'.format(exp_file_name.split('.py')[0], avg_test_acc, date)
new_folder_name.replace('.', '_')
os.rename(tmp_dir_path, os.path.join(os.path.dirname(tmp_dir_path), new_folder_name))

0 comments on commit 0e22daa

Please sign in to comment.