In [None]:
import os
import json
import subprocess
import traceback
import torch
import mne
import scipy
import json
import torchvision
import multiprocessing

import scipy.signal

import matplotlib.pyplot as plt
import numpy as np

class dotdict(dict):
	"""dot.notation access to dictionary attributes"""
	__getattr__ = dict.get
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__


# > Configuration

In [None]:
# Path config

DATA_PATH = 'Cleared'
OUT_PATH  = f'{DATA_PATH}-converted'
CLEARED_PATH   = OUT_PATH
VISUAL_SUBPATH = 'Visual'
AUDIAL_SUBPATH = 'Audial'

# Data config
# 2X - internal speaking
# 1X - stimulus
# Phoneme length ~300ms

TARGET_CHANNELS = 4
TARGET_CHANNEL_SETS = [
	[ 'EEG F7-A1', 'EEG F7-M1' ],
	[ 'EEG F3-A1', 'EEG F3-M1' ],
	[ 'EEG T3-A1', 'EEG T3-M1' ],
	[ 'EEG C3-A1', 'EEG C3-M1' ],
]
SOURCE_FREQ     = 1000 # Article: 1000Hz
SECTOR_LENGTH   = 600
SECTOR_LENGTH_STEPS = 600
MAX_MORLET_FREQ = 30
MORLET_FREQ_STEPS = 30
LOW_PASS_FREQ   = 3
HIGH_PASS_FREQ  = 30
MAX_SAMPLE_LENGTH = 1.5

# Phonemes are enumerated in range 2, 3, 4, 5, 6, 7, 8
MIN_PHONEME_ID = 2
PHONEME_COUNT  = 7

# Directories
MORLET_ORIGINAL_SAVE_DIR = 'morlet-original'

# List of EDF files to use
# These files are taken from CLEARED_PATH/VISUAL_SUBPATH and CLEARED_PATH/AUDIAL_SUBPATH
INPUT_EDF_LIST = [
	'Antonovazrf_och',
	'BazvlkDzrf_och',
	'DachaPapzrf_och',
	'Drachenkozrf_och',
	'Gordokovzrf_och',
	'Manenkovzrf_och',
	'pavluhinNzrf_och',
	'rylkovSzrf_och',
	'Sazanovazrf_och',
	'vinickiDzrf_och',
]


# Flags
CONVERT_EDF = False
CONVERT_MORLETS = False
ENABLE_DEMO = False

# > Convert all files from custom EDF to normal EDF

**WARNING: Remove headers for each EDF file via EDFBrowser**

In [None]:
def convert_edf():
	"""
	Convert EDF files using Mitr_Edf.exe utility.
	"""
	
	os.makedirs(f'{OUT_PATH}/{VISUAL_SUBPATH}', exist_ok=True)
	os.makedirs(f'{OUT_PATH}/{AUDIAL_SUBPATH}', exist_ok=True)

	for file in os.listdir(f'{DATA_PATH}/{VISUAL_SUBPATH}'):
		processed_file = f'{DATA_PATH}/{VISUAL_SUBPATH}/{file[:-4]}_.EDF'
		target_file    =  f'{OUT_PATH}/{VISUAL_SUBPATH}/{file}'
		file           = f'{DATA_PATH}/{VISUAL_SUBPATH}/{file}'
		
		print('Processing', file)
		try:
			os.remove(processed_file)
		except:
			pass
		subprocess.run(f'Mitr_Edf.exe {file}')
		
		print('Moving', processed_file, 'to', target_file)
		try:
			os.rename(processed_file, target_file)
		except:
			try:
				os.replace(processed_file, target_file)
			except:
				print('Critical failture')
				traceback.print_exc()
				break

	for file in os.listdir(f'{DATA_PATH}/{AUDIAL_SUBPATH}'):
		processed_file = f'{DATA_PATH}/{AUDIAL_SUBPATH}/{file[:-4]}_.EDF'
		target_file    =  f'{OUT_PATH}/{AUDIAL_SUBPATH}/{file}'
		file           = f'{DATA_PATH}/{AUDIAL_SUBPATH}/{file}'
		
		print('Processing', file)
		try:
			os.remove(processed_file)
		except:
			pass
		subprocess.run(f'Mitr_Edf.exe {file}')
		
		print('Moving', processed_file, 'to', target_file)
		try:
			os.rename(processed_file, target_file)
		except:
			try:
				os.replace(processed_file, target_file)
			except:
				print('Critical failture')
				traceback.print_exc()
				break

# > Basic functions for data processing & train

## Read EDF files, extract, split into sectors 

In [None]:
def extract_sectors(edf): # Returns sectors[begin,end] and missing_labels
	"""
	Extract valid and invalid sectors from edf data file
	"""
	sectors = []
	missing_labels = []
	last_label = None
	last_label_index = None
	
	METKA = edf['METKA']
	X = METKA[1]
	Y = METKA[0].T[:,0]

	for index, (timestamp, value) in enumerate(zip(X, Y)):
		if value > 0:
			value = int(value)
			
			# Phoneme begin
			if value // 10 == 1:
				if last_label is not None:
					if last_label // 10 == 1:
						missing_labels.append(last_label_index)
				
				last_label = value
				last_label_index = index
			
			# Phoneme end
			elif value // 10 == 2:
				if last_label is not None:
					if last_label % 10 != value % 10:
						missing_labels.append(last_label_index)
						missing_labels.append(index)
					else:
						sectors.append((last_label_index, index))
						last_label = None
				else:
					missing_labels.append(index)
	
	return sectors, missing_labels

def extract_strict_sectors(edf, sector_length): # Returns sectors[begin,end] and labels
	"""
	Extract sectors of the given length using begin and start labels.
	Sample usage is extracting sectors of length 600 ms (1000Hz).
	"""
	
	sectors = []
	labels = []
	invalid_labels = []
	last_sector_end_index = None
	
	METKA = edf['METKA']
	X = METKA[1]
	Y = METKA[0].T[:,0]

	for index, (timestamp, value) in enumerate(zip(X, Y)):
		if value > 0:
			
			if last_sector_end_index is not None and index < last_sector_end_index:
				invalid_labels.append(index)
				continue
			
			# Assume that sector [index : index+sector_length] does not 
			#  intersect with other sector [index2 : index2+sector_length]
			value = int(value)
			
			# Phoneme begin
			if value // 10 == 1:
				
				if index + sector_length > len(X):
					invalid_labels.append(index)
					continue
				
				# Append sector from current position to current+sector_length as sector
				sectors.append((index, index + sector_length))
				labels.append(value % 10)
				last_sector_end_index = index + sector_length
			
			# Phoneme end
			elif value // 10 == 2:
				
				# Ignore underflow
				if index - sector_length < 0:
					invalid_labels.append(index)
					continue
				
				# Append sector from current position to current+sector_length as sector
				sectors.append((index - sector_length, index))
				labels.append(value % 10)
				last_sector_end_index = index
	
	return sectors, invalid_labels, labels

