Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvonk committed Mar 25, 2024
1 parent 2b57e2f commit 119003a
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/spei/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.dates import date2num
from numpy import array, concatenate, linspace, meshgrid, reshape
from pandas import Series
from pandas import DatetimeIndex, Series
from scipy.stats import gaussian_kde

from ._typing import Axes
Expand Down Expand Up @@ -49,9 +50,9 @@ def si(
nmax = ybound

if cmap in ("roma", "roma_r"):
cmap = roma(_r=True if "_r" in cmap else False)
colormap = roma(_r=True if "_r" in cmap else False)
else:
cmap = plt.get_cmap(cmap)
colormap = plt.get_cmap(cmap)

if background:
ax.plot(si.index, si.values, linewidth=0.8, color="k")
Expand All @@ -63,17 +64,19 @@ def si(
nodroughts[nodroughts < 0] = 0

x, y = meshgrid(si.index, linspace(nmin, nmax, 100))
ax.contourf(x, y, y, cmap=cmap, levels=linspace(nmin, nmax, 100))
ax.contourf(x, y, y, cmap=colormap, levels=linspace(nmin, nmax, 100))
ax.fill_between(x=si.index, y1=droughts, y2=nmin, color="w")
ax.fill_between(x=si.index, y1=nodroughts, y2=nmax, color="w")
else:
x = mpl.dates.date2num(si.index.to_pydatetime())
points = array([x, si.values]).T.reshape(-1, 1, 2)
datetime = DatetimeIndex(si.index).to_pydatetime()
x = date2num(datetime)
y = si.values.astype(float)
points = array([x, y]).T.reshape(-1, 1, 2)
segments = concatenate([points[:-1], points[1:]], axis=1)
lc = mpl.collections.LineCollection(
segments, cmap=cmap, norm=plt.Normalize(nmin, nmax)
segments, cmap=colormap, norm=plt.Normalize(nmin, nmax)
)
lc.set_array(si.values)
lc.set_array(y)
lc.set_linewidth(1.2)
_ = ax.add_collection(lc)

Expand Down Expand Up @@ -147,7 +150,7 @@ def monthly_density(
return ax


def roma(_r: bool = False) -> mpl.colors.LinearSegmentedColormap:
def roma(_r: bool = False) -> mpl.colors.Colormap:
colors = [
[0.492325, 0.090787, 7.6e-05],
[0.49673, 0.102802, 0.003675],
Expand Down

0 comments on commit 119003a

Please sign in to comment.