# Introduction to ML using ML4H

## Prerequisites
- Basic comfort with python, some linear algebra, some data science
- Follow the instructions in the main [README](https://github.com/broadinstitute/ml4h) for installing ML4H
- Data used in this notebook is available here ([gs://fc-500bd872-4a53-45c9-87d3-39656bd83f85/data/hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz](gs://fc-500bd872-4a53-45c9-87d3-39656bd83f85/data/hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz))
- Now we are ready to teach the machines!

In [None]:
# Imports
import os
import sys
import pickle
import random
from typing import List, Dict, Callable
from collections import defaultdict, Counter

import h5py
import numpy as np


from ml4h.defines import StorageType
from ml4h.arguments import parse_args
from ml4h.TensorMap import TensorMap, Interpretation
from ml4h.tensor_generators import test_train_valid_tensor_generators
from ml4h.models.train import train_model_from_generators
from ml4h.models.legacy_models import make_multimodal_multitask_model
from ml4h.models.inspect import plot_and_time_model
from ml4h.recipes import compare_multimodal_scalar_task_models, train_multimodal_multitask

%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec

In [None]:
# Constants
HD5_FOLDER = './tensors/'
OUTPUT_FOLDER = './outputs/'

# Python features we make lots of use of in this notebook:
- F Strings
- Callback Functions

## TensorMaps
The critical data structure in the ml4h codebase is the TensorMap.
This abstraction provides a way to translate ***any*** kind of input data, into structured numeric tensors with clear semantics for interpretation and modeling.  TensorMaps guarantee a shape, a way to construct tensors of that shape from the HD5 files created during tensorization and a meaning to the values in the tensor that the TensorMap yields.

For example, in the `mnist.py` file these TensorMaps are defined:

In [None]:
def mnist_image_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
    return np.array(hd5['mnist_image'])


mnist_image = TensorMap('mnist_image', shape=(28, 28, 1), tensor_from_file=mnist_image_from_hd5)


def mnist_label_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
    one_hot = np.zeros(tm.shape, dtype=np.float32)
    one_hot[int(hd5['mnist_label'][0])] = 1.0
    return one_hot


mnist_label = TensorMap(
    'mnist_label', Interpretation.CATEGORICAL, tensor_from_file=mnist_label_from_hd5,
    channel_map={f'digit_{i}': i for i in range(10)},
)

Similiarly, in the `gatk.py` file we define tensors to encode data about genomic variants.  Specifically, we create 3 TensorMaps: `reference` is a 1-hot encoded 128 base-pair window of DNA sequence. `read_tensor` is an alignment of as many as 128 different DNA reads overlapping a 128 base-pair window of reference DNA.  This TensorMap includes 15 channels which encode the DNA bases from the reference from the read sequence and meta data belonging to each read.  Lastly, we define the `CATEGORICAL` TensorMap  `variant_label` which encodes the truth status of this particular genomic variant.  In this dataset we are considering on SNPs and small insertions or deletions giving us the 4 labels: `'NOT_SNP', 'NOT_INDEL', 'SNP', 'INDEL'`.  

In [None]:
DNA_SYMBOLS = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
VARIANT_LABELS = {'NOT_SNP': 0, 'NOT_INDEL': 1, 'SNP': 2, 'INDEL': 3}


def tensor_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
    return np.array(hd5[tm.name])


reference = TensorMap('reference', shape=(128, len(DNA_SYMBOLS)), tensor_from_file=tensor_from_hd5)
read_tensor = TensorMap('read_tensor', shape=(128, 128, 15), tensor_from_file=tensor_from_hd5)


def variant_label_from_hd5(tm: TensorMap, hd5: h5py.File, dependents: Dict = {}) -> np.ndarray:
    one_hot = np.zeros(tm.shape, dtype=np.float32)
    variant_str = str(hd5['variant_label'][()], 'utf-8')
    for channel in tm.channel_map:
        if channel.lower() == variant_str.lower():
            one_hot[tm.channel_map[channel]] = 1.0
    if one_hot.sum() != 1:
        raise ValueError(f'TensorMap {tm.name} missing or invalid label: {variant_str} one_hot: {one_hot}')
    return one_hot


variant_label = TensorMap(
    'variant_label', Interpretation.CATEGORICAL,
    shape=(len(VARIANT_LABELS),),
    tensor_from_file=variant_label_from_hd5,
    channel_map=VARIANT_LABELS,
)

This is the type of data used by the GATK tool CNNScoreVariants to filter DNA sequencing data.  The tensorization code is part of the GATK not ML4H, however tensorized data for use is available at: `gs://fc-500bd872-4a53-45c9-87d3-39656bd83f85/data/hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz`. Once the data has been localized you can unpack the HD5 files into the `HD5_FOLDER` with the cell below (assuming the tar.gz file is in the same directory as:

In [None]:
if not os.path.exists(HD5_FOLDER):
    os.makedirs(HD5_FOLDER)
!tar -zxvf ./hg002_na24385_ml4h_tensors_v2021_10_14.tar.gz  -C ./tensors/

# The Model Factory
The function ***make_multimodal_multitask_model()*** takes lists of TensorMaps and connects them with intelligent goo.  Specifically, given a list of TensorMaps that are model inputs and TensorMaps that are desired outputs the model factory will build a model and loss appropriate for the dimensions and interpretations of the data at hand.  The depending on the input and output TensorMaps provided, the Model Factory will build models for many different situations including:
- Classification
- Regression
- Multitask
- Multimodal
- Multimodal Multitask
- Autoencoders



## 1D CNN for Classification of Genomic Variants
Jupyter is great, but can complicate productionizing code. We try to mitigate this by interacting with the jupyter notebook as if it were a command line call to one of ml4h's modes. 

In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'gatk.reference',
            '--output_tensors', 'gatk.variant_label',
            '--batch_size', '16',
            '--epochs', '12',
            '--output_folder', OUTPUT_FOLDER,
            '--id', 'learn_1d_cnn'
           ]
args = parse_args()
metrics = train_multimodal_multitask(args)

In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'gatk.read_tensor',
            '--output_tensors', 'gatk.variant_label',
            '--batch_size', '16',
            '--epochs', '12',
            '--output_folder', OUTPUT_FOLDER,
            '--id', 'learn_2d_cnn'
           ]
args = parse_args()
metrics = train_multimodal_multitask(args)

### Compare Models that have been trained for the same task (ie with the same output TensorMap)

In [None]:
sys.argv = ['compare_scalar', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'gatk.reference', 'gatk.read_tensor',
            '--output_tensors', 'gatk.variant_label',
            '--id', 'gatk_model_comparison',
            '--output_folder', OUTPUT_FOLDER,
            '--model_files', f'{OUTPUT_FOLDER}learn_1d_cnn/learn_1d_cnn.h5',
                             f'{OUTPUT_FOLDER}learn_2d_cnn/learn_2d_cnn.h5',
            '--test_steps', '100', 
            '--batch_size', '16',
           ]
args = parse_args()

generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
compare_multimodal_scalar_task_models(args)

## Custom Architectures
The default architecture produced by the ModelFactory is based the [DenseNet](https://arxiv.org/abs/1608.06993) Convolutional Neural Network.  It is extremely customizable as shown below.

In [None]:
sys.argv = ['train', 
            '--tensors', HD5_FOLDER, 
            '--input_tensors', 'gatk.reference',
            '--output_tensors', 'gatk.variant_label',
            '--output_folder', OUTPUT_FOLDER,
            '--activation', 'swish',
            '--conv_layers', '32',
            '--conv_width', '32', '32', '32',
            '--dense_blocks', '32', '24', '16',
            '--dense_layers', '32',  '32', 
            '--block_size', '4',
            '--pool_x', '2',
            '--pool_y', '2',
            '--inspect_model',
            '--epochs', '1',
            '--batch_size', '4',
            '--id', 'hypertuned_1d',
           ]
args = parse_args()
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(**args.__dict__)
train_multimodal_multitask(args)

After running the cell above the diagram of the model architecture will be saved at: `./outputs/hypertuned_1d/architecture_graph_hypertuned_1d.png`