def extract_strict_sectors_with_offset(edf, sector_length, first_label_offset): # Returns sectors[begin,end] and labels
	"""
	Extract sectors of the given length using begin and start labels.
	Sample usage is extracting sectors of length 600 ms (1000Hz).
	"""
	
	sectors = []
	labels = []
	invalid_labels = []
	last_sector_end_index = None
	
	METKA = edf['METKA']
	X = METKA[1]
	Y = METKA[0].T[:,0]

	for index, (timestamp, value) in enumerate(zip(X, Y)):
		if value > 0:
			
			if last_sector_end_index is not None and index < last_sector_end_index:
				invalid_labels.append(index)
				continue
			
			# Assume that sector [index : index+sector_length] does not 
			#  intersect with other sector [index2 : index2+sector_length]
			value = int(value)
			
			# Phoneme begin
			if value // 10 == 1:
				
				if index + sector_length > len(X):
					invalid_labels.append(index)
					continue
				
				# Append sector from current position to current+sector_length as sector
				sectors.append((index + first_label_offset, index + first_label_offset + sector_length))
				labels.append(value % 10)
				last_sector_end_index = index + first_label_offset + sector_length
			
			# Phoneme end
			elif value // 10 == 2:
				
				# Ignore underflow
				if index - sector_length < 0:
					invalid_labels.append(index)
					continue
				
				# Append sector from current position to current+sector_length as sector
				sectors.append((index - sector_length, index))
				labels.append(value % 10)
				last_sector_end_index = index
	
	return sectors, invalid_labels, labels

def print_sectors_summary(edf, sectors, missing_labels):
	"""
	Print summary info about given set of sectors
	"""
	print('sectors:', len(sectors))
	print('invalid:', len(missing_labels))
	
	METKA = edf['METKA']
	X = METKA[1]
	
	diff = np.array([ X[b] - X[a] for (a, b) in sectors ])

	print('min sector length:', np.min(diff))
	print('max sector length:', np.max(diff))
	print('avg sector length:', np.average(diff))

def plot_labels(edf, missing_labels):
	"""
	Plot distribution of valid and invalid labels
	"""
	
	METKA = edf['METKA']
	X = METKA[1]
	Y = METKA[0].T[:,0]
	
	plt.rcParams["figure.figsize"] = (25, 5)
	plt.rcParams["font.size"] = 14

	for index in range(len(X)):
		if Y[index] > 0:
			if index in missing_labels:
				plt.scatter(X[index], Y[index], color='red', marker='x')
			else:
				plt.scatter(X[index], Y[index], color='blue', marker='.')

	plt.show()

def plot_sectors(sectors, missing_labels):
	"""
	Plot segments on single line, inclusing invalid sectors
	"""
	
	plt.rcParams["figure.figsize"] = (25, 2.5)
	plt.rcParams["font.size"] = 14

	# Plot correct labels
	for index, sector in enumerate(sectors):
		plt.plot(sector, (0, 0), color='blue', marker='|', label='sectors' if index == 0 else None)

	# plot invalid labels
	for index, miss in enumerate(missing_labels):
		plt.scatter(miss, 0, color='red', marker='|', label='invalid sectors' if index == 0 else None)

	plt.legend()
	plt.show()

def list_visual_edf():
	"""
	List visual EDF file names
	"""
	
	return os.listdir(f'{OUT_PATH}/{VISUAL_SUBPATH}')

def open_visual_edf(filename):
	"""
	Open visual data file and return EDF object
	"""
	file = f'{OUT_PATH}/{VISUAL_SUBPATH}/{filename}'

	return mne.io.read_raw_edf(file)

def list_audial_edf():
	"""
	List audial EDF file names
	"""
	
	return os.listdir(f'{OUT_PATH}/{AUDIAL_SUBPATH}')

def open_audial_edf(filename):
	"""
	Open audial data file and return EDF object
	"""
	file = f'{OUT_PATH}/{AUDIAL_SUBPATH}/{filename}'

	return mne.io.read_raw_edf(file)

## Channel extraction and wavelet pass

In [None]:
def subselect_channels(edf):
	print(f'Available channels: {edf.ch_names}')
	
	channels = [ None ] * TARGET_CHANNELS
	for i in range(TARGET_CHANNELS):
		
		# Iterate over all channels find compatible channel names
		for comatible in range (len(TARGET_CHANNEL_SETS[i])):
			try:
				channels[i] = edf[TARGET_CHANNEL_SETS[i][comatible]][0][0]
				break
			except:
				continue
		
		if channels[i] is None:
			raise RuntimeError(f'No compatible channels found for channels {TARGET_CHANNEL_SETS[i]}')
	
	return channels

def butterworth_filter_pass(edf, channels_data):
	
	filtered = [ None ] * TARGET_CHANNELS
	
	for index, cd in enumerate(channels_data):
		filtered[index] = mne.filter.filter_data(cd, SOURCE_FREQ, LOW_PASS_FREQ, HIGH_PASS_FREQ, method='iir')
	
	return filtered

def split_sectors(edf, channels_data, sectors):
	"""
	Performs slicing of the given channels using sector info data.
	Returns label number, split length, split duration and splitted data for channels
	"""
	
	METKA = edf['METKA']
	X = METKA[1]
	Y = METKA[0][0] # .T[:,0]
	
	splitted  = [ [ None ] * len(sectors) for i in range(len(channels_data)) ]
	lengths   = [ None ] * len(sectors)
	durations = [ None ] * len(sectors)
	labels    = [ None ] * len(sectors)
	
	for index in range(len(sectors)):
		(a, b) = sectors[index]
		
		labels[index]    = int(Y[a]) % 10
		lengths[index]   = b - a
		durations[index] = X[b] - X[a]
		
		for index2, f in enumerate(channels_data):
			splitted[index2][index] = f[a:b]
	
	return labels, lengths, durations, splitted

