In [127]:
from mutagene.profiles import Sample
from mutagene.signatures import Identify
from mutagene.signatures.constants import sig_by_etiology
from scipy.stats import gaussian_kde
import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
from glob import glob
%matplotlib tk

In [201]:
class Correlated:
    def __init__(self, profile):
        self.profile = profile
        
        self.decompositions = []
        for sig_set in [5,10,30,49]:
            self.decompositions.append(Identify(self.profile, sig_set, bootstrap=False, dummy_sigs=True))

        self.results = []
        for corr in sig_by_etiology:
            _corr_res = {
                'etiology': corr['etiology'],
                'sig_set': {}
            }
            for sig_set, decomp in zip([5,10,30,49], self.decompositions):
                for sig in corr['sig'][sig_set]:
                    _corr_res['sig_set'].setdefault(sig_set, [])
                    found = list(filter(lambda x: x['name'] == sig, decomp.decomposition))
                    if found:
                        res = found[0]
                        _corr_res['sig_set'][sig_set].append({
                            'sig': sig,
                            'score': res['score'],
                            'mutations': res['mutations']
                        })
                    else:
                        _corr_res['sig_set'][sig_set].append({
                            'sig': sig,
                            'score': 0,
                            'mutations': 0
                        })
            self.results.append(_corr_res)
            
    def scatter_plot(self, set_a, set_b):
        labels = []
        yticks = []
        xticks = []
        y = []
        x = []
        for res in self.results:
            if set_a in res['sig_set'] and set_b in res['sig_set']:
                # label
                labels.append(res['etiology'])

                # xticks
                sig_names_a = [ sig['sig'] for sig in res['sig_set'][set_a] ]
                xticks.append(' '.join(sig_names_a))

                # x
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_a] ]
                x.append(sum(mutations_list))

                # yticks
                sig_names_b = [sig['sig'] for sig in res['sig_set'][set_b]]
                yticks.append(' '.join(sig_names_b))

                # y
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_b] ]
                y.append(sum(mutations_list))

        assert len(yticks)==len(xticks)==len(x)==len(y)==len(labels), "Shapes don't match"
        max_val = max( [ max(x), max(y) ] )
        plt.plot([0,max_val],[0,max_val],'-g')
        plt.scatter(x, y)
        plt.xticks(x, xticks)
        plt.xticks(rotation=90)
        plt.yticks(y, yticks)
        plt.ylabel('Signature b')
        plt.xlabel('Signature a')
        plt.title('Scatter Plot: {}'.format(sample_files[0]))
        text = []
        for i, label in enumerate(labels):
            text.append(plt.text(x[i], y[i], f'{label} ({int(round(x[i]))}, {int(round(y[i]))})'))
        adjust_text(texts)
        plt.show()
    
    def _get_plot_data(self, sig_set):
        y = []
        for res in self.results:
            if sig_set in res['sig_set']:
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][sig_set] ]
                y.append(sum(mutations_list))
            else:
                y.append(0)
        return y    
    
    def bar_chart(self):
        for i, (sig_set, color) in enumerate([(5,'#247BA0'),(10,'#70C1B3'),(30,'#B2DBBF'),(49,'#E9C46A')]):  
            ind = range(i, len(self.results)*5, 5)
            y = self._get_plot_data(sig_set)
            plt.barh(
                ind, y,
                color=color,
                label=sig_set
            )
            labels = list(map(lambda x: x if x>0 else '', y))
            for j in range(len(y)):
                plt.annotate(labels[j], (y[j]+1, ind[j]), va='center')
        ticks = []
        for res in self.results:
            for sig_set in [5,10,30,49]:
                if sig_set in res['sig_set']:
                    sig_names = [ sig['sig'] for sig in res['sig_set'][sig_set] ]
                    ticks.append(' '.join(sig_names))
                else:
                    ticks.append('')
            ticks.append('')
        for i in range(len(ticks)):
            plt.annotate(ticks[i], (-1, i), ha='right', va='center')
        labels = [res['etiology'] for res in self.results]
        plt.yticks(range(3,len(self.results)*5,5), labels)
        plt.legend()
        plt.ylabel('Etiology')
        plt.xlabel('Number of Mutations')
        plt.title('Etiology vs mutations over all signature sets')
        plt.tick_params(left=False)
        plt.show()
        
    def stacked_bars(self, set_a, set_b):
        labels = []
        ticks_a = []
        ticks_b = []
        y_a = []
        y_b = []
        
        def norm(a,b):
            s = a+b
            if s == 0:
                return 0.5,0.5
            _a = a/s
            _b = b/s
            return _a, _b
        
        for res in self.results:
            if set_a in res['sig_set'] and set_b in res['sig_set']:
                # label
                labels.append(res['etiology'])
                
                # ticks_a
                sig_names_a = [ sig['sig'] for sig in res['sig_set'][set_a] ]
                ticks_a.append(' '.join(sig_names_a))

                # ticks_b
                sig_names_b = [sig['sig'] for sig in res['sig_set'][set_b]]
                ticks_b.append(' '.join(sig_names_b))
                
                # y_a
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_a] ]
                y_a.append(sum(mutations_list))

                # y_b
                mutations_list = [ sig['mutations'] for sig in res['sig_set'][set_b] ]
                y_b.append(sum(mutations_list))

        assert len(ticks_a)==len(ticks_b)==len(y_b)==len(y_a)==len(labels), "Shapes don't match"
        
        norm_a, norm_b = zip(*map(norm, y_a, y_b))
        
        ind = range(0,len(labels)*3,3)
        width = 1
        bars_a = plt.barh(ind, norm_a, width, color='#FFCC7A')
        bars_b = plt.barh(ind, norm_b, width, left=norm_a, color='#819BC1')
        plt.legend((bars_a[0], bars_b[0]), (set_a, set_b))
        plt.plot([0.5 for i in labels], ind, '|w')
        for i in range(len(labels)):
            plt.annotate(labels[i], (0.5, ind[i]+0.8), ha='center')
            plt.annotate(f"{y_a[i]} mut(s)", (norm_a[i]-0.01, ind[i]), ha='right', va='center')
            plt.annotate(f"{y_b[i]} mut(s)", (norm_a[i]+0.01, ind[i]), ha='left', va='center')
            plt.annotate(ticks_a[i], (0, ind[i]-0.8), ha='left', va='top', color='#616161')
            plt.annotate(ticks_b[i], (1, ind[i]-0.8), ha='right', va='top', color='#616161')
        plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
        plt.show()

