In [81]:
from mutagene.profiles import Sample
from mutagene.signatures import Identify
from mutagene.signatures.constants import sig_by_etiology
from pprint import pprint
import matplotlib.pyplot as plt
from glob import glob
%matplotlib tk


genome = 'samples/hg38.2bit'

class Correlated:
    def __init__(self, infile):
        self.sample = Sample(infile, genome)

        self.decompositions = []
        for sig_set in [5,10,30,49]:
            self.decompositions.append(Identify(self.sample.profile, sig_set, bootstrap=True, dummy_sigs=True))

        self.results = []
        for corr in sig_by_etiology:
            _corr_res = {
                'etiology': corr['etiology'],
                'sig_set': {}
            }
            for sig_set, decomp in zip([5,10,30,49], self.decompositions):
                for sig in corr['sig'][sig_set]:
                    _corr_res['sig_set'].setdefault(sig_set, [])
                    found = list(filter(lambda x: x['name'] == sig, decomp.decomposition))
                    if found:
                        res = found[0]
                        _corr_res['sig_set'][sig_set].append({
                            'sig': sig,
                            'score': res['score'],
                            'mutations': res['mutations']
                        })
                    else:
                        _corr_res['sig_set'][sig_set].append({
                            'sig': sig,
                            'score': 0,
                            'mutations': 0
                        })
            self.results.append(_corr_res)
            
    def scatter_plot(self, set_a, set_b):
        labels = []
        yticks = []
        xticks = []
        y = []
        x = []
        for res in self.results:
            if set_a in res['sig_set'] and set_b in res['sig_set']:
                # label
                labels.append(res['etiology'])

                # xticks
                sig_names_a = [ sig['sig'] for sig in res['sig_set'][set_a] ]
                xticks.append(' '.join(sig_names_a))

                # x
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_a] ]
                x.append(sum(mutations_list))

                # yticks
                sig_names_b = [sig['sig'] for sig in res['sig_set'][set_b]]
                yticks.append(' '.join(sig_names_b))

                # y
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_b] ]
                y.append(sum(mutations_list))

        assert len(yticks)==len(xticks)==len(x)==len(y)==len(labels), "Shapes don't match"
        max_val = max( [ max(x), max(y) ] )
        plt.plot([0,max_val],[0,max_val],'-g')
        plt.scatter(x, y)
        plt.xticks(x, xticks)
        plt.xticks(rotation=90)
        plt.yticks(y, yticks)
        plt.ylabel('Signature b')
        plt.xlabel('Signature a')
        plt.title('Scatter Plot: {}'.format(sample_files[0]))
        for i, label in enumerate(labels):
            plt.annotate(f'{label} ({int(round(x[i]))}, {int(round(y[i]))})', (x[i], y[i]+3), ha='center')
        plt.show()
    
    def _get_plot_data(self, sig_set):
        y = []
        for res in self.results:
            if sig_set in res['sig_set']:
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][sig_set] ]
                y.append(sum(mutations_list))
            else:
                y.append(0)
        return y    
    
    def bar_chart(self):
        for i, (sig_set, color) in enumerate([(5,'#247BA0'),(10,'#70C1B3'),(30,'#B2DBBF'),(49,'#E9C46A')]):  
            y = self._get_plot_data(sig_set)
            plt.barh(
                range(i, len(self.results)*5, 5), y,
                color=color,
                label=sig_set
            )

            if sig_set == 5:
                labels = [res['etiology'] for res in self.results]
                for j, label in enumerate(labels):
                    plt.annotate(label, (-1, range(i, len(self.results)*5, 5)[j]+1), ha='right')

        ticks = []
        for res in self.results:
            for sig_set in [5,10,30,49]:
                if sig_set in res['sig_set']:
                    sig_names = [ sig['sig'] for sig in res['sig_set'][sig_set] ]
                    ticks.append(' '.join(sig_names))
                else:
                    ticks.append('')
            ticks.append('')

        plt.yticks(range(len(self.results)*5), ticks)
        plt.legend()
        plt.ylabel('Signature')
        plt.xlabel('Number of Mutations')
        plt.title('Bar Chart: {}'.format(sample_files[0]))
        plt.show()    
                        
sample_files = glob('samples/*.vcf')

In [82]:
sample1 = Correlated(sample_files[0])

                                                     

In [83]:
sample1.scatter_plot(30, 49)

In [84]:
sample1.bar_chart()