Skip to content

Commit

Permalink
Fix for dtype with boundary assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
johnlees committed Oct 2, 2020
1 parent 7414af7 commit e95a720
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.2.0'
__version__ = '2.2.1'
22 changes: 14 additions & 8 deletions PopPUNK/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ClusterFit:
The output prefix used for reading/writing
'''

def __init__(self, outPrefix):
def __init__(self, outPrefix, default_dtype = np.float32):
self.outPrefix = outPrefix
if outPrefix != "" and not os.path.isdir(outPrefix):
try:
Expand All @@ -96,6 +96,7 @@ def __init__(self, outPrefix):

self.fitted = False
self.indiv_fitted = False
self.default_dtype = default_dtype


def fit(self, X = None):
Expand All @@ -110,6 +111,8 @@ def fit(self, X = None):
preprocess is set.
(default = None)
default_dtype (numpy dtype)
Type to use if no X provided
'''
# set output dir
if not os.path.isdir(self.outPrefix):
Expand All @@ -119,6 +122,9 @@ def fit(self, X = None):
sys.stderr.write(self.outPrefix + " already exists as a file! Use a different --output\n")
sys.exit(1)

if X is not None:
self.default_dtype = X.dtype

# preprocess subsampling
if self.preprocess:
if X.shape[0] > self.max_samples:
Expand Down Expand Up @@ -154,7 +160,7 @@ def no_scale(self):
'''Turn off scaling (useful for refine, where optimization
is done in the scaled space).
'''
self.scale = np.array([1, 1])
self.scale = np.array([1, 1], dtype = self.default_dtype)


class BGMMFit(ClusterFit):
Expand Down Expand Up @@ -479,7 +485,7 @@ def assign(self, X, no_scale = False):
raise RuntimeError("Trying to assign using an unfitted model")
else:
if no_scale:
scale = np.array([1,1])
scale = np.array([1, 1], dtype = X.dtype)
else:
scale = self.scale
y = assign_samples_dbscan(X, self.hdb, scale)
Expand Down Expand Up @@ -629,22 +635,22 @@ def apply_threshold(self, X, threshold):
y (numpy.array)
Cluster assignments of samples in X
'''
self.scale = np.array([1,1])
self.scale = np.array([1, 1], dtype = X.dtype)

# Blank values to pass to plot
self.mean0 = None
self.mean1 = None
self.start_point = None
self.min_move = None
self.max_move = None

# Sets threshold
self.core_boundary = threshold
self.accessory_boundary = np.nan
self.optimal_x = threshold
self.optimal_y = np.nan
self.slope = 0

# Flags on refine model
self.fitted = True
self.threshold = True
Expand Down Expand Up @@ -713,7 +719,7 @@ def plot(self, X, y=None):
# Subsamples huge plots to save on memory
max_points = int(0.5*(5000)**2)
if X.shape[0] > max_points:
plot_X = utils.shuffle(X, random_state=random.randint(1,10000))[0:max_points,]
plot_X = utils.shuffle(X, random_state=random.randint(1, 10000))[0:max_points, ]
else:
plot_X = X

Expand Down

0 comments on commit e95a720

Please sign in to comment.