In [1]:
from mutagene.profiles import Sample
from mutagene.signatures import Identify
from mutagene.signatures.constants import sig_by_etiology
from scipy.stats import gaussian_kde
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
from glob import glob
%matplotlib tk

## Correlated class
Decomposes a profile using 4 signature sets and allows you to compare them by etilogy.

In [2]:
class Correlated:
    def __init__(self, profile):
        self.profile = profile
        
        self.decompositions = []
        for sig_set in [5,10,30,49]:
            self.decompositions.append(Identify(self.profile, sig_set, bootstrap=False, dummy_sigs=True))

        self.results = []
        for corr in sig_by_etiology:
            if corr['sig'][10]:
                _corr_res = {
                    'etiology': corr['etiology'],
                    'sig_set': {}
                }
                for sig_set, decomp in zip([5,10,30,49], self.decompositions):
                    for sig in corr['sig'][sig_set]:
                        _corr_res['sig_set'].setdefault(sig_set, [])
                        found = list(filter(lambda x: x['name'] == sig, decomp.decomposition))
                        if found:
                            res = found[0]
                            _corr_res['sig_set'][sig_set].append({
                                'sig': sig,
                                'score': res['score'],
                                'mutations': res['mutations']
                            })
                        else:
                            _corr_res['sig_set'][sig_set].append({
                                'sig': sig,
                                'score': 0,
                                'mutations': 0
                            })
                self.results.append(_corr_res)

    def scatter_plot(self, set_a, set_b):
        labels = []
        yticks = []
        xticks = []
        y = []
        x = []
        for res in self.results:
            if set_a in res['sig_set'] and set_b in res['sig_set']:
                # label
                labels.append(res['etiology'])

                # xticks
                sig_names_a = [ sig['sig'] for sig in res['sig_set'][set_a] ]
                xticks.append(' '.join(sig_names_a))

                # x
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_a] ]
                x.append(sum(mutations_list))

                # yticks
                sig_names_b = [sig['sig'] for sig in res['sig_set'][set_b]]
                yticks.append(' '.join(sig_names_b))

                # y
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_b] ]
                y.append(sum(mutations_list))

        assert len(yticks)==len(xticks)==len(x)==len(y)==len(labels), "Shapes don't match"
        max_val = max( [ max(x), max(y) ] )
        plt.plot([0,max_val],[0,max_val],'-g')
        plt.scatter(x, y)
        plt.xticks(x, xticks)
        plt.xticks(rotation=90)
        plt.yticks(y, yticks)
        plt.ylabel('Signature b')
        plt.xlabel('Signature a')
        plt.title('Scatter Plot: {}'.format(sample_files[0]))
        text = []
        for i, label in enumerate(labels):
            text.append(plt.text(x[i], y[i], f'{label} ({int(round(x[i]))}, {int(round(y[i]))})'))
        adjust_text(texts)
        plt.show()
    
    def _get_plot_data(self, sig_set):
        y = []
        for res in self.results:
            if sig_set in res['sig_set']:
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][sig_set] ]
                y.append(sum(mutations_list))
            else:
                y.append(0)
        return y    
    
    def bar_chart(self):
        for i, (sig_set, color) in enumerate([(5,'#247BA0'),(10,'#70C1B3'),(30,'#B2DBBF'),(49,'#E9C46A')]):  
            ind = range(i, len(self.results)*5, 5)
            y = self._get_plot_data(sig_set)
            plt.barh(
                ind, y,
                color=color,
                label=sig_set
            )
            labels = list(map(lambda x: x if x>0 else '', y))
            for j in range(len(y)):
                plt.annotate(labels[j], (y[j]+1, ind[j]), va='center')
        ticks = []
        for res in self.results:
            for sig_set in [5,10,30,49]:
                if sig_set in res['sig_set']:
                    sig_names = [ sig['sig'] for sig in res['sig_set'][sig_set] ]
                    ticks.append(' '.join(sig_names))
                else:
                    ticks.append('')
            ticks.append('')
        for i in range(len(ticks)):
            plt.annotate(ticks[i], (-1, i), ha='right', va='center')
        labels = [res['etiology'] for res in self.results]
        plt.yticks(range(3,len(self.results)*5,5), labels)
        plt.legend()
        plt.ylabel('Etiology')
        plt.xlabel('Number of Mutations')
        plt.title('Etiology vs mutations over all signature sets')
        plt.tick_params(left=False)
        plt.show()
        
    def stacked_bars(self, set_a, set_b):
        labels = []
        ticks_a = []
        ticks_b = []
        y_a = []
        y_b = []
        
        def norm(a,b):
            s = a+b
            if s == 0:
                return 0.5,0.5
            _a = a/s
            _b = b/s
            return _a, _b
        
        for res in self.results:
            if set_a in res['sig_set'] and set_b in res['sig_set']:
                # label
                labels.append(res['etiology'])
                
                # ticks_a
                sig_names_a = [ sig['sig'] for sig in res['sig_set'][set_a] ]
                ticks_a.append(' '.join(sig_names_a))

                # ticks_b
                sig_names_b = [sig['sig'] for sig in res['sig_set'][set_b]]
                ticks_b.append(' '.join(sig_names_b))
                
                # y_a
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_a] ]
                y_a.append(sum(mutations_list))

                # y_b
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_b] ]
                y_b.append(sum(mutations_list))

        assert len(ticks_a)==len(ticks_b)==len(y_b)==len(y_a)==len(labels), "Shapes don't match"
        
        norm_a, norm_b = zip(*map(norm, y_a, y_b))
        
        ind = range(0,len(labels)*3,3)
        width = 1
        bars_a = plt.barh(ind, norm_a, width, color='#FFCC7A')
        bars_b = plt.barh(ind, norm_b, width, left=norm_a, color='#819BC1')
        plt.legend((bars_a[0], bars_b[0]), (set_a, set_b))
        plt.plot([0.5 for i in labels], ind, '|w')
        for i in range(len(labels)):
            plt.annotate(labels[i], (0.5, ind[i]+0.8), ha='center')
            plt.annotate(f"{y_a[i]} mut(s)", (norm_a[i]-0.01, ind[i]), ha='right', va='center')
            plt.annotate(f"{y_b[i]} mut(s)", (norm_a[i]+0.01, ind[i]), ha='left', va='center')
            plt.annotate(ticks_a[i], (0, ind[i]-0.8), ha='left', va='top', color='#616161')
            plt.annotate(ticks_b[i], (1, ind[i]-0.8), ha='right', va='top', color='#616161')
        plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
        plt.show()