def single_morlet_wavelet_pass(sample, w = 6.):
	"""
	Apply wavelet transform on the givven sample
	"""
	
	t, dt = np.linspace(0, SECTOR_LENGTH / SOURCE_FREQ, SECTOR_LENGTH, retstep=True)
	freq = np.linspace(1, MAX_MORLET_FREQ, MAX_MORLET_FREQ)
	fs = 1 / dt
	widths = w * fs / (2 * freq * np.pi)
	
	return t[::SECTOR_LENGTH//SECTOR_LENGTH_STEPS], freq[::MAX_MORLET_FREQ//MORLET_FREQ_STEPS], scipy.signal.cwt(sample, scipy.signal.morlet2, widths, w=w)[::MAX_MORLET_FREQ//MORLET_FREQ_STEPS,::SECTOR_LENGTH//SECTOR_LENGTH_STEPS]

def rescale_morlet_plz(sample):
	"""
	Rescale from shape (MAX_MORLET_FREQ, SECTOR_LENGTH) to shape 
	(MORLET_FREQ_STEPS, SECTOR_LENGTH_STEPS)
	"""
	
	NW = SECTOR_LENGTH_STEPS
	FW = (SECTOR_LENGTH // SECTOR_LENGTH_STEPS)
	
	if SECTOR_LENGTH == SECTOR_LENGTH_STEPS:
		return sample
	
	sample = np.reshape(sample, (MAX_MORLET_FREQ, NW, FW)).mean(axis=2)
	
	if MAX_MORLET_FREQ != MORLET_FREQ_STEPS:
		raise RuntimeError('Incomplete code')
	
	return sample

def morlet_wavelet_pass(channel_splitted_data, w = 6.):
	"""
	Performs wavelet transform over the given data. Returns 2D matrixes 
	representing morlet transform application result for each of 4 channels for 
	each of N samples.
	
	channel_splitted_data contains 4 channels, each has a set of splitted 
	samples in it.
	"""
	
	t, dt = np.linspace(0, SECTOR_LENGTH / SOURCE_FREQ, SECTOR_LENGTH, retstep=True)
	freq = np.linspace(1, MAX_MORLET_FREQ, MAX_MORLET_FREQ)
	fs = 1 / dt
	widths = w * fs / (2 * freq * np.pi)
	# [::MAX_MORLET_FREQ//MORLET_FREQ_STEPS,::SECTOR_LENGTH//SECTOR_LENGTH_STEPS]
	
	# NW = SECTOR_LENGTH_STEPS
	# np.reshape(scipy.signal.cwt(channel_splitted_data[channel][index], scipy.signal.morlet2, widths, w=w), (NW, FW, FH, -1)).mean(axis=3).mean(axis=1)
	# 	scipy.signal.cwt(channel_splitted_data[channel][index], scipy.signal.morlet2, widths, w=w)
	# if SECTOR_LENGTH == SECTOR_LENGTH_STEPS else
	# 	np.reshape(scipy.signal.cwt(channel_splitted_data[channel][index], scipy.signal.morlet2, widths, w=w), (NW, FW, FH, -1)).mean(axis=3).mean(axis=1)
	# scipy.signal.cwt(channel_splitted_data[channel][index], scipy.signal.morlet2, widths, w=w)[::MAX_MORLET_FREQ//MORLET_FREQ_STEPS,::SECTOR_LENGTH//SECTOR_LENGTH_STEPS]
	
	FW = (MAX_MORLET_FREQ // MORLET_FREQ_STEPS)
	FH = (SECTOR_LENGTH // SECTOR_LENGTH_STEPS)
	
	return t[::FH], freq[::FW], [ 
		[
			rescale_morlet_plz(scipy.signal.cwt(channel_splitted_data[channel][index], scipy.signal.morlet2, widths, w=w))
		for index in range(len(channel_splitted_data[channel]))
		]
	for channel in range(4)
	]

def transpose_morlet_channel_data(morlet_channel_data):
	"""
	Perform transposition of channel data so order changes from
	
	morlet_channel_data[channel][index]
	
	to
	
	morlet_channel_data[index][channel]
	"""
	
	return [
		[
			morlet_channel_data[channel][index]
		for channel in range(TARGET_CHANNELS)
		]
	for index in range(len(morlet_channel_data[0]))
	]

def abs_morlet_data(morlet):
	return np.abs(morlet)


## Data save & load

In [None]:
def save_morlet(filename, morlet):
	"""
	Write numpy morlet data to file
	"""
	
	np.save(filename, morlet)

def read_morlet(filename):
	"""
	Read numpy morlet data from file
	"""
	
	if os.path.exists(filename):
		return np.load(filename)
	
def auto_save_morlet(directory, person, phoneme, channel, sample, phoneme_data):
	"""
	Write phoneme data to file with the following parameters:
	
	directory - directory to place files in
	
	person - index of person / edf file
	
	phoneme - index of phoneme
	
	channel - index of used channel
	
	sample - index of this phoneme's sample
	
	phoneme_data - phoneme morlet data
	"""
	
	os.makedirs(directory, exist_ok=True)
	
	save_morlet(f'{directory}/morlet_{person}_{phoneme}_{channel}_{sample}.npy', phoneme_data)

def auto_load_morlet(directory, person, phoneme, channel, sample):
	"""
	Read morlet numpy data in the same way as auto_save_morlet()
	"""
	
	return read_morlet(f'{directory}/morlet_{person}_{phoneme}_{channel}_{sample}.npy')

def get_morlet_count(directory, person, phoneme):
	"""
	Get count of morlet samples for given person and phoneme ID
	"""
	
	try:
		with open(f'{directory}/count.json', 'r') as f:
			data = json.load(f)
			
			return data[f'person_{person}'][f'phoneme_{phoneme}']
	except:
		return 0

def set_morlet_count(directory, person, phoneme, count):
	"""
	Get count of morlet samples for given person and phoneme ID
	"""
	
	os.makedirs(directory, exist_ok=True)
	
	try:
		with open(f'{directory}/count.json', 'r') as f:
			data = json.load(f)
	except:
		data = {}
			
	if f'person_{person}' not in data:
		data[f'person_{person}'] = {}
	
	data[f'person_{person}'][f'phoneme_{phoneme}'] = count
	
	with open(f'{directory}/count.json', 'w') as f:
		json.dump(data, f)

def get_total_morlet_count(directory, person):
	"""
	Get count of morlet samples for given person ID
	"""
	
	try:
		with open(f'{directory}/count.json', 'r') as f:
			data = json.load(f)
			
			return data[f'person_{person}'][f'total']
	except:
		return 0

def set_total_morlet_count(directory, person, count):
	"""
	Get count of morlet samples for given person ID
	"""
	
	os.makedirs(directory, exist_ok=True)
	
	try:
		with open(f'{directory}/count.json', 'r') as f:
			data = json.load(f)
	except:
		data = {}
			
	if f'person_{person}' not in data:
		data[f'person_{person}'] = {}
	
	data[f'person_{person}'][f'total'] = count
	
	with open(f'{directory}/count.json', 'w') as f:
		json.dump(data, f)

def normalize_labels(labels):
	"""
	Normalize label values.
	
	Source label values start ffrom MIN_PHONEME_ID, normalization substracts 
	MIN_PHONEME_ID from each phoneme ID.
	"""
	
	return [ p - MIN_PHONEME_ID for p in labels ]

def group_morlet_by_phoneme(normalized_morlet_labels, morlet_list):
	"""
	Group morlet data by phoneme ID  
	"""
	
	result = [ [] for _ in range(PHONEME_COUNT) ]
	for label, morlet in zip(normalized_morlet_labels, morlet_list):
		result[label].append(morlet)
	return result

def ungroup_morlet_by_phoneme(grouped_morlet_list):
	"""
	Performs reverse operation by concatenating all groups
	"""
	
	result = grouped_morlet_list[0]
	labels = [ 0 ] * len(grouped_morlet_list[0])
	for i in range(1, PHONEME_COUNT):
		result = result + grouped_morlet_list[i]
		labels = labels + [ i ] * len(grouped_morlet_list[i])
	
	return labels, result

def save_person_grouped_morlet_list(directory, person, grouped_morlet_list):
	"""
	Save given morlet data for a person and update total person morlet data count
	
	grouped_morlet_list is a list containing grouped morlet data for each phoneme ID.
	
	grouped_morlet_list[0] contains all samples for phoneme 0, e.t.c.
	"""
	
	os.makedirs(directory, exist_ok=True)
	
	# Update total
	set_total_morlet_count(directory, person, sum([ len(gml) for gml in grouped_morlet_list ]))
	
	for phoneme in range(PHONEME_COUNT):
		
		gml = grouped_morlet_list[phoneme]
		
		# Update count
		set_morlet_count(directory, person, phoneme, len(gml))
		
		# Iterate over channels & save morlet data
		for index in range(len(gml)):
			for channel in range(TARGET_CHANNELS):	
				auto_save_morlet(directory, person, phoneme, channel, index, gml[index][channel])

def load_person_grouped_morlet_list(directory, person):
	"""
	Load data for the gven person
	"""
	
	return [
		[
			[
				auto_load_morlet(directory, person, phoneme, channel, index)
			for channel in range(TARGET_CHANNELS)
			]
		for index in range(get_morlet_count(directory, person, phoneme))
		]
	for phoneme in range(PHONEME_COUNT)
	]


## Dataset

In [None]:
import torch.utils.data 
import multiprocessing

class MorletDataset(torch.utils.data.Dataset):
	"""
	Dataset class for morlet data
	"""
	
	def __init__(self, labels, morlets, transform=None, target_transform=None):
		self.labels = labels
		self.morlets = morlets
		self.transform = transform
		self.target_transform = target_transform
		
	def __len__(self):
		return len(self.labels)
		
	def __getitem__(self, idx):
		morlet, label = self.morlets[idx], self.labels[idx]
		if self.transform:
			morlet = self.transform(morlet)
		if self.target_transform:
			label = self.target_transform(label)
		return morlet, label


## Transform

In [None]:
import random
import cv2

class NoiseTransform(object):
	"""
	Add noise to the sample
	"""
	
	def __init__(self, noise_scale):
		self.noise_scale = noise_scale

	def __call__(self, sample):
		morlet = sample
		return morlet + np.random.normal(0, self.noise_scale, morlet.shape)

# class ShiftTransform(object):
# 	"""
# 	Randomly shift morlet spectrogram
# 	"""

# 	def __call__(self, sample):
# 		morlet, label = sample
# 		return np.roll(morlet, random.randint(0, SECTOR_LENGTH_STEPS)), label

class ResizeShiftTransform(object):
	"""
	Perform rescale of the meorlet and shifting it in a random position
	"""
	
	def __init__(self, max_scale, max_roll):
		self.max_scale = max_scale
		self.max_roll = max_roll
		self.max_pad = int(SECTOR_LENGTH_STEPS * (max_scale - 1))

	def __call__(self, sample):
		morlet = sample.transpose(1, 2, 0)
		rand_pad = random.randint(0, self.max_pad)
		resized = cv2.resize(morlet, dsize=(morlet.shape[1] + rand_pad, morlet.shape[0]), interpolation=cv2.INTER_CUBIC).transpose(2, 0, 1)

		roll = np.roll(resized, -random.randint(0, int(rand_pad * self.max_roll)), 2)

		crop = roll[:,:,0:SECTOR_LENGTH_STEPS]
		
		return crop

class FlipAlongTime(object):
	"""
	Flip data along time axis
	"""
	
	def __call__(self, sample):
		return np.flip(sample, axis=2) if random.randint(0, 1) else sample

class ToTensor(object):
	"""
	Convert ndarray to tensor
	"""

	def __call__(self, sample):
		morlet = sample
		
		return torch.from_numpy(morlet.copy())


## Network model class

In [None]:
class MNetwork(torch.nn.Module):
	"""
	Base blass for morlet classification network.
	
	conv_layers - list of dicts defining each convolutional layer
	"""
	def __init__(self, conv_layers, dense_layers, num_classes, use_conv_sigmoid=False, use_dense_sigmoid=False, print_log=True):
		super().__init__()
		
		# Conv layers
		self.conv = []
		# MaxPool layers
		self.pool = []
		# Dropout for layers
		self.conv_dropout = []
		# Fully connected layers
		self.fc = []
		# Dropout for layers
		self.fc_dropout = []
		
		# Use sigmoid or relu
		self.use_conv_sigmoid = use_conv_sigmoid
		self.use_dense_sigmoid = use_dense_sigmoid
		
		# Save useful parameters
		self.conv_layers = conv_layers
		self.dense_layers = dense_layers
		self.num_classes = num_classes
		
		# Add all conv layers
		# Input conv layer 'in' parameter is ignored and equals to TARGET_CHANNELS x MORLET_FREQ_STEPS x SECTOR_LENGTH_STEPS
		# Kernel is tuple
		# pool is tuple
		last_conv_out = TARGET_CHANNELS
		result_dimension = [ TARGET_CHANNELS, MORLET_FREQ_STEPS, SECTOR_LENGTH_STEPS ]
		for i, layer in enumerate(conv_layers):
			self.conv.append(torch.nn.Conv2d(last_conv_out, layer['out'], layer['kernel']))
			self.__setattr__(f'conv{i}', self.conv[-1])
			last_conv_out = layer['out']
			
			self.pool.append(torch.nn.MaxPool2d(layer['pool']))
			self.__setattr__(f'pool{i}', self.pool[-1])
			
			self.conv_dropout.append(torch.nn.Dropout(layer['dropout']))
			self.__setattr__(f'conv_dropout{i}', self.conv_dropout[-1])
			
			# Recalculate dimensions
			result_dimension[0] = layer['out']
			result_dimension[1] = result_dimension[1] - layer['kernel'][0] + 1
			result_dimension[2] = result_dimension[2] - layer['kernel'][1] + 1
			result_dimension[1] = result_dimension[1] // layer['pool'][0]
			result_dimension[2] = result_dimension[2] // layer['pool'][1]
		
		if print_log:
			print(f'conv-dense dimension:', result_dimension)
		
		# Add fully connected
		# dense_layers contain only layers sizes between input and output of dense net
		# Input of the dense has dimensions of the last conv layer
		# Outputs of the dense has dimensions of the num_classes
		last_fc_out = result_dimension[0] * result_dimension[1] * result_dimension[2]
		for i, layer in enumerate(dense_layers):
			self.fc.append(torch.nn.Linear(last_fc_out, layer['count']))
			self.__setattr__(f'fc{i}', self.fc[-1])
			last_fc_out = layer['count']
			
			self.fc_dropout.append(torch.nn.Dropout(layer['dropout']))
			self.__setattr__(f'fc_dropout{i}', self.fc_dropout[-1])
		
		# Append last fc layer
		self.fc.append(torch.nn.Linear(last_fc_out, num_classes))
	
	def forward(self, x):
		
		# Apply conv
		for conv, pool, drop in zip(self.conv, self.pool, self.conv_dropout):
			if self.use_conv_sigmoid:
				x = drop(pool(torch.sigmoid(conv(x))))
			else:
				x = drop(pool(torch.relu(conv(x))))
		
		# Flatten
		x = torch.flatten(x, 1)
		
		# Dense
		for fc, drop in zip(self.fc[:-1], self.fc_dropout):
			if self.use_dense_sigmoid:
				x = drop(torch.sigmoid(fc(x)))
			else:
				x = drop(torch.relu(fc(x)))
		
		x = self.fc[-1](x)
		
		return x
	
	def save_model(self, filename):
		config = {
			'conv_layers': self.conv_layers,
			'dense_layers': self.dense_layers,
			'num_classes': self.num_classes,
			'use_conv_sigmoid': self.use_conv_sigmoid,
			'use_dense_sigmoid': self.use_dense_sigmoid
		}
		
		# Save config
		with open(f'{filename}-config.json', 'w') as f:
			json.dump(config, f)
		
		# Save all parameters
		for i, conv in enumerate(self.conv):
			state = conv.state_dict()
			
			for key in state.keys():
				torch.save(state[key], f'{filename}-conv{i}-{key}.tensor')
		
		# Save all parameters
		for i, fc in enumerate(self.fc):
			state = fc.state_dict()
			
			for key in state.keys():
				torch.save(state[key], f'{filename}-fc{i}-{key}.tensor')
	
	def load_model(filename):
		# Save config
		with open(f'{filename}-config.json', 'r') as f:
			config = json.load(f)
		
		# Create model
		model = MNetwork(config['conv_layers'], config['dense_layers'], config['num_classes'], config['use_conv_sigmoid'], config['use_dense_sigmoid'])
		
		# Save all parameters
		for i, conv in enumerate(model.conv):
			state = conv.state_dict()
			
			for key in state.keys():
				state[key] = torch.load(f'{filename}-conv{i}-{key}.tensor')
			
			conv.load_state_dict(state)
		
		# Save all parameters
		for i, fc in enumerate(model.fc):
			state = fc.state_dict()
			
			for key in state.keys():
				state[key] = torch.load(f'{filename}-fc{i}-{key}.tensor')
			
			fc.load_state_dict(state)
		
		return model


## Model evaluation, accuracy

In [None]:
def calculate_match_on_dataset(model: MNetwork, test_dataset):
	count = 0
	with torch.no_grad():
		model.eval()
		for morlet, label in test_dataset:
			output = model(morlet[None, ...].float())
			if output.shape[-1] == 1:
				count += ((output > 0.5).int() == label).sum().item()
			else:
				predicted = torch.argmax(output, 1)
				count += (predicted == label).sum().item()
	
	return count / len(test_dataset)

def calculate_fail_on_dataset(model: MNetwork, test_dataset):
	count = 0
	with torch.no_grad():
		model.eval()
		for morlet, label in test_dataset:
			output = model(morlet[None, ...].float())
			if output.shape[-1] == 1:
				count += ((output > 0.5).int() != label).sum().item()
			else:
				predicted = torch.argmax(output, 1)
				count += (predicted != label).sum().item()
	
	return count / len(test_dataset)


## Model training

In [None]:
def train_model(model: MNetwork, train_loader, optimizer, scheduler, criterion, train_epochs, print_log=False, print_iters=100, print_first_iter=True, print_last_iter=True):
	"""
	Train model for given amount of epochs
	"""
	
	model.train()
	for epoch in range(train_epochs):
		for iter, (morlets, labels) in enumerate(train_loader):
			optimizer.zero_grad()
			outputs = model(morlets.float())
			
			if outputs.shape[-1] == 1:
				outputs = torch.sigmoid(outputs.flatten())
				loss = criterion(outputs, labels.float())
			else:
				loss = criterion(outputs, labels.long())
				
			loss.backward()
			optimizer.step()
			
			if print_log:
				if (iter % print_iters == 0):
					if iter == 0 and print_first_iter:
						print(f'epoch = {epoch}, iter = {iter}, loss = {loss.item()}')
					elif iter > 0:
						print(f'epoch = {epoch}, iter = {iter}, loss = {loss.item()}')
				elif print_last_iter and iter == len(train_loader) - 1:
					print(f'epoch = {epoch}, iter = {iter}, loss = {loss.item()}')
		
		if scheduler is not None:
			scheduler.step()


## Dataset utilities for training

In [None]:
def load_dataset(visual = None, person = None, sector_length_steps = None, morlet_freq_steps = None):
	"""
	Load data with given options.
	If option is set to None, data  for all values of this optio is loaded.
	
	Returns labels, morlets in ungrouped mode
	"""
	
	labels, morlets = [], []
	
	person_list = [ person ]
	visual_list = [ visual ]
	
	if person is None:
		person_list = list(range(len(INPUT_EDF_LIST)))
	
	if visual is None:
		visual_list = [ False, True ]
	
	sector_length_steps = sector_length_steps if sector_length_steps is not None else SECTOR_LENGTH_STEPS
	morlet_freq_steps = morlet_freq_steps if morlet_freq_steps is not None else MORLET_FREQ_STEPS
	
	for visual in visual_list:
		for person in person_list:
			# Subdirectory matching oble audial or visual
			edf_subdir = VISUAL_SUBPATH if visual else AUDIAL_SUBPATH
			# Full directory path to morlet files
			morlet_dir = f'{MORLET_ORIGINAL_SAVE_DIR}/width-{sector_length_steps}_height-{morlet_freq_steps}/{edf_subdir}'

			# Load
			loaded_data = load_person_grouped_morlet_list(morlet_dir, person=person)
			
			loaded_labels, loaded_morlets = ungroup_morlet_by_phoneme(loaded_data)
			
			labels += loaded_labels
			morlets += loaded_morlets
	
	return np.array(labels), np.array(morlets)

def select_phonemes(labels, morlets, phoneme_pair = None):
	"""
	Select given pair of phonemes from given data
	"""
	
	if phoneme_pair is None:
		return labels, morlets
	
	labels, morlets = labels.copy(), morlets.copy()
	
	# Sumselect required classes
	if phoneme_pair is not None:
		cond = np.isin(labels, phoneme_pair)
		morlets = morlets[cond]
		labels = labels[cond]
		
		for i in range(len(labels)):
			if labels[i] == phoneme_pair[0]:
				labels[i] = 0
			else:
				labels[i] = 1
	
	return labels, morlets

def normalize_morlets(morlets):
	return morlets / (morlets.max() - morlets.min())

def train_test_split(labels, morlets, test_size):
	indices = np.arange(len(morlets))
	np.random.shuffle(indices)
	labels, morlets = labels[indices], morlets[indices]
	train_count = int(len(morlets) * (1.0 - test_size))
	
	train_labels, train_morlets = labels[0:train_count], morlets[0:train_count]
	test_labels, test_morlets = labels[train_count:-1], morlets[train_count:-1]
	return train_labels, train_morlets, test_labels, test_morlets


# > Sample usage demos

## Load EDF

In [None]:
if ENABLE_DEMO:
	edf = open_audial_edf(f'{INPUT_EDF_LIST[0]}.edf')

	# New
	sectors, invalid_sectors, labels = extract_strict_sectors(edf, SECTOR_LENGTH)
	print_sectors_summary(edf, sectors, invalid_sectors)


## Preview labels

In [None]:
if ENABLE_DEMO:
	plot_labels(edf, invalid_sectors)

## Preview sectors

In [None]:
if ENABLE_DEMO:
	plot_sectors(sectors, invalid_sectors)

## Preprocess

In [None]:
if ENABLE_DEMO:
	channels_data = subselect_channels(edf)
	_, lengths, durations, splitted = split_sectors(edf, channels_data, sectors)
	t, freq, morlet = morlet_wavelet_pass(splitted)
	morlet = transpose_morlet_channel_data(morlet)
	morlet = abs_morlet_data(morlet)

## Preview channel samples

In [None]:
if ENABLE_DEMO:
	index = 0

	plt.rcParams["figure.figsize"] = (20, 10)
	plt.rcParams["font.size"] = 14
	fig, axs = plt.subplots(2, 2)

	for channel in range(4):
		value = np.abs(morlet[index][channel])
		print(t.shape, freq.shape, value.shape)
		
		axs[channel % 2][channel // 2].pcolormesh(t, freq, value, cmap='viridis', shading='gouraud')
	
	plt.show()

## Example data statistics

In [None]:
if ENABLE_DEMO:
	normalized_labels = normalize_labels(labels)
	grouped_morlet_list = group_morlet_by_phoneme(normalized_labels, morlet)
	
	print('phonemes:   ', 'len(grouped_morlet_list)', len(grouped_morlet_list))
	print('samples ph0:', 'len(grouped_morlet_list[0])', len(grouped_morlet_list[0]))
	print(' total:     ', sum([ len(s) for s in grouped_morlet_list ]))
	print('channels:   ', 'len(grouped_morlet_list[0][0])', len(grouped_morlet_list[0][0]))
	print('frequencies:', 'len(grouped_morlet_list[0][0][0])', len(grouped_morlet_list[0][0][0]))
	print('ticks       ', 'len(grouped_morlet_list[0][0][0][0])', len(grouped_morlet_list[0][0][0][0]))

## Example save & load

In [None]:
if ENABLE_DEMO:
	save_person_grouped_morlet_list(MORLET_ORIGINAL_SAVE_DIR, 0, grouped_morlet_list)

In [None]:
if ENABLE_DEMO:
	loaded_data = load_person_grouped_morlet_list(MORLET_ORIGINAL_SAVE_DIR, person=0)

In [None]:
if ENABLE_DEMO:
	print('phonemes:   ', 'len(grouped_morlet_list)', len(loaded_data))
	print('samples ph0:', 'len(grouped_morlet_list[0])', len(loaded_data[0]))
	print(' total:     ', sum([ len(s) for s in loaded_data ]))
	print('channels:   ', 'len(grouped_morlet_list[0][0])', len(loaded_data[0][0]))
	print('frequencies:', 'len(grouped_morlet_list[0][0][0])', len(loaded_data[0][0][0]))
	print('ticks       ', 'len(grouped_morlet_list[0][0][0][0])', len(loaded_data[0][0][0][0]))
	
	labels, ungrouped_loaded_data = ungroup_morlet_by_phoneme(loaded_data)
	ungrouped_loaded_data = np.asarray(ungrouped_loaded_data)
	ungrouped_loaded_data.shape

## Sample network init + save & load

In [None]:
if ENABLE_DEMO:
	conv_layers = [
		{
			'out': 8,
			'kernel': (5, 10),
			'pool': (2, 2)
		}
	]

	dense_layers = [
		
	]

	num_classes = 2

	model = MNetwork(conv_layers, dense_layers, num_classes)

In [None]:
if ENABLE_DEMO:
	model.save_model('temp/model')

In [None]:
if ENABLE_DEMO:
	model = MNetwork.load_model('temp/model')

# > Perform wavelet pass over all data

In [None]:
if CONVERT_MORLETS:
	for visual in [ False, True ]:
		# Subdirectory matching oble audial or visual
		edf_subdir = VISUAL_SUBPATH if visual else AUDIAL_SUBPATH
		# Full directory path to input edf
		edf_dir = f'{CLEARED_PATH}/{edf_subdir}'
		# Full directory path to morlet files
		morlet_dir = f'{MORLET_ORIGINAL_SAVE_DIR}/width-{SECTOR_LENGTH_STEPS}_height-{MORLET_FREQ_STEPS}/{edf_subdir}'
		
		print()
		print(f'Preprocessing data in {edf_dir}')
		print(f'Wriging morlets to {morlet_dir}')
		
		for person, edf_file in enumerate(INPUT_EDF_LIST):
			print()
			print(f'Processing {edf_file}')
			
			# Open
			if visual:
				edf = open_visual_edf(f'{edf_file}.edf')
			else:
				edf = open_audial_edf(f'{edf_file}.edf')
			
			# Select segments
			sectors, invalid_sectors, labels = extract_strict_sectors(edf, SECTOR_LENGTH)
			print_sectors_summary(edf, sectors, invalid_sectors)
			
			# Morlet transform
			print(f'Applying morlet transform for {edf_file}')
			channels_data = subselect_channels(edf)
			_, lengths, durations, splitted = split_sectors(edf, channels_data, sectors)
			t, freq, morlet = morlet_wavelet_pass(splitted)
			morlet = transpose_morlet_channel_data(morlet)
			morlet = abs_morlet_data(morlet)
			
			# print(morlet[0].shape, freq.shape, t.shape)
			# break
			
			print('t.shape', t.shape)
			print('freq.shape', freq.shape)
			print('morlet.shape', morlet.shape)
			
			# Save data
			print(f'Saving morlet data for {edf_file}')
			normalized_labels = normalize_labels(labels)
			grouped_morlet_list = group_morlet_by_phoneme(normalized_labels, morlet)
			save_person_grouped_morlet_list(morlet_dir, person, grouped_morlet_list)


# > Train for different tasks with diferent configurations

## Lazy data preload for large datasets and multiperson classification

In [None]:
# Enable presloading dataset for ALL persons as one
PRELOAD_ALL_DATASET = False

if PRELOAD_ALL_DATASET:
	base_labels, base_morlets = load_dataset(None, None)


## Test running code

In [None]:
import time
import multiprocessing

def run_test():

	if SAVE_CONFIG:
		with open(test_config.test_config_json, 'w') as f:
			json.dump(test_config, f)

	print('binary:', test_config.binary)

	# What data to use during train
	if test_config.all_person_data:
		# List of person IDS to fetch data from
		persons_list = [ None ]
		print('Combine all person data')
		
	else:
		if test_config.person is None:
			persons_list = list(range(len(INPUT_EDF_LIST)))
			print('Train on each person separately')
		else:
			persons_list = [ test_config.person ]
			print('Train on single person:', test_config.person)

	# What phoneme combinations to use
	if test_config.binary:
		if test_config.phoneme_classes is not None:
			# Binary classifier for given class numbers
			phoneme_list = [ test_config.phoneme_classes ]
			print('Train on single phoneme pair:', test_config.phoneme_classes)
		else:
			# Binary classifier for every pair of classes
			phoneme_list = []
			for i in range(PHONEME_COUNT):
				for j in range(PHONEME_COUNT):
					if i > j:
						phoneme_list.append((i, j))
			print('Train on each phoneme pair')
	else:
		# Select data for given list of classes
		phoneme_list = [ None ]
		
		if test_config.phoneme_classes is None:
			print('Train on all phoneme classes')
		else:
			print('Train on selected phoneme classes:', test_config.phoneme_classes)
	
	for person in persons_list:
		global base_labels
		global base_morlets
		
		if RELOAD_DATASET:
			base_labels, base_morlets = load_dataset(test_config.visual, person)
			print('labels:', set(base_labels))
		else:
			print('-----------------------------')
			print('WARNING: Dataset not reloaded')
			print('-----------------------------')
		
		for phoneme_pair in phoneme_list:
			
			for seed in range(test_config.seed_steps):
				seed += test_config.seed
				
				print('seed', seed)
				
				# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
				# L O A D   D A T A
				# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
				
				if True:
					print()
					print()
					if test_config.binary:
						print('person =', person, 'phoneme1 =', phoneme_pair[0], 'phoneme2 =', phoneme_pair[1])
					else:
						print('person =', person, 'classes =', test_config.phoneme_classes)
				
					# Phoneme classes to select from data
					if test_config.binary:
						phoneme_classes = phoneme_pair
						num_classes = 1
					else:
						if test_config.phoneme_classes is None:
							phoneme_classes = None # list(range(PHONEME_COUNT))
							num_classes = PHONEME_COUNT
						else:
							phoneme_classes = test_config.phoneme_classes
							num_classes =  len(test_config.phoneme_classes)
					
					# Load dataset & select phonemes
					labels, morlets = select_phonemes(base_labels, base_morlets, phoneme_classes)
					morlets = normalize_morlets(morlets)
					
					print('labels stats:')
					stats = np.unique(labels, return_counts=True)
					print('\n'.join([ f'{v[0]}: {v[1]}' for v in zip(stats[0], stats[1])]))
					
					# Set seed
					np.random.seed(seed)
					torch.manual_seed(seed)

					# Split
					train_labels, train_morlets, test_labels, test_morlets = train_test_split(labels, morlets, test_config.test_size)
					
					print('train labels stats:')
					stats = np.unique(train_labels, return_counts=True)
					print('\n'.join([ f'{v[0]}: {v[1]}' for v in zip(stats[0], stats[1])]))
					
					print('test labels stats:')
					stats = np.unique(test_labels, return_counts=True)
					print('\n'.join([ f'{v[0]}: {v[1]}' for v in zip(stats[0], stats[1])]))
					
					print('Train count:', len(train_labels))
					print('Test count: ', len(test_labels))

					# Transform
					train_transform = torchvision.transforms.Compose([
						ResizeShiftTransform(test_config.shift_transform_scale, test_config.shift_transform_roll),
						NoiseTransform(test_config.noise_transform_scale),
						FlipAlongTime(),
						ToTensor()
					])

					test_transform = torchvision.transforms.Compose([
						ToTensor()
					])

					# Dataset
					train, test = MorletDataset(train_labels, train_morlets, train_transform), MorletDataset(test_labels, test_morlets, test_transform)
					
					# Overall dataset
					overall = MorletDataset(labels, morlets, test_transform)			
					
					# DataLoader for train
					train_loader = torch.utils.data.DataLoader(train, batch_size=test_config.batch_size, shuffle=True) # , num_workers=test_config.num_workers)
					
					# Set seed
					np.random.seed(seed)
					torch.manual_seed(seed)
				
				# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
				# T R A I N _ M O D E L
				# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
				
				if True:
					
					print('num_classes', num_classes)
					
					# Create model
					model = MNetwork(
						test_config.conv_layers, 
						test_config.dense_layers, 
						num_classes, 
						use_conv_sigmoid=test_config.use_conv_sigmoid, 
						use_dense_sigmoid=test_config.use_dense_sigmoid
					)
					print('Network:', 'conv_layers', test_config.conv_layers, 'dense_layers', test_config.dense_layers, 'num_classes', num_classes, 'use_conv_sigmoid', test_config.use_conv_sigmoid, 'use_dense_sigmoid', test_config.use_dense_sigmoid)
					
					
					print(f'Train size:    {len(train)}')
					print(f'Test size:     {len(test)}')
					init_match = 0
					print(f'Initial match: {round(init_match := calculate_match_on_dataset(model, test) * 100, 2)}%')
					init_overall_match = 0
					print(f'Initial overall match: {round(init_overall_match := calculate_match_on_dataset(model, overall) * 100, 2)}%')
					
					# Training requirements
					if num_classes > 1:
						criterion = torch.nn.CrossEntropyLoss()
					else:
						criterion = torch.nn.BCELoss()
					
					# optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
					optimizer = torch.optim.SGD(model.parameters(), lr=test_config.lr_start, momentum=0.9)
					scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=test_config.lr_step_size, gamma=0.1)
					
					# Train 5 epochs
					train_model(model, train_loader, optimizer, scheduler, criterion, test_config.epochs, print_log=True, print_iters=100, print_last_iter=False)
					
					# Save global checkpoint
					if GLOBAL_LAST_CHECKPOINT:
						global checkpoint
						checkpoint = model
					
					# Evaluate
					result_match = 0
					print(f'After train match: {round(result_match := calculate_match_on_dataset(model, test) * 100, 2)}%')
					
					# Evaluate
					result_overall_match = 0
					print(f'Overall train match: {round(result_overall_match := calculate_match_on_dataset(model, overall) * 100, 2)}%')
				
				# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
				# S A V E   M O D E L
				# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
				
				if SAVE_CHECKPOINTS:
					if test_config.binary:
						checkpoint_name = f'checkpoint-{test_config.checkpoint_prefix}-{test_config.visual}-{person}-binary-ph_{phoneme_pair[0]}-ph_{phoneme_pair[1]}-ep_{test_config.epochs}-seed_{seed}'
					else:
						checkpoint_name = f'checkpoint-{test_config.checkpoint_prefix}-{test_config.visual}-{person}-multiclass-phs_{"_".join(test_config.phoneme_classes or [ "all" ])}-ep_{test_config.epochs}-seed_{seed}'
					model.save_model(f'checkpoints/{checkpoint_name}')
					
					checkpoint_info = {}
					try:
						with open(test_config.test_json, 'r') as f:
							checkpoint_info = json.load(f)
					except:
						pass
					
					checkpoint_info[checkpoint_name] = {
						'init_match': init_match,
						'result_match': result_match,
						'init_overall_match': init_overall_match,
						'result_overall_match': result_overall_match
					}
					try:
						with open(test_config.test_json, 'w') as f:
							json.dump(checkpoint_info, f)
					except:
						print('Saving error')
						print(checkpoint_info)


## Binary single person

In [None]:
# Enable / disable config
BINARY_SINGLE = False

if BINARY_SINGLE:
	
	test_config = dotdict()

	# Data selection
	# Whiat data to use (visual, audial, all together)
	test_config.visual          = None

	# Set to True to enable binary classification mode
	test_config.binary          = True

	if test_config.binary:
		
		# - - - B I N A R Y   O P T I O N S - - - 

		# What person to use to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = None

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = False

		# What phonemes to use during train
		# Set to None to train on each pair combination
		# Set to pair to train on this data pair
		test_config.phoneme_classes = None

	else:
		
		# - - - M U L T I C L A S S   O P T I O N S - - - 

		# List of persons to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = 1

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = True

		# Phoneme classes to use during train
		# Set to None to use all phonemes
		test_config.phoneme_classes = None

	# Train properties
	test_config.test_size         = 0.2
	test_config.batch_size        = 4
	test_config.epochs            = 50
	test_config.use_conv_sigmoid  = False
	test_config.use_dense_sigmoid = False # test_config.binary

	# Train autoconfig
	test_config.lr_step_size     = 20
	test_config.lr_start         = 0.01

	# Transform properties
	test_config.shift_transform_scale = 1.1
	test_config.shift_transform_roll  = 1.0
	test_config.noise_transform_scale = 0.002

	test_config.seed       = 43
	test_config.seed_steps = 1
	test_config.num_workers = multiprocessing.cpu_count() // 4

	# Output properties
	timestamp                     = round(time.time() * 1000)
	test_config.checkpoint_prefix = f'autorun-{timestamp}'
	test_config.test_json         = f'checkpoints/{test_config.checkpoint_prefix}.json'
	test_config.test_config_json  = f'checkpoints/{test_config.checkpoint_prefix}-config.json'
				
	# Model properties
	test_config.conv_layers = [
		{
			'out': 6,
			'kernel': (3, 5),
			'pool': (1, 2),
			'dropout': 0.001
		},
		{
			'out': 12,
			'kernel': (3, 3),
			'pool': (2, 2),
			'dropout': 0.001
		},
	]

	test_config.dense_layers = [
		{
			'count': 64,
			'dropout': 0
		},
	]


## Binary all persons

In [None]:
# Enable / disable config
BINARY_ALL = False

if BINARY_ALL:
	test_config = dotdict()
	
	# Data selection
	# Whiat data to use (visual, audial, all together)
	test_config.visual          = None

	# Set to True to enable binary classification mode
	test_config.binary          = True

	if test_config.binary:
		
		# - - - B I N A R Y   O P T I O N S - - - 

		# What person to use to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = 1

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = True

		# What phonemes to use during train
		# Set to None to train on each pair combination
		# Set to pair to train on this data pair
		test_config.phoneme_classes = None

	else:
		
		# - - - M U L T I C L A S S   O P T I O N S - - - 

		# List of persons to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = 1

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = False

		# Phoneme classes to use during train
		# Set to None to use all phonemes
		test_config.phoneme_classes = None

	# Force enable/disable dataset reloading, for example, when performing multiple restarts
	RELOAD_DATASET = False

	# Train properties
	test_config.test_size         = 0.2
	test_config.batch_size        = 8
	test_config.epochs            = 50
	test_config.use_conv_sigmoid  = False
	test_config.use_dense_sigmoid = False # test_config.binary

	# Train autoconfig
	test_config.lr_step_size     = 40
	test_config.lr_start         = 0.01

	# Transform properties
	test_config.shift_transform_scale = 1.1
	test_config.shift_transform_roll  = 1.0
	test_config.noise_transform_scale = 0.002

	test_config.seed       = 42
	test_config.seed_steps = 1
	test_config.num_workers = multiprocessing.cpu_count() // 4

	# Output properties
	timestamp                     = round(time.time() * 1000)
	test_config.checkpoint_prefix = f'autorun-{timestamp}'
	test_config.test_json         = f'checkpoints/{test_config.checkpoint_prefix}.json'
	test_config.test_config_json  = f'checkpoints/{test_config.checkpoint_prefix}-config.json'
				
	# Model properties
	test_config.conv_layers = [
		{
			'out': 8,
			'kernel': (3, 5),
			'pool': (1, 2),
			'dropout': 0.001
		},
		{
			'out': 16,
			'kernel': (3, 3),
			'pool': (2, 2),
			'dropout': 0.001
		},
	]

	test_config.dense_layers = [
		{
			'count': 128,
			'dropout': 0
		},
	]


## Multiclass single person

In [None]:
# Enable / disable config
MULTICLASS_SINGLE = False

if MULTICLASS_SINGLE:
	test_config = dotdict()

	# Data selection
	# Whiat data to use (visual, audial, all together)
	test_config.visual          = None

	# Set to True to enable binary classification mode
	test_config.binary          = False

	if test_config.binary:
		
		# - - - B I N A R Y   O P T I O N S - - - 

		# What person to use to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = 1

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = True

		# What phonemes to use during train
		# Set to None to train on each pair combination
		# Set to pair to train on this data pair
		test_config.phoneme_classes = None

	else:
		
		# - - - M U L T I C L A S S   O P T I O N S - - - 

		# List of persons to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = None

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = False

		# Phoneme classes to use during train
		# Set to None to use all phonemes
		test_config.phoneme_classes = None

	if True:
		
		# Force enable/disable dataset reloading, for example, when performing multiple restarts
		RELOAD_DATASET = True

		# Train properties
		test_config.test_size         = 0.2
		test_config.batch_size        = 8
		test_config.epochs            = 50
		test_config.use_conv_sigmoid  = False
		test_config.use_dense_sigmoid = False # test_config.binary

		# Train autoconfig
		test_config.lr_step_size     = 40
		test_config.lr_start         = 0.01

		# Transform properties
		test_config.shift_transform_scale = 1.1
		test_config.shift_transform_roll  = 1.0
		test_config.noise_transform_scale = 0.002

		test_config.seed       = 42
		test_config.seed_steps = 1
		test_config.num_workers = multiprocessing.cpu_count() // 4

		# Output properties
		timestamp                     = round(time.time() * 1000)
		test_config.checkpoint_prefix = f'autorun-{timestamp}'
		test_config.test_json         = f'checkpoints/{test_config.checkpoint_prefix}.json'
		test_config.test_config_json  = f'checkpoints/{test_config.checkpoint_prefix}-config.json'

	# Model properties
	test_config.conv_layers = [
		{
			'out': 4,
			'kernel': (1, 4),
			'pool': (1, 4),
			'dropout': 0.001
		},
		{
			'out': 8,
			'kernel': (3, 5),
			'pool': (1, 2),
			'dropout': 0.001
		},
		{
			'out': 16,
			'kernel': (3, 3),
			'pool': (2, 2),
			'dropout': 0.001
		},
	]

	test_config.dense_layers = [
			{
				'count': 128,
				'dropout': 0
			},
		]


## Multiclass all persons

In [None]:
# Enable / disable config
MULTICLASS_ALL = False

if MULTICLASS_ALL:
	test_config = dotdict()

	# Data selection
	# Whiat data to use (visual, audial, all together)
	test_config.visual          = None

	# Set to True to enable binary classification mode
	test_config.binary          = False

	if test_config.binary:
		
		# - - - B I N A R Y   O P T I O N S - - - 

		# What person to use to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = 1

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = True

		# What phonemes to use during train
		# Set to None to train on each pair combination
		# Set to pair to train on this data pair
		test_config.phoneme_classes = None

	else:
		
		# - - - M U L T I C L A S S   O P T I O N S - - - 

		# List of persons to train on
		# Set to None to train on each person
		# Set to number to train on single person
		test_config.person          = None

		# Set to True to combine data from all persons into single dataset
		test_config.all_person_data = True

		# Phoneme classes to use during train
		# Set to None to use all phonemes
		test_config.phoneme_classes = None

	if True:
		
		# Force enable/disable dataset reloading, for example, when performing multiple restarts
		RELOAD_DATASET = False

		# Train properties
		test_config.test_size         = 0.1
		test_config.batch_size        = 16
		test_config.epochs            = 80
		test_config.use_conv_sigmoid  = False
		test_config.use_dense_sigmoid = False # test_config.binary

		# Train autoconfig
		test_config.lr_step_size     = 40
		test_config.lr_start         = 0.01

		# Transform properties
		test_config.shift_transform_scale = 1.1
		test_config.shift_transform_roll  = 1.0
		test_config.noise_transform_scale = 0.002

		test_config.seed       = 42
		test_config.seed_steps = 1
		test_config.num_workers = multiprocessing.cpu_count() // 4

		# Output properties
		timestamp                     = round(time.time() * 1000)
		test_config.checkpoint_prefix = f'autorun-{timestamp}'
		test_config.test_json         = f'checkpoints/{test_config.checkpoint_prefix}.json'
		test_config.test_config_json  = f'checkpoints/{test_config.checkpoint_prefix}-config.json'

	# Model properties
	test_config.conv_layers = [
		{
			'out': 4,
			'kernel': (1, 4),
			'pool': (1, 4),
			'dropout': 0.001
		},
		{
			'out': 8,
			'kernel': (3, 5),
			'pool': (1, 2),
			'dropout': 0.001
		},
		{
			'out': 16,
			'kernel': (3, 3),
			'pool': (2, 2),
			'dropout': 0.001
		},
	]

	test_config.dense_layers = [
		{
			'count': 256,
			'dropout': 0
		},
		{
			'count': 128,
			'dropout': 0
		},
	]


## Run test

In [None]:
RELOAD_DATASET         = True
GLOBAL_LAST_CHECKPOINT = False
SAVE_CHECKPOINTS       = True
SAVE_CONFIG            = True

if __name__ == '__main__':
	run_test()