Skip to content

Commit 84d3dd2

Browse files
authored
Merge pull request #41 from jcha40/master
MS Timing normalization & remove numba
2 parents f1c0d14 + 991e096 commit 84d3dd2

File tree

6 files changed

+4
-10
lines changed

6 files changed

+4
-10
lines changed

BuildTree/CellPopulationEngine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections import defaultdict
55
import collections
66
import numpy as np
7-
import numba as nb
87
import itertools
98
import operator
109
import logging
@@ -46,7 +45,6 @@ def sample_ccf(xk, pk):
4645
return custm.rvs(size=1)[0]
4746

4847
@staticmethod
49-
@nb.njit()
5048
def logSumExp(ns):
5149
max_ = np.max(ns)
5250
ds = ns - max_

BuildTree/ClusterObject.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import numpy as np
3-
import numba as nb
43

54

65
class Cluster:
@@ -112,7 +111,6 @@ def remove_mutation(self, mutation, update_cluster_hist=True):
112111
self._identifier))
113112

114113
@staticmethod
115-
@nb.njit()
116114
def logSumExp(ns):
117115
max_ = np.max(ns)
118116
ds = ns - max_

BuildTree/Tree.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import numpy as np
3-
import numba as nb
43
import itertools
54
import functools
65

@@ -265,7 +264,6 @@ def get_possible_configurations(potential_children):
265264
return new_children
266265

267266
@staticmethod
268-
@nb.njit()
269267
def logSumExp(ns):
270268
max_ = np.max(ns)
271269
ds = ns - max_

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from bitnami/minideb
22
RUN install_packages python-pip build-essential python-dev r-base r-base-dev git graphviz python-tk
33
RUN pip install setuptools wheel
4-
RUN pip install numpy scipy matplotlib pandas numba
4+
RUN pip install numpy scipy matplotlib pandas
55
COPY req /tmp/req
66
RUN apt-get -y upgrade
77
RUN apt-get -y update

SinglePatientTiming/TimingEngine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def get_concordant_cn_states(self):
5353
states_across_samples = {} # hold union of all copy number states in all samples
5454
for sample in self.sample_list:
5555
if region in sample.cn_states:
56+
# Coerce copy number states to integer states
5657
if 0. <= sample.cn_states[region].cn_a1 < 1.:
5758
cn_a1 = 0.
5859
elif 1. < sample.cn_states[region].cn_a1 <= 2.:
@@ -182,7 +183,7 @@ def _get_cluster_ccfs(self):
182183
cluster_ccfs.setdefault(c, np.zeros((n_samples, 101)))
183184
cluster_ccfs[c] += np.log(mut.ccf_dist + 1e-10)
184185
for c in cluster_ccfs:
185-
cluster_ccfs[c] = np.exp(cluster_ccfs[c] - logsumexp(cluster_ccfs[c]))
186+
cluster_ccfs[c] = np.exp(cluster_ccfs[c] - logsumexp(cluster_ccfs[c], axis=1, keepdims=True))
186187
return cluster_ccfs
187188

188189
def time_events(self):
@@ -375,7 +376,7 @@ def call_wgd(self, use_concordant_states=False):
375376
regions_supporting_WGD.append(cn_state)
376377
if cn_state.cn_a1 >= 2 and cn_state.cn_a2 >= 2:
377378
regions_both_arms_gained.append(cn_state)
378-
if len(regions_both_arms_gained) >= 5 or len(regions_supporting_WGD) * 2 >= \
379+
if len(regions_both_arms_gained) >= 5 and len(regions_supporting_WGD) * 2 >= \
379380
len(self.arm_regions) - len(self.missing_arms):
380381
supporting_arm_states = [TimingCNState([self], s.chrN, s.arm, (s.cn_a1, s.cn_a2), s.purity, supporting_muts=s.supporting_muts) for
381382
s in supporting_arm_states]

req

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
intervaltree==2.1.0
22
scikit-learn==0.18.1
33
networkx==1.11
4-
numba==0.45.1
54
seaborn
65

0 commit comments

Comments
 (0)