<a href="https://colab.research.google.com/github/broadinstitute/igv-experiments/blob/main/igv-dash/IGV_Dash_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interactive IGV Technology Demo

## This notebook demonstrates how to use [Dash](https://plotly.com/dash/) to create a plot that will drive the current position within an embedded instance of [IGV](https://igv.org/).

In [None]:
#@markdown # Step One: Build and Configure
#@markdown This demo builds its own version of the python package
#@markdown [cyvcf2](https://github.com/brentp/cyvcf2), so it can take up to two minutes to complete this 
#@markdown initialization cell. 

def install_cyvcf2(path_lib='./lib'):
    import sys, subprocess, os
    cyvcf_install_script = f"""
    set -e
    package=cyvcf2
    cwd=$(pwd)
    cd /content
    rm -rf $package
    if [ ! -d $package ]; then
    rm -rf $package
    fi
    git clone --recursive https://github.com/brentp/$package
    apt-get install autoconf
    cd $package/htslib
    autoheader
    autoconf
    ./configure --enable-gcs --enable-libcurl
    make 
    cd .. 
    pip install -r requirements.txt 
    CYTHONIZE=1 pip install -e .
    cd ..
    cp -r $package {path_lib}
    rm -rf $package
    cd $cwd
    """
    
    cyvcf2_path = f"{path_lib}/cyvcf2"
    if not os.path.exists(cyvcf2_path):
        print("Building python cyvcf2")
        build_proc = subprocess.Popen(cyvcf_install_script, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        with open('cyvcf2-build.log', 'w') as fh:
            for line in iter(build_proc.stdout.readline, b''):
                line = line.decode().strip()
                fh.write(line)
    
    sys.path.append(cyvcf2_path)
    
def configure_dash_proxy(port=8050):
    import os
    server_url = None
    base_subpath = None

    # Terra Notebooks
    gcs_project = os.environ.get("GOOGLE_PROJECT")
    runtime_name = os.environ.get("RUNTIME_NAME")
    if (gcs_project and runtime_name):
        base_subpath = f'/notebooks/{gcs_project}/{runtime_name}/'
        server_url = f'https://notebooks.firecloud.org/proxy/{gcs_project}/{runtime_name}/jupyter/proxy/{port}/'
    else:
        # Google Colab
        try:
            from google.colab.output import eval_js
            js_code = f"google.colab.kernel.proxyPort({port})"
            server_url = eval_js(js_code)
            base_subpath = '/'
        except ImportError:
            pass

    if server_url and base_subpath:
        from jupyter_dash.comms import _jupyter_config
        from jupyter_dash import JupyterDash

        _jupyter_config.update({
            'type': 'base_url_response',
            'server_url': server_url, 
            'base_subpath': '/',
            'frontend': 'notebook'
        })
        JupyterDash.default_server_url = server_url

def init_dash(port=8050):
    from importlib.util import find_spec
    if find_spec('jupyter_dash') is None:
        !pip install jupyter-dash==0.4.2 dash-bio >/dev/null 2>&1
        #!pip install jupyter-dash==0.4.1 dash-bio pyyaml==5.4.1 >/dev/null 2>&1
    configure_dash_proxy(port=port)
    # squelch confusing log messages from app server
    import logging
    logging.getLogger('werkzeug').setLevel(logging.ERROR)
    print("jupyter-dash configured")
    return True
    
install_cyvcf2()
_ = init_dash()

In [None]:
#@markdown # Step Two: Sample VCF Data
#@markdown Before we can show an interactive plot, we need data first. This cell will sample 
#@markdown thousands of sites from a VCF for HG001 from GIAB, and cache the data 
#@markdown into a pandas data frame.

import re
import random
import pandas as pd
from cyvcf2 import VCF

class VariantSampler(object):
    def random_region(self, vcf, seqmap):
        (seqname, seqlen) = random.choice(seqmap)
        pos = random.randint(0, seqlen)
        coord = f"{seqname}:{pos}"
        region = vcf(coord)
        return region
    
    def sample(self, vcf_path=None, n_sites=100, n_samples=100, seqmap=None):
        vcf_fn = vcf_path.split('/')[-1]
        print(f"Sampling {n_sites * n_samples} variants from {vcf_fn}")
        vcf = VCF(vcf_path)
        if seqmap is None:
            re_chr = re.compile("^chr\d{1,2}$")
            seqnames = filter(re_chr.match, vcf.seqnames)
            seqmap = list(zip(seqnames, vcf.seqlens))

        for site_idx in range(n_sites):
            sample_cnt = 0
            region = self.random_region(vcf, seqmap)
            while sample_cnt < n_samples:
                try:
                    rec = next(region)
                except StopIteration:
                    region = self.random_region(vcf, seqmap)
                    continue
                yield rec
                sample_cnt += 1

def vcf_to_dict(itr):
    format_fields = ['DP', 'GQ', 'AD', 'ADALL']
    for rec in itr:
        row = {
            'chr': rec.CHROM,
            'pos': rec.POS,
            'ref': rec.REF,
            'alt': rec.ALT,
            'filter': rec.FILTER,
        }
        info = {key: val.split(',') if (type(val) == str) else val for (key, val) in dict(rec.INFO).items()}
        row.update(info)
        format_values = {fn: rec.format(fn)[0][0] for fn in format_fields}
        row.update(format_values)
        yield row

def sample_vcf_as_dataframe(vcf_path=None, n_sites=50, n_samples=100):
    vs = VariantSampler()
    sites = vs.sample(vcf_path, n_sites=n_sites, n_samples=n_samples)
    sites = vcf_to_dict(sites)
    df = pd.DataFrame(list(sites))
    df = df.fillna(0)
    #pass_indicies = (df['filter'] == 0)
    #df['filter'][pass_indicies] = 'PASS'
    df.loc[df['filter'] == 0, 'filter'] = 'PASS'
    df['locus'] = [f'{ch}:{pos}' for (ch, pos) in zip(df['chr'], df['pos'])]
    df['_index'] = df.index
    return df

vcf_path = "https://ftp-trace.ncbi.nlm.nih.gov/ReferenceSamples/giab/release/NA12878_HG001/latest/GRCh38/SupplementaryFiles/HG001_GRCh38_1_22_v4.2.1_all.vcf.gz"
vcf_index_path = vcf_path + '.tbi'

cram_path = "gs://broad-public-datasets/NA12878/NA12878.cram"
cram_index_path = "gs://broad-public-datasets/NA12878/NA12878.cram.crai"

vcf_data = sample_vcf_as_dataframe(vcf_path=vcf_path, n_sites=10, n_samples=100)

In [None]:
#@markdown # Step Three: Explore Variant Data with Plotly and IGV
#@markdown Now that we have our sampled variant data, lets explore it. We will create a scatter plot of all the variants, with the X axis representing variant quality (GQ) and the Y axis representing variant diversity within the larger cohort (ADALL). The color of the marker indicates the variant's filter status. Once the graph loads, click on any marker to focus IGV on the variant in question. On the right side of the scatter plot is a data pane that shows tabular information.
#@markdown This cell will output a URL that points to your plotly applicaton.

import logging
import json
import dash
import dash_bio
import plotly.express as px
from jupyter_dash import JupyterDash
from dash import dcc, html
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate

igv_tracks = [
    {
        "name": "HG001 Variants",
        "url": vcf_path,
        "indexURL": vcf_index_path,
        "format": "vcf",
        "displayMode": "SQUISHED",
    },
    {
        "name": "HG001 Alignments",
        "url": cram_path,
        "indexURL": cram_index_path,
        "format": "cram",
        "displayMode": "SQUISHED",
    },
    {
        "name": "Common SNPs (150)",
        "type": "snp",
        "format": "snp",
        "url": "https://s3.dualstack.us-east-1.amazonaws.com/igv.org.genomes/hg38/annotations/snp150Common.txt.gz",
        "indexURL": "https://s3.dualstack.us-east-1.amazonaws.com/igv.org.genomes/hg38/annotations/snp150Common.txt.gz.tbi",
        "visibilityWindow": 100000,
    },
]


# Build App
app = JupyterDash(__name__)
app.layout = html.Div([
    html.H1("VCF IGV Scatter Demo"),
    html.H2("To interact with the demo, click on a marker within the scatter plot."),
    html.Table([
        html.Tr([
            html.Td([
                dcc.Loading(type='default', id='graph-loading', children=dcc.Graph(id='graph-output')),
            ], style={'width': '1024px'}),
            html.Td([
                html.Pre(id="where")
            ]),
        ]),
    ]),
    dcc.Loading(type='default', id='igv-loading', children=html.Div(id='igv-output')),
], style={"background-color": "white"})

# Define callback to update graph
@app.callback(
    Output('graph-output', 'figure'),
    Input('graph-loading', 'children')
)
def update_figure(children):
    return px.scatter(
        vcf_data,
        x="GQ", 
        y="ADALL",
        color="filter",
        size="ADALL",
        custom_data=["_index"],
        render_mode="webgl", 
        title="HG001: ADALL vs GQ",
    )

@app.callback(
    Output('igv-output', 'children'),
    Input('graph-output', 'clickData')
)
def return_igv(clickData):
    igv_config = dict(
        id='igv-chart',
        genome='hg38',
        tracks=igv_tracks,
    )
        
    if clickData:
        idx = clickData['points'][0]['customdata'][0]
        locus = vcf_data.loc[idx]['locus']
        igv_config["locus"] = locus
    
    igv = dash_bio.Igv(**igv_config)
    dom = html.Div([igv])
    return dom
    
@app.callback(
    Output('where', 'children'),
    Input('graph-output', 'clickData')
)
def click_data(clickData):
    if not clickData:
        return ''
    idx = clickData['points'][0]['customdata'][0]
    return repr(vcf_data.loc[idx])

# Run app and display result inline in the notebook
debug = False
mode = 'external'
_ = app.run_server(debug=debug, mode=mode, dev_tools_ui=debug, dev_tools_serve_dev_bundles=debug)