## BenchmarkCorrelated class
Analyses how many samples are decomposed by etiology and signature set

In [3]:
class BenchmarkCorrelated:
    sig_sets = [
        {'index': 0,
         'num': 5,
         'colour': '#3BCEAC',
         'name': 'MutaGene 5'},
        {'index': 1,
         'num': 10,
         'colour': '#FFD23F',
         'name': 'MutaGene 10'},
        {'index': 2,
         'num': 30,
         'colour': '#EE4266',
         'name': 'COSMIC V2'},
        {'index': 3,
         'num': 49,
         'colour': '#540D6E',
         'name': 'COSMIC V3'}
    ]
    eti_colours = ['#ff0000', '#ff8400', '#51ff00', '#00fff7', '#0000ff', '#b700ff', '#ff0090', '#99D588', '#00ab25', '#ffdd00', '#e1ff00', '#ab0000', '#aba200', '#7002a3']
    
    def __init__(self, samples):
        """
        Benchmarks Correlated results for the samples given.
        
        Arguments:
        `samples` - list of Sample objects
        """
        assert len(samples) > 0, 'Empty list'
        self.correlated_objs = []
        total_samples = len(samples)
        for i,sample in enumerate(samples):
            print(f'Processing sample {i+1}/{total_samples} ...', end="\r")
            self.correlated_objs.append(Correlated(sample.profile))
        print('\nDone!')
        
    def avg_decomp(self):
        data = []
        for sig_i, sig_info in enumerate(self.sig_sets):
            assert sig_info['num'] == self.correlated_objs[0].decompositions[sig_i].sig_set, "Signature set missmatch"
            _d = {
                'set': sig_info['name'],
                'avg_decomp': []
            }
            sig_names = self.correlated_objs[0].decompositions[sig_i].W_and_labels[1]
            for sig in sig_names:
                _s = {
                    'sig': sig,
                    'score': None
                }
                scores = []
                for corr_obj in self.correlated_objs:
                    decomp = list(filter(lambda o: o['name'] == sig, corr_obj.decompositions[sig_i].decomposition))
                    assert len(decomp) < 2
                    if decomp:
                        scores.append(decomp[0]['score'])
                    else:
                        scores.append(0)
                _s['score'] = sum(scores)/len(self.correlated_objs)
                _d['avg_decomp'].append(_s)
            data.append(_d)
        return data
            
        
    def distribution(self, bandwidth=0.3, resolution=1000, interval=[0,1], plot=True, return_results=False):
        """
        Plot distribution of proposed contributions by etiology and signature set
        
        Arguments:
        `bandwidth` - bandwidth parameter for the Gaussian kernel density estimate
        `resolution` - number of points to calculate for the probability density plot
        `plot` - if True, the data is plotted
        `return_results` - if True, the data is returned
        """
        # Process data
        results = []
        ref = self.correlated_objs[0].results # reference structure
        for eti_i in range(len(ref)):
            _res = {
                'etiology': ref[eti_i]['etiology'],
                'sig_set':{}
            }
            for sig_set in self.sig_sets:
                sig_set = sig_set['num']
                _res['sig_set'][sig_set] = {
                    'contrib': [],
                    'kde': None
                }
            for sig_set in self.sig_sets:
                sig_set = sig_set['num']
                if sig_set in ref[eti_i]['sig_set']:
                    for corr in self.correlated_objs:
                        contrib_list = [ sig['score'] for sig in corr.results[eti_i]['sig_set'][sig_set] ]
                        _res['sig_set'][sig_set]['contrib'].append(sum(contrib_list))
                        if len(_res['sig_set'][sig_set]['contrib']) > 1 and sum(_res['sig_set'][sig_set]['contrib']) > 0:
                            distrib = gaussian_kde(_res['sig_set'][sig_set]['contrib'], bandwidth)
                            _res['sig_set'][sig_set]['kde'] = distrib
            results.append(_res)
        if plot:
            # Plot data
            grid_n = np.ceil( np.sqrt(len(ref)) )
            for i, res in enumerate(results):
                plt.subplot(grid_n, grid_n, i+1)
                plt.title(res['etiology'])
                plt.ylabel('Probability density')
                plt.xlabel('Contribution')
                for j, sig_set in enumerate(self.sig_sets):
                    if res['sig_set'][ sig_set['num'] ]['kde']:
                        plt.plot(
                            np.linspace(interval[0], interval[1], resolution),
                            res['sig_set'][ sig_set['num'] ]['kde'].evaluate(np.linspace(interval[0], interval[1], resolution)),
                            marker='',
                            color=sig_set['colour'],
                            label=sig_set['name']
                        )
                plt.legend()
            plt.subplots_adjust(
                left=.05,
                bottom=.05,
                right=.95,
                top=.95,
                wspace=.2,
                hspace=.55)
            plt.show()
        if return_results:
            return results
        
    def compare(self, set_a, set_b, cutoff=0.05, plot=True, return_results=False):
        assert list(filter(lambda x: x['num']==set_a, self.sig_sets)), "Invalid value for set_a"
        assert list(filter(lambda x: x['num']==set_b, self.sig_sets)), "Invalid value for set_b"
        # Process data
        results = []
        ref = self.correlated_objs[0].results # reference structure
        assert len(self.eti_colours) >= len(ref), "Not enough colours"
        for eti_i in range(len(ref)):
            if set_a in ref[eti_i]['sig_set'] and set_b in ref[eti_i]['sig_set']:
                _res = {
                    'etiology': ref[eti_i]['etiology'],
                    'set_a':[],
                    'set_b':[],
                    'colour': self.eti_colours[eti_i]
                }
                for corr in self.correlated_objs:
                    res =  corr.results[eti_i]
                    # append set a score
                    score_sum_a = sum([ sig['score'] for sig in res['sig_set'][set_a] ])
                    # append set b score
                    score_sum_b = sum([ sig['score'] for sig in res['sig_set'][set_b] ])
                    if score_sum_a > cutoff or score_sum_b > cutoff:
                        _res['set_a'].append(score_sum_a)
                        _res['set_b'].append(score_sum_b)
                results.append(_res)
        if plot:
            # Plot results
            plt.plot([0,1],[0,1],'-k')
            for res in results:
                plt.scatter(
                    np.round_(res['set_a'],3),
                    np.round_(res['set_b'],3),
                    marker='o',
                    c=res['colour'],
                    label=res['etiology']
                )
            plt.legend()
            label_a = list(filter(lambda x: x['num']==set_a, self.sig_sets))[0]['name']
            label_b = list(filter(lambda x: x['num']==set_b, self.sig_sets))[0]['name']
            plt.xlabel(f'{label_a}')
            plt.ylabel(f'{label_b}')
            plt.title('Predicted contributions by signature set')
            plt.show()
        if return_results:
            return results
        
    def histogram(self, bins=20, threshold=0.05, plot=True, return_results=False, savefig=None):
        """
        Plot histogram of proposed contributions by etiology and signature set
        
        Arguments:
        `bandwidth` - bandwidth parameter for the Gaussian kernel density estimate
        `resolution` - number of points to calculate for the probability density plot
        `plot` - if True, the data is plotted
        `return_results` - if True, the data is returned
        """
        # Process data
        results = []
        ref = self.correlated_objs[0].results # reference structure
        for eti_i in range(len(ref)):
            _res = {
                'etiology': ref[eti_i]['etiology'],
            }
            for sig_set in self.sig_sets:
                sig_set = sig_set['num']
                _res[sig_set] = []
            for sig_set in self.sig_sets:
                sig_set = sig_set['num']
                if sig_set in ref[eti_i]['sig_set']:
                    for corr in self.correlated_objs:
                        contrib_list = [ sig['score'] for sig in corr.results[eti_i]['sig_set'][sig_set] ]
                        _res[sig_set].append(sum(contrib_list))
            results.append(_res)
        if plot:
            # Plot data
            results = list(filter(lambda res: max([ np.mean(res[sig_set['num']]) for sig_set in self.sig_sets ]) > threshold, results))
            grid_n = np.ceil( np.sqrt(len(results)) )
            plt.clf()
            for i, res in enumerate(results):
                plt.subplot(grid_n, grid_n, i+1)
                plt.title(res['etiology'])
                plt.ylabel('Counts')
                plt.xlabel('Contribution')
                xs, cs, ls = [], [], []
                for sig_set in self.sig_sets:
                    if res[ sig_set['num'] ]:
                        filt_xs = list(filter(lambda x: x > threshold, res[ sig_set['num'] ]))
                        xs.append(filt_xs)
                        cs.append(sig_set['colour'])
                        ls.append(sig_set['name'])
                plt.hist(
                    xs, bins,
                    color=cs,
                    label=ls
                )
                plt.legend()
            plt.minorticks_on()
            plt.subplots_adjust(
                left=.05,
                bottom=.05,
                right=.95,
                top=.95,
                wspace=.2,
                hspace=.55)
            plt.rcParams["figure.figsize"] = (16,9)
            if savefig == None:
                plt.show()
            else:
                plt.savefig(savefig)
        if return_results:
            return results

    def heatmap(self, sig_set):
        assert sig_set in [o['num'] for o in self.sig_sets], "Signature set provided not found"
        sig_set_obj = list(filter(lambda o: o['num'] == sig_set, self.sig_sets))[0]
        sig_names = self.correlated_objs[0].decompositions[sig_set_obj['index']].W_and_labels[1]
        matrix = []
        for corr_obj in self.correlated_objs:
            decomposition = corr_obj.decompositions[sig_set_obj['index']].decomposition
            sample_v = []
            for sig_name in sig_names:
                res = list(filter(lambda o: o['name'] == sig_name, decomposition))
                if res:
                    sample_v.append(res[0]['score'])
                else:
                    sample_v.append(0)
            matrix.append(sample_v)
        plt.imshow(np.transpose(matrix), cmap="Reds")
        plt.yticks(range(len(sig_names)), sig_names)
        plt.ylim(len(sig_names), -1)
        plt.ylabel('Signature')
        plt.xlabel('Sample number')
        plt.title(f"Signature set: {sig_set_obj['name']}")
        plt.show()

