Skip to content

Commit

Permalink
clean warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Feb 6, 2020
1 parent 7077570 commit 5bab5eb
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 13 deletions.
5 changes: 2 additions & 3 deletions getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from copy import deepcopy
from collections import namedtuple
from typing import Sequence, Any, Optional
from typing import Sequence, Any, Optional, Union

# whether to write to terminal chain names and burn in details when loaded from file
print_load_details = True
Expand Down Expand Up @@ -747,7 +747,7 @@ def mean_diff(self, paramVec, where=None):
else:
return paramVec[where] - self.mean(paramVec, where)

def mean_diffs(self, pars=None, where=None) -> Sequence:
def mean_diffs(self, pars: Union[None, int, Sequence] = None, where=None) -> Sequence:
"""
Calculates a list of parameter vectors giving distances from parameter means
Expand All @@ -762,7 +762,6 @@ def mean_diffs(self, pars=None, where=None) -> Sequence:
means = self.getMeans()
return [self.samples[:, i] - means[i] for i in range(pars)]
else:
# noinspection PyTypeChecker
return [self.mean_diff(i, where) for i in pars]

def twoTailLimits(self, paramVec, confidence):
Expand Down
3 changes: 2 additions & 1 deletion getdist/cobaya_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numbers import Number
import numpy as np
import os
from collections import Mapping
from typing import Mapping

# Conventions
_label = "label"
Expand Down Expand Up @@ -152,6 +152,7 @@ def get_info_params(info):
return info_params_full


# noinspection PyUnboundLocalVariable
def get_range(param_info):
# Sampled
if is_sampled_param(param_info):
Expand Down
2 changes: 1 addition & 1 deletion getdist/mcsamples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import math
import time
from collections import Mapping
from typing import Mapping

import numpy as np
from scipy.stats import norm
Expand Down
7 changes: 3 additions & 4 deletions getdist/paramnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ def makeList(roots):
def escapeLatex(text):
if text:
import matplotlib
if text and matplotlib.rcParams['text.usetex']:
return text.replace('_', '{\\textunderscore}')
else:
return text
if matplotlib.rcParams['text.usetex']:
return text.replace('_', '{\\textunderscore}')
return text


def mergeRenames(*dicts, **kwargs):
Expand Down
10 changes: 6 additions & 4 deletions getdist/plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import copy
from collections import Mapping
from typing import Mapping
import matplotlib
import sys
import warnings
Expand Down Expand Up @@ -30,6 +30,7 @@
from getdist.matplotlib_ext import BoundedMaxNLocator, SciFuncFormatter
from getdist._base import _BaseObject
from types import MappingProxyType
from typing import Sequence, Union

"""Plotting scripts for GetDist outputs"""

Expand Down Expand Up @@ -1375,7 +1376,7 @@ def _get_list(tag):

def _make_contour_args(self, nroots, **kwargs):
contour_args = self._make_line_args(nroots, **kwargs)
filled = kwargs.get('filled')
filled: Union[None, bool, Sequence] = kwargs.get('filled')
if filled and not isinstance(filled, bool):
for cont, fill in zip(contour_args, filled):
cont['filled'] = fill
Expand Down Expand Up @@ -1725,7 +1726,7 @@ def make_figure(self, nplot=1, nx=None, ny=None, xstretch=1.0, ystretch=1.0, sha
self.subplots[:, :] = None
return self.plot_col, self.plot_row

def get_param_array(self, root, params=None, renames=None):
def get_param_array(self, root, params: Union[None, str, Sequence] = None, renames: Mapping = None):
"""
Gets an array of :class:`~.paramnames.ParamInfo` for named params
in the given `root`.
Expand Down Expand Up @@ -2749,6 +2750,7 @@ def add_3d_scatter(self, root, params, color_bar=True, alpha=1, extra_thin=1, sc
else:
pts = self.sample_analyser.load_single_samples(root)
weights = 1
mcsamples = None
names = self.param_names_for_root(root)
fixed_color = kwargs.get('fixed_color') # if actually just a plain scatter plot
samples = []
Expand All @@ -2757,7 +2759,7 @@ def add_3d_scatter(self, root, params, color_bar=True, alpha=1, extra_thin=1, sc
samples.append(param.getDerived(self._make_param_object(names, pts)))
else:
samples.append(pts[:, names.numberOfName(param.name)])
if alpha_samples:
if mcsamples:
# use most samples, but alpha with weight
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize, to_rgb
Expand Down
1 change: 1 addition & 0 deletions getdist/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def compare_method_nsims(g, probs, sizes=(1000, 10000), **kwargs):

def compare_method(probs, nx=2, fname='', **kwargs):
ny = (len(probs) - 1) // nx + 1
# noinspection PyTypeChecker
fig, axs = plt.subplots(ny, nx, sharex=True, sharey=True, squeeze=False, figsize=(nx * 3, ny * 3))
for i, prob in enumerate(probs):
ax = axs.reshape(-1)[i]
Expand Down

0 comments on commit 5bab5eb

Please sign in to comment.