# Step 4: Fit Barnacle model to your data

Use this notebook to fit the Barnacle model to your normalized data tensor. Fitting Barnacle to data requires tuning two model parameters: 
1. `R` -- the number of components
1. `lambda` -- the sparsity parameter

There are many methods for fitting model parameters. The cross-validated parameter search strategy here is the method used to fit Barnacle to metatranscriptomic data in the [original Barnacle manuscript](https://doi.org/10.1101/2024.07.15.603627). This strategy aims to reduce resource costs by fitting `R` first and then `lambda`, rather than both parameters simultaneously. It also depends on sample replicates for performing cross validation. If your data does not have sample replicates, you might instead consider trying split-half analysis for parameter selection, and if you have more compute resources at your disposal, you might consider a full grid search of both `R` and `lambda` parameters simultaneously to find the optimal combination.

Please refer to the notebook [3-tensorize-data.ipynb](https://github.com/blasks/barnacle-boilerplate/blob/main/3-tensorize-data.ipynb) for proper formatting of your input data tensor. Note that in order to facilitate bootstrapping, sample ID and replicate ID are combined into a unique identifier called `'sample_replicate_id'` (how creative). This will be the name of the third mode of your tensor. The sample ID and replicate ID information is preserved in separate metadata arrays in the dataset. The script will use this information to shuffle replicates between bootstraps, which enables more robust parameter selection, and confidence intervals in the final model.

In [41]:
# imports

import os
import pprint
import toml
import xarray as xr


In [39]:
# get user inputs

# data file
datapath = input('Enter the filepath of your input data tensor (e.g. directory/example-tensor.nc):')
# check data file exists
if not os.path.isfile(datapath):
    raise Exception(f'Unable to find the file "{datapath}"')
ds = xr.open_dataset(datapath)

# check input data formats
required_vars = ['data', 'sample_id', 'replicate_id', 'sample_replicate_id']
for var in required_vars:
    if var not in list(ds.variables):
        raise Exception(f"Tensor missing variable '{var}'. See `3-tensorize-data.ipynb` for proper formatting of input tensor dataset.")
if input(f'Found tensor with the following modes and dimensions:\n{dict(ds.sizes)}\nIs this correct? (Y/N):').strip().lower() == 'n':
    raise Exception(f'Please double check input tensor dataset.')
else:
    modes = list(ds.data.coords)
    mode_description = f'1-{modes[0]}, 2-{modes[1]}, 3-{modes[2]}'

# output directory
outdir = input('Enter the filepath of the output directory where you want files saved:')
# check output directory exists
if not os.path.isdir(outdir):
    raise Exception(f'Unable to find the directory "{outdir}"')

# input constant model parameters
nonneg_modes = [int(x)-1 for x in input(f'Which modes are non-negative? {mode_description}, (Enter 1/2/3, comma-separated)').split(',')]
sparse_modes = int(input(f'How many modes will sparsity be applied to? (Enter 0/1/2/3, default is 1)'))
if sparse_modes == 1:
    if input(f'Sparsity will be applied to mode 1-{modes[0]}\nIs this correct? (Y/N):').strip().lower() == 'n':
        raise Exception('Sparsity constraint should be applied to the first mode in the tensor. Please rearrange input tensor.')
elif sparse_modes == 2:
    if input(f'Sparsity will be applied to modes 1-{modes[0]} and 2-{modes[1]}\nIs this correct? (Y/N):').strip().lower() == 'n':
        raise Exception('Sparsity constraints should be applied to first and second modes in the tensor. Please rearrange input tensor.')

ds

Enter the filepath of your input data tensor (e.g. directory/example-tensor.nc): data/data-tensor.nc
Found tensor with the following modes and dimensions:
{'KOfam': 20069, 'phylum': 99, 'sample_replicate_id': 31}
Is this correct? (Y/N): y
Enter the filepath of the output directory where you want files saved: data
Which modes are non-negative? 1-KOfam, 2-phylum, 3-sample_replicate_id, (Enter 1/2/3, comma-separated) 2,3
How many modes will sparsity be applied to? (Enter 0/1/2/3, defalut is 1) 1
Sparsity will be applied to mode 1-KOfam
Is this correct? (Y/N): y


### Part A: Identifying optimal rank

In this step we will identify the optimal rank of the model.

In [52]:
# build config file

# config file structure
config = {
    "grid": {
        "ranks": [int(x) for x in input('Enter values of R (rank) to test (comma-separated): ').split(',')],
        "lambdas": [[0., 0., 0.]]
    },
    "params": {
        "nonneg_modes": nonneg_modes,
        "tol": 0.00001, 
        "n_iter_max": 2000,
        "n_initializations": 5
    },
    "script": {
        "input_filepath": datapath,
        "outdir": outdir,
        "n_bootstraps": 10,
        "replicates": [str(l) for l in set(ds.replicate_id.data)],
        "max_processes": os.cpu_count(),
        "seed": int(input('Enter random seed integer: '))
    }
}

# function to save config as toml file
def save_toml(config, filename="config.toml"):
    with open(filename, "w") as f:
        toml.dump(config, f)
    print(f"TOML file '{filename}' created successfully.")

# check config and save
print(f"Config file for identifying optimal rank parameter:\n\n{toml.dumps(config)}")
if input("\nDoes this look correct? (Y/N):").strip().lower() == "y":
    save_toml(config, filename=f"{outdir}/1-rank-search.toml")
else:
    print("Please review parameters and re-generate config file.")


Enter values of R to test (comma-separated):  1,5,10,15,20,25,30,35,40
Enter random seed integer:  9481


Config file for identifying optimal rank parameter:

[grid]
ranks = [ 1, 5, 10, 15, 20, 25, 30, 35, 40,]
lambdas = [ [ 0.0, 0.0, 0.0,],]

[params]
nonneg_modes = [ 1, 2,]
tol = 1e-5
n_iter_max = 2000
n_initializations = 5

[script]
input_filepath = "data/data-tensor.nc"
outdir = "data"
n_bootstraps = 10
replicates = [ "C", "B", "A",]
max_processes = 16
seed = 9481




Does this look correct? (Y/N): y


TOML file 'data/1-rank-search.toml' created successfully.