## Available samples

First we list and index the available samples.

In [6]:
# Find sample files
sample_files = glob('samples/benchmark/*.maf')
list(enumerate(sample_files))

[(0, 'samples/benchmark/tcga_brca.maf'),
 (1, 'samples/benchmark/tcga_dlbc.maf'),
 (2, 'samples/benchmark/tcga_luad.maf'),
 (3, 'samples/benchmark/tcga_skcm.maf')]

## Report
You can generate an automatic report.

Now we choose where we want to save the report and resulting plots. Choose the `root_dir` of your liking. A folder called `assets_[date and time of report]` will be created at the `root_dir` and the resulting plots will be exported to this folder.

Modify the `rep` list to choose the samples you would like in your report.
```python
{'name': 'Sample category / type',
'samp': [
    {'name': 'Sample A', # sample name
    'code': 'samp_a', # cBioPortal study id
    'id':0}, # corresponding index in `sample_files`
    ...
]}
```

In [None]:
# Automatic report
import datetime as dt
import os

root_dir = "/mnt/c/Users/fernando.hernandez/Documents/1_Projects/10_MutaGene/documentation/samples"

rep = [
    {'name': 'Breast cancer',
    'samp': [
        {'name': 'Breast Invasive Carcinoma (British Columbia, Nature 2012)',
        'code': 'brca_bccrc',
        'id':0},
        {'name': 'Breast Invasive Carcinoma (Broad, Nature 2012)',
        'code': 'brca_broad',
        'id':1},
        {'name': 'Breast Cancer (MSKCC, 2019)',
        'code': 'brca_mskcc_2019',
        'id':3},
        {'name': 'Breast Invasive Carcinoma (Sanger, Nature 2012)',
        'code': 'brca_sanger',
        'id':4}
    ]},
    {'name': 'Lung cancer',
    'samp': [
        {'name': 'Lung Adenocarcinoma (Broad, Cell 2012)',
        'code': 'luad_broad',
        'id':6},
        {'name': 'Lung Adenocarcinoma (MSKCC, Science 2015)',
        'code': 'luad_mskcc_2015',
        'id':7},
        {'name': 'Lung Adenocarcinoma (TCGA, PanCancer Atlas)',
        'code': 'luad_tcga_pan_can_atlas_2018',
        'id':9},
        {'name': 'Lung Adenocarcinoma (TCGA, Nature 2014)',
        'code': 'luad_tcga_pub',
        'id':10}
    ]},
    {'name': 'Skin cancer',
    'samp': [
        {'name': 'Skin Cutaneous Melanoma (Broad, Cell 2012)',
        'code': 'skcm_broad',
        'id':12},
        {'name': 'Skin Cutaneous Melanoma(Broad, Cancer Discov 2014)',
        'code': 'skcm_broad_brafresist_2012',
        'id':13},
        {'name': 'Skin Cutaneous Melanoma (Yale, Nat Genet 2012)',
        'code': 'skcm_yale',
        'id':14}
    ]},
]

