In [66]:
from mutagene.profiles import Sample
from mutagene.signatures import Identify
from mutagene.signatures.constants import sig_by_etiology
from pprint import pprint
import matplotlib.pyplot as plt
from adjustText import adjust_text
import numpy as np
from glob import glob
import matplotlib.patches as mpatches
%matplotlib tk

genome = 'samples/hg38.2bit'

class Correlated:
    def __init__(self, infile):
        self.sample = Sample(infile, genome)

        self.decompositions = []
        for sig_set in [5,10,30,49]:
            self.decompositions.append(Identify(self.sample.profile, sig_set, bootstrap=False, dummy_sigs=True))

        self.results = []
        for corr in sig_by_etiology:
            _corr_res = {
                'etiology': corr['etiology'],
                'sig_set': {}
            }
            for sig_set, decomp in zip([5,10,30,49], self.decompositions):
                for sig in corr['sig'][sig_set]:
                    _corr_res['sig_set'].setdefault(sig_set, [])
                    found = list(filter(lambda x: x['name'] == sig, decomp.decomposition))
                    if found:
                        res = found[0]
                        _corr_res['sig_set'][sig_set].append({
                            'sig': sig,
                            'score': res['score'],
                            'mutations': res['mutations']
                        })
                    else:
                        _corr_res['sig_set'][sig_set].append({
                            'sig': sig,
                            'score': 0,
                            'mutations': 0
                        })
            self.results.append(_corr_res)
            
    def score_scatter_plot(self, set_a, set_b):
        labels = []
        yticks = []
        xticks = []
        y = []
        x = []
        for res in self.results:
            if set_a in res['sig_set'] and set_b in res['sig_set']:
                # label
                labels.append(res['etiology'])

                # xticks
                sig_names_a = [ sig['sig'] for sig in res['sig_set'][set_a] ]
                xticks.append(' '.join(sig_names_a))

                # x
                score_list = [ sig['score'] for sig in res['sig_set'][set_a] ]
                x.append(sum(score_list))

                # yticks
                sig_names_b = [sig['sig'] for sig in res['sig_set'][set_b]]
                yticks.append(' '.join(sig_names_b))

                # y
                score_list = [ sig['score'] for sig in res['sig_set'][set_b] ]
                y.append(sum(score_list))
        
        colorEtiology = ['#2300A8', '#00A658', '#FF333C', '#0F0101', '#55043E', '#470455', '#8897D5', '#88CBD5', '#88D5B9', '#99D588', '#C4D533', '#D5B833', '#D58E33']
        
        assert len(yticks)==len(xticks)==len(x)==len(y)==len(labels), "Shapes don't match"
        #max_val = max( [ max(x), max(y) ] )
        #plt.plot([0,max_val],[0,max_val],'-g')
        plt.scatter(x, y, color=colorEtiology)
        xRounded = np.round_(x,3)
        plt.xticks(x, xRounded)
        #plt.xticks(rotation=90)
        yRounded = np.round_(y,3)
        plt.yticks(y, yRounded)
        plt.ylabel('Score : {}'.format(set_b))
        plt.xlabel('Score : {}'.format(set_a))
        plt.title('Scatter Plot: {}'.format(sample_files[0]))
        # Etiology labels
        #text = []
        #for i, label in enumerate(labels):
        #    text.append(plt.text(x[i], y[i], f'{label} ({int(round(x[i]))}, {int(round(y[i]))})')) 
        #adjust_text(text)
        #etiology legend
        #plt.legend(colorEtiology, labels)
        leg_0 = mpatches.Patch(color=colorEtiology[0], label=labels[0])
        leg_1 = mpatches.Patch(color=colorEtiology[1], label=labels[1])
        leg_2 = mpatches.Patch(color=colorEtiology[2], label=labels[2])
        leg_3 = mpatches.Patch(color=colorEtiology[3], label=labels[3])
        leg_4 = mpatches.Patch(color=colorEtiology[4], label=labels[4])
        leg_5 = mpatches.Patch(color=colorEtiology[5], label=labels[5])
        leg_6 = mpatches.Patch(color=colorEtiology[6], label=labels[6])
        leg_7 = mpatches.Patch(color=colorEtiology[7], label=labels[7])
        leg_8 = mpatches.Patch(color=colorEtiology[8], label=labels[8])
        leg_9 = mpatches.Patch(color=colorEtiology[9], label=labels[9])
        leg_10 = mpatches.Patch(color=colorEtiology[10], label=labels[10])
        leg_11 = mpatches.Patch(color=colorEtiology[11], label=labels[11])
        leg_12 = mpatches.Patch(color=colorEtiology[12], label=labels[12])
                        
        plt.legend(handles=[leg_0,leg_1,leg_2,leg_3,leg_4,leg_5,leg_6,leg_7,leg_8,leg_9,leg_2,leg_10,leg_11,leg_12])

        plt.show()
    
    def _get_plot_data(self, sig_set):
        y = []
        for res in self.results:
            if sig_set in res['sig_set']:
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][sig_set] ]
                y.append(sum(mutations_list))
            else:
                y.append(0)
        return y    
    
    def bar_chart(self):
        for i, (sig_set, color) in enumerate([(5,'#247BA0'),(10,'#70C1B3'),(30,'#B2DBBF'),(49,'#E9C46A')]):  
            y = self._get_plot_data(sig_set)
            plt.barh(
                range(i, len(self.results)*5, 5), y,
                color=color,
                label=sig_set
            )

            if sig_set == 5:
                labels = [res['etiology'] for res in self.results]
                for j, label in enumerate(labels):
                    plt.annotate(label, (-1, range(i, len(self.results)*5, 5)[j]+1), ha='right')

        ticks = []
        for res in self.results:
            for sig_set in [5,10,30,49]:
                if sig_set in res['sig_set']:
                    sig_names = [ sig['sig'] for sig in res['sig_set'][sig_set] ]
                    ticks.append(' '.join(sig_names))
                else:
                    ticks.append('')
            ticks.append('')

        plt.yticks(range(len(self.results)*5), ticks)
        plt.legend()
        plt.ylabel('Signature')
        plt.xlabel('Number of Mutations')
        plt.title('Bar Chart: {}'.format(sample_files[0]))
        plt.show()
            
sample_files = glob('samples/*.vcf')

In [67]:
sample1 = Correlated(sample_files[0])

                                        

In [68]:
sample1.score_scatter_plot(30, 49)

In [28]:
sample1.bar_chart()