# Introduction to ML using ML4H

## Prerequisites
- Basic comfort with python, some linear algebra, some data science
- Follow the instructions in the main [README](https://github.com/broadinstitute/ml4h) for installing ML4H
- Data used in this notebook is available here ([gs://fc-500bd872-4a53-45c9-87d3-39656bd83f85/data/hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz](gs://fc-500bd872-4a53-45c9-87d3-39656bd83f85/data/hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz))
- Now we are ready to teach the machines!

In [None]:
# Imports
import os
import sys
import pickle
import gzip
from typing import Dict

import h5py
import numpy as np


from ml4h.arguments import parse_args
from ml4h.TensorMap import TensorMap, Interpretation
from ml4h.tensor_generators import test_train_valid_tensor_generators
from ml4h.recipes import compare_multimodal_scalar_task_models, train_multimodal_multitask

%matplotlib inline

In [None]:
# Constants
HD5_FOLDER = './mnist_tensors/'
OUTPUT_FOLDER = './outputs/'

# Python features we make lots of use of in this notebook:
- F Strings
- Callback Functions

## TensorMaps
The critical data structure in the ml4h codebase is the TensorMap.
This abstraction provides a way to translate ***any*** kind of input data, into structured numeric tensors with clear semantics for interpretation and modeling.  TensorMaps guarantee a shape, a way to construct tensors of that shape from the HD5 files created during tensorization and a meaning to the values in the tensor that the TensorMap yields.

For example, in the `mnist.py` file these TensorMaps are defined:

In [None]:
def mnist_image_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
    return np.array(hd5['mnist_image'])


mnist_image = TensorMap('mnist_image', shape=(28, 28, 1), tensor_from_file=mnist_image_from_hd5)


def mnist_label_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
    one_hot = np.zeros(tm.shape, dtype=np.float32)
    one_hot[int(hd5['mnist_label'][0])] = 1.0
    return one_hot


mnist_label = TensorMap(
    'mnist_label', Interpretation.CATEGORICAL, tensor_from_file=mnist_label_from_hd5,
    channel_map={f'digit_{i}': i for i in range(10)},
)

Similiarly, in the `gatk.py` file we define tensors to encode data about genomic variants.  Specifically, we create 3 TensorMaps: `reference` is a 1-hot encoded 128 base-pair window of DNA sequence. `read_tensor` is an alignment of as many as 128 different DNA reads overlapping a 128 base-pair window of reference DNA.  This TensorMap includes 15 channels which encode the DNA bases from the reference from the read sequence and meta data belonging to each read.  Lastly, we define the `CATEGORICAL` TensorMap  `variant_label` which encodes the truth status of this particular genomic variant.  In this dataset we are considering on SNPs and small insertions or deletions giving us the 4 labels: `'NOT_SNP', 'NOT_INDEL', 'SNP', 'INDEL'`.  

In [None]:
def load_data(dataset):
    ''' Loads the dataset
    :param dataset: the path to the dataset (here MNIST)'''
    data_dir, data_file = os.path.split(dataset)
    if data_dir == "" and not os.path.isfile(dataset):
        # Check if dataset is in the data directory.
        new_path = os.path.join("data", dataset)
        if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz':
            dataset = new_path

    if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz':
        from urllib.request import urlretrieve
        origin = ('http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz')
        print('Downloading data from %s' % origin)
        if not os.path.exists(os.path.dirname(dataset)):
            os.makedirs(os.path.dirname(dataset))
        urlretrieve(origin, dataset)

    print('loading data...')
    f = gzip.open(dataset, 'rb')
    if sys.version_info[0] == 3:
        u = pickle._Unpickler(f)
        u.encoding = 'latin1'
        train_set, valid_set, test_set = u.load()
    else:
        train_set, valid_set, test_set = pickle.load(f)
    f.close()

    return train_set, valid_set, test_set


def mnist_as_hd5(hd5_folder):
    train, _, _ = load_data('mnist.pkl.gz')
    mnist_images = train[0].reshape((-1, 28, 28, 1))
    if not os.path.exists(hd5_folder):
        os.makedirs(hd5_folder)
    for i, mnist_image in enumerate(mnist_images):
        with h5py.File(os.path.join(hd5_folder, f'{i}.hd5'), 'w') as hd5:
            hd5.create_dataset('mnist_image', data=mnist_image)
            hd5.create_dataset('mnist_label', data=[train[1][i]])
        if (i+1) % 5000 == 0:
            print(f'Wrote {i+1} MNIST images and labels as HD5 files')

This is the type of data used by the GATK tool CNNScoreVariants to filter DNA sequencing data.  The tensorization code is part of the GATK not ML4H, however tensorized data for use is available at: `gs://fc-500bd872-4a53-45c9-87d3-39656bd83f85/data/hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz`. Once the data has been localized you can unpack the HD5 files into the `HD5_FOLDER` with the cell below (assuming the tar.gz file is in the same directory as:

In [None]:
mnist_as_hd5(HD5_FOLDER)

# The Model Factory
The function ***make_multimodal_multitask_model()*** takes lists of TensorMaps and connects them with intelligent goo.  Specifically, given a list of TensorMaps that are model inputs and TensorMaps that are desired outputs the model factory will build a model and loss appropriate for the dimensions and interpretations of the data at hand.  The depending on the input and output TensorMaps provided, the Model Factory will build models for many different situations including:
- Classification
- Regression
- Multitask
- Multimodal
- Multimodal Multitask
- Autoencoders



In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'mnist.mnist_image',
            '--output_tensors', 'mnist.mnist_label',
            '--batch_size', '16',
            '--epochs', '12',
            '--output_folder', OUTPUT_FOLDER,
            '--id', 'learn_2d_cnn'
           ]
args = parse_args()
metrics = train_multimodal_multitask(args)

sys.argv = ['train',
            '--tensors', HD5_FOLDER,
            '--input_tensors', 'mnist.mnist_image',
            '--output_tensors', 'mnist.mnist_label',
            '--activation', 'mish',
            '--dense_blocks', '64', '64', '64',
            '--batch_size', '16',
            '--epochs', '12',
            '--output_folder', OUTPUT_FOLDER,
            '--id', 'learn_2d_cnn2'
           ]
args = parse_args()
metrics = train_multimodal_multitask(args)

### Compare Models that have been trained for the same task (ie with the same output TensorMap)

In [None]:
sys.argv = ['compare_scalar', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'mnist.mnist_image',
            '--output_tensors', 'mnist.mnist_label',
            '--id', 'mnist_model_comparison',
            '--output_folder', OUTPUT_FOLDER,
            '--model_files', f'{OUTPUT_FOLDER}learn_2d_cnn/learn_2d_cnn.h5',
                            f'{OUTPUT_FOLDER}learn_2d_cnn2/learn_2d_cnn2.h5',
            '--test_steps', '100', 
            '--batch_size', '16',
           ]
args = parse_args()

generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
compare_multimodal_scalar_task_models(args)

## Custom Architectures
The default architecture produced by the ModelFactory is based the [DenseNet](https://arxiv.org/abs/1608.06993) Convolutional Neural Network.  It is extremely customizable as shown below.

In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'mnist.mnist_image',
            '--output_tensors', 'mnit.mnist_label',
            '--output_folder', OUTPUT_FOLDER,
            '--activation', 'swish',
            '--conv_layers', '32',
            '--conv_width', '32', '32', '32',
            '--dense_blocks', '32', '24', '16',
            '--dense_layers', '32',  '32', 
            '--block_size', '4',
            '--pool_x', '2',
            '--pool_y', '2',
            '--inspect_model',
            '--epochs', '1',
            '--batch_size', '4',
            '--id', 'hypertuned_2d',
           ]
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
train_multimodal_multitask(args)

After running the cell above the diagram of the model architecture will be saved at: `./outputs/hypertuned_2d/architecture_graph_hypertuned_2d.png`