curr_dt = dt.datetime.now().strftime("%y-%m-%d_%H-%M-%S")
assets_dir = os.path.join(root_dir,f"assets_{curr_dt}")
os.mkdir(assets_dir)
with open(os.path.join(root_dir,f"report_{curr_dt}.md") ,"w+") as f:
    f.write("""
# Sample decompositions

Setup:
- Bootstrap: no
- Dummy signatures: yes
- Genome: hg19

""")
    print("Starting new automatic report","Progress:", sep="\n")
    for can_i, can in enumerate(rep):
        print(f"{can_i+1}/{len(rep)} {can['name']}...")
        f.write(f"## {can_i+1}. {can['name']}\n")
        for samp_i, samp in enumerate(can['samp']):
            print(f"{samp_i+1}/{len(can['samp'])} {samp['name']}...", end="\r")
            f.write(f"""
### {can_i+1}.{samp_i+1} [{samp['name']}](https://www.cbioportal.org/study?id={samp['code']})
[Download sample](http://download.cbioportal.org/{samp['code']}.tar.gz)

Histogram:
![Histogram](assets_{curr_dt}/{samp['code']}_hist.png)
""")
            this_samp = Sample.multisample(sample_files[samp['id']], 'samples/hg19.2bit')
            samp_bench = BenchmarkCorrelated(this_samp[:])
            samp_bench.histogram(threshold=0.05, savefig=os.path.join(assets_dir,f"{samp['code']}_hist.png"))

