# 0. Add scanner column to IXI labels file

In [None]:
import os
import pandas as pd

data_folder = os.path.join(os.path.expanduser('~'), 'data', 'IXI', 'resized')
labels_file = os.path.join(data_folder, 'labels.csv')
project_folder = os.path.join(os.path.expanduser('~'), 'projects', 'scanner-adaptation')

df = pd.read_csv(labels_file, index_col=False)
df['scanner'] = df['id'].apply(lambda x: x.split('-')[1])
df.to_csv(labels_file, index=False)
df.head()

# 1. Configure a label encoding the scanners

In [None]:
# Add scripts folder to path
import sys
sys.path.append(os.path.join(os.pardir, 'scripts'))

from configure_label import configure_label

configure_label(name='scanner', variabletype='categorical', filenames=[labels_file], columns=['scanner'],
                destination=os.path.join(project_folder, 'scanners.json'), kwargs='{"encoding": "index"}')

# 2. Split the data into folds for training/validation/test

In [None]:
from configure_nifti_folds import configure_nifti_folds

configure_nifti_folds(folders=[data_folder], targets='age', stratification=['age', 'sex', 'scanner'],
                      k=5, test_portion=0.2, encoders=[os.path.join(project_folder, 'scanners.json')],
                      destination=os.path.join(project_folder, 'data'))

# 3. Configure a domain-adaptive SFCN regression model

In [None]:
import json

from configure_model import configure_model

model_kwargs = {
    'input_shape': [43, 54, 41],
    'dropout': 0.2,
    'weight_decay': 1e-3,
    'prediction_range': [19, 87],
    'domains': 3
    
}
model_kwargs = json.dumps(model_kwargs)

configure_model(model='sfcn-reg', kwargs=model_kwargs, destination=os.path.join(project_folder, 'adaptive_model'))

# 4. Configure a preprocessor, an augmenter and a learning rate schedule

In [None]:
from pyment.data.augmenters import NiftiAugmenter
from pyment.data.preprocessors import NiftiPreprocessor
from pyment.utils.learning_rate import LearningRateSchedule

# Create a preprocessor which normalizes the images to the range [0, 1]
preprocessor = NiftiPreprocessor(sigma=255.)
preprocessor.save(os.path.join(project_folder, 'preprocessor.json'))
print(preprocessor)

augmenter = NiftiAugmenter(flip_probabilities=[0.5, 0, 0])
augmenter.save(os.path.join(project_folder, 'augmenter.json'))
print(augmenter)

learning_rate_schedule = LearningRateSchedule({0: 1e-3, 20: 3e-3, 40: 1e-4, 60: 3e-4})
learning_rate_schedule.save(os.path.join(project_folder, 'learning_rate_schedule.json'))
print(learning_rate_schedule)

In [None]:
from shutil import rmtree

from fit_model import fit_model

run_folder = os.path.join(project_folder, 'run')

if os.path.isdir(run_folder):
    rmtree(run_folder)
    
fit_model(model=os.path.join(project_folder, 'model'),
          training=[os.path.join(project_folder, 'data', f'fold_{i}.json') \
                    for i in range(4)],
          validation=[os.path.join(project_folder, 'data', f'fold_4.json')],
          preprocessor=os.path.join(project_folder, 'preprocessor.json'),
          augmenter=os.path.join(project_folder, 'augmenter.json'),
          batch_size=4,
          num_threads=8,
          loss='mse',
          metrics=['mae'],
          learning_rate_schedule=os.path.join(project_folder, 'learning_rate_schedule.json'),
          epochs=5,
          domain='scanner',
          destination=run_folder)

In [None]:
from shutil import rmtree

from fit_model import fit_model

run_folder = os.path.join(project_folder, 'run')

if os.path.isdir(run_folder):
    rmtree(run_folder)
    
fit_model(model=os.path.join(project_folder, 'model'),
          training=[os.path.join(project_folder, 'data', f'fold_{i}.json') \
                    for i in range(4)],
          validation=[os.path.join(project_folder, 'data', f'fold_4.json')],
          preprocessor=os.path.join(project_folder, 'preprocessor.json'),
          augmenter=os.path.join(project_folder, 'augmenter.json'),
          batch_size=4,
          num_threads=8,
          loss='mse',
          metrics=['mae'],
          learning_rate_schedule=os.path.join(project_folder, 'learning_rate_schedule.json'),
          epochs=5,
          domain=None,
          destination=run_folder)