From ab2dcd0a8582f48501d29efaaf5af9abca1d7077 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 25 Jan 2024 16:47:27 +0200 Subject: [PATCH] Defer importing matplotlib --- .changeset/old-rice-laugh.md | 5 +++++ gradio/helpers.py | 7 ++++--- gradio/utils.py | 5 ++++- 3 files changed, 13 insertions(+), 4 deletions(-) create mode 100644 .changeset/old-rice-laugh.md diff --git a/.changeset/old-rice-laugh.md b/.changeset/old-rice-laugh.md new file mode 100644 index 0000000000000..eaa258bcd813a --- /dev/null +++ b/.changeset/old-rice-laugh.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Defer importing matplotlib diff --git a/gradio/helpers.py b/gradio/helpers.py index faf5bb3a900c8..92faed81b7ab4 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -15,13 +15,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional -import matplotlib.pyplot as plt import numpy as np import PIL import PIL.Image from gradio_client import utils as client_utils from gradio_client.documentation import document, set_documentation_group -from matplotlib import animation from gradio import components, oauth, processing_utils, routes, utils, wasm_utils from gradio.context import Context, LocalContext @@ -903,6 +901,9 @@ def make_waveform( Returns: A filepath to the output video in mp4 format. """ + import matplotlib.pyplot as plt + from matplotlib.animation import FuncAnimation + if isinstance(audio, str): audio_file = audio audio = processing_utils.audio_from_file(audio) @@ -1024,7 +1025,7 @@ def _animate(_): b.set_y((-rand_height * samples)[idx]) frames = int(duration * 10) - anim = animation.FuncAnimation( + anim = FuncAnimation( fig, # type: ignore _animate, repeat=False, diff --git a/gradio/utils.py b/gradio/utils.py index a9b2ab52fe0a4..a819953856dd9 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -39,7 +39,6 @@ import anyio import httpx -import matplotlib from typing_extensions import ParamSpec import gradio @@ -881,10 +880,14 @@ def __str__(self): class MatplotlibBackendMananger: def __enter__(self): + import matplotlib + self._original_backend = matplotlib.get_backend() matplotlib.use("agg") def __exit__(self, exc_type, exc_val, exc_tb): + import matplotlib + matplotlib.use(self._original_backend)