## Quick tests
From here on, you can choose a sample from the `sample_files`, create a `BenchmarkCorrelated` object and run its functions manually.

In [32]:
# Get samples from multisample file
samples = Sample.multisample(sample_files[2], 'samples/hg38.2bit')

                              

In [33]:
bench = BenchmarkCorrelated(samples[:])

Processing sample 563/563 ...
Done!


In [30]:
bench.heatmap(30)

In [34]:
bench.avg_decomp()

[{'set': 'MutaGene 5',
  'avg_decomp': [{'sig': '1', 'score': 0.4076982705196225},
   {'sig': '2', 'score': 0.14090940686612413},
   {'sig': '3', 'score': 0.06635092332898074},
   {'sig': '4', 'score': 0.09593326567295726},
   {'sig': '5', 'score': 0.07594747942508472}]},
 {'set': 'MutaGene 10',
  'avg_decomp': [{'sig': '1', 'score': 0.053383638035822925},
   {'sig': '2', 'score': 0.05290563051126785},
   {'sig': '3', 'score': 0.10368295239109553},
   {'sig': '4', 'score': 0.19668062307538434},
   {'sig': '5', 'score': 0.10868778617563164},
   {'sig': '6', 'score': 0.04195383450585331},
   {'sig': '7', 'score': 0.03209422181013312},
   {'sig': '8', 'score': 0.060399093807835515},
   {'sig': '9', 'score': 0.07306095751446862},
   {'sig': '10', 'score': 0.058268809831031616}]},
 {'set': 'COSMIC V2',
  'avg_decomp': [{'sig': '1', 'score': 0.061407276833564665},
   {'sig': '2', 'score': 0.052239193163211024},
   {'sig': '3', 'score': 0.011496249627491711},
   {'sig': '4', 'score': 0.285841

In [10]:
bench.histogram(threshold=0.0)

In [11]:
bench.distribution()

In [None]:
bench.compare(10,30,cutoff=.1)