In [1]:
import autokeras as ak

In [None]:
from dataclasses import dataclass
from sklearn.base import TransformerMixin, BaseEstimator

@dataclass
class Connectivity(TransformerMixin, BaseEstimator):
    """Parcellate regions, aggregate networks, and extract connectivity."""

    atlas: str = 'dosenbach2010'
    kind: str = 'correlation'
    agg_networks: bool = True
    mock: bool = False

    def __post_init__(self):
        if self.mock:
            self.transform = self.mock_transform

    def fit(self, X, y=None, **fit_params):
        return self

    def transform(self, X):
        p = Parcellation(self.atlas, bids_dir=BIDS_DIR, cache_dir=PARCELLATION_CACHE_DIR)
        n = NetworkAggregator(p.labels_)
        c = ConnectivityExtractor(self.kind)

        if self.agg_networks:
            conn = make_pipeline(p, n, c).fit_transform(X)
            nodes = n.networks_
        else:
            conn = make_pipeline(p, c).fit_transform(X)
            nodes = p.labels_.index.to_list()

        self.dataset_ = xr.DataArray(
            conn,
            coords={'subject': p.dataset_['subject'],
                    'node': nodes},
            dims=['subject', 'node', 'node'],
            name='connectivity')

        # select only queried subjects
        if X is not None:
            subjects_1d = X.reshape(-1).tolist()
            self.dataset_ = self.dataset_.sel(dict(subject=subjects_1d))

        return self.dataset_

    def mock_transform(self, X):

        n_features_dict = {
            ('gordon2014_2mm', True): 13,
            ('gordon2014_2mm', False): 333,
            ('dosenbach2010', True): 6,
            ('dosenbach2010', False): 160,
            ('difumo_64_2mm', True): 7,
            ('difumo_64_2mm', False): 64,
            ('seitzman2018', True): 14,
            ('seitzman2018', False): 300,
        }

        subjects = X
        if subjects is None:
            subjects =  Parcellation(self.atlas).fit(None).dataset_['subject'].values

        n_features = n_features_dict[(self.atlas, self.agg_networks)]

        nodes = [f'node_{n}' for n in range(n_features)]

        mock_conn = np.random.rand(len(subjects), n_features, n_features)

        self.dataset_ = xr.DataArray(
            mock_conn,
            coords={'subject': subjects,
                    'node': nodes},
            dims=['subject', 'node', 'node'],
            name='connectivity')
                
        return self.dataset_

    def get_feature_names_out(self, input_features=None):
        sep = ' \N{left right arrow} '
        if input_features is None:
            input_features = self.transform(None).coords['node'].values
        feature_names = pd.DataFrame(
            np.zeros((input_features.shape[0], input_features.shape[0])),
            columns=input_features, index=input_features)
        feature_names = feature_names.stack().to_frame().apply(lambda x:
            sep.join(x.name) if x.name[0] != x.name[1] else x.name[0],
            axis=1).unstack()
        return feature_names

In [None]:
clf = ak.StructuredDataClassifier(
    column_names= None,
    column_types={'sex': 'categorical', 'fare': 'numerical'},
    max_trials=10,
    overwrite=True,
)

In [None]:
model = ak.AutoModel(
    inputs=[ak.StructuredDataInput],
    outputs=[ak.ClassificationHead()]
)

In [None]:
split = 500
x_val = x_train[split:]
y_val = y_train[split:]
x_train = x_train[:split]
y_train = y_train[:split]

In [None]:
clf.fit(
    x_train,
    y_train,
    # Split the training data and use the last 15% as validation data.
    validation_split=0.15,
    epochs=10,
)

clf.fit(
    x_train,
    y_train,
    # Use your own validation set.
    validation_data=(x_val, y_val),
    epochs=10,
)