In [202]:
# Find sample files
sample_files = glob('samples/*.txt')
list(enumerate(sample_files))

[(0, 'samples/data_mutations_mskcc.txt')]

In [None]:
# Get samples from multisample file
samples = Sample.multisample(sample_files[0], 'samples/hg38.2bit')

In [None]:
# You can pick a specific sample and plot its profile
samples[0].plot_profile()

In [205]:
# Generate a Correlated class for each sample
corrs = []
for samp in samples:
    corrs.append(Correlated(samp.profile))

In [207]:
# You can pick a specific sample and see the correlated results
corrs[0].stacked_bars(10,49)

In [191]:
# Calculate distribution of contributions per etiology and signature set
etiology = 4 # Index base on the sig_by_etiology dictionary
plot_x = np.linspace(0, 1, 1000)
plot_y = {
    5:[],
    10:[],
    30:[],
    49:[]
}
all_contribs = {
    5:[],
    10:[],
    30:[],
    49:[]
}
sig_sets = [5,10,30,49]
for sig_set in sig_sets:
    if sig_set in corrs[0].results[etiology]['sig_set']:
        contribs = []
        for corr in corrs:
            contrib_list = [ sig['score'] for sig in corr.results[etiology]['sig_set'][sig_set] ]
            contribs.append(sum(contrib_list))
        all_contribs[sig_set] = contribs
        dist = gaussian_kde(contribs, 0.3)
        plot_y[sig_set] = dist.evaluate(plot_x)

In [192]:
# Plot pdf
for i, style in enumerate(['-b','-r','-c','-g']):
    if len(plot_y[sig_sets[i]]) > 0:
        plt.plot(plot_x, plot_y[sig_sets[i]], style)

plt.show()