Skip to content

Commit

Permalink
clean catalog command
Browse files Browse the repository at this point in the history
  • Loading branch information
miquelduranfrigola committed Jun 12, 2023
1 parent 8d190aa commit 8a666f4
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 159 deletions.
71 changes: 31 additions & 40 deletions ersilia/cli/commands/catalog.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import click


from . import ersilia_cli
from ...hub.content.catalog import ModelCatalog
from ...hub.content.search import ModelSearcher
from ...hub.content.table_update import table


def catalog_cmd():
Expand All @@ -14,53 +12,46 @@ def catalog_cmd():
@ersilia_cli.command(help="List a catalog of models")
@click.option(
"-l",
"--local",
"--local/--hub",
is_flag=True,
default=False,
help="Show catalog of models available in the local computer",
)
@click.option(
"-t",
"--text",
default=None,
type=click.STRING,
help="Shows the model related to input keyword",
)
@click.option(
"-m",
"--mode",
default=None,
type=click.STRING,
help="Shows the model trained via input mode",
"--file_name", "-f", default=None, type=click.STRING, help="Catalog file name"
)
@click.option(
"-n", "--next", is_flag=True, default=False, help="Shows the next table"
"--browser", is_flag=True, default=False, help="Show catalog in the browser"
)
@click.option(
"-p", "--previous", is_flag=True, default=False, help="Shows previous table"
"--more/--less",
is_flag=True,
default=False,
help="Show more information than just the EOS identifier",
)
def catalog(
local=False, search=None, text=None, mode=None, next=False, previous=False
):
mc = ModelCatalog()
if not (local or text or mode):
catalog = mc.hub()
if not (next or previous):
catalog = table(catalog).initialise()

if next:
catalog = table(catalog).next_table()

if previous:
catalog = table(catalog).prev_table()

def catalog(local=False, file_name=None, browser=False, more=False):
if local is True and browser is True:
click.echo(
"You cannot show the local model catalog in the browser", fg="red"
)
if more:
only_identifier = False
else:
only_identifier = True
mc = ModelCatalog(only_identifier=only_identifier)
if browser:
mc.airtable()
return
if local:
catalog = mc.local()

if text:
catalog = mc.hub()
catalog = ModelSearcher(catalog).search_text(text)
if mode:
catalog = mc.hub()
catalog = ModelSearcher(catalog).search_mode(mode)
if file_name is None:
catalog = mc.local().as_json()
else:
mc.local().write(file_name)
catalog = None
else:
if file_name is None:
catalog = mc.hub().as_json()
else:
mc.hub().write(file_name)
catalog = None
click.echo(catalog)
133 changes: 79 additions & 54 deletions ersilia/hub/content/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import subprocess
import requests
import os
import json
import csv
from .card import ModelCard
from ... import ErsiliaBase
from ...utils.identifiers.model import ModelIdentifier
Expand All @@ -20,40 +22,50 @@
except ModuleNotFoundError as err:
Github = None

try:
from tabulate import tabulate
except ModuleNotFoundError as err:
tabulate = None


class CatalogTable(object):
def __init__(self, data, columns):
self.data = data
self.columns = columns

def as_table(self):
if not tabulate:
return None
else:
return tabulate(
self.data,
headers=self.columns,
tablefmt="fancy_grid",
colalign=("center", "center", "center"),
)
def as_list_of_dicts(self):
R = []
for r in self.data:
d = {}
for i, c in enumerate(self.columns):
d[c] = r[i]
R += [d]
return R

def as_json(self):
R = self.as_list_of_dicts()
return json.dumps(R, indent=4)

def write(self, file_name):
with open(file_name, "w") as f:
if file_name.endswith(".csv"):
delimiter = ","
elif file_name.endswith(".tsv"):
delimiter = "\t"
else:
return None
writer = csv.writer(f, delimiter=delimiter)
writer.writerow(self.columns)
for r in self.data:
writer.writerow(r)

def __str__(self):
return self.as_table()
return self.as_json()

def __repr__(self):
return self.__str__()


class ModelCatalog(ErsiliaBase):
def __init__(self, tabular_view=True, config_json=None):
def __init__(self, config_json=None, only_identifier=True):
ErsiliaBase.__init__(self, config_json=config_json)
self.mi = ModelIdentifier()
self.tabular_view = tabular_view
self.only_identifier = only_identifier

def _is_eos(self, s):
if self.mi.is_valid(s):
Expand All @@ -75,21 +87,21 @@ def _get_slug(self, card):
return card["Slug"]
return None

def _get_mode(self, card):
if "mode" in card:
return card["mode"]
if "Mode" in card:
return card["Mode"]

def airtable(self):
"""List models available in AirTable Ersilia Model Hub base"""
if webbrowser: # TODO: explore models online
if not self.tabular_view:
webbrowser.open(
"https://airtable.com/shr9sYjL70nnHOUrP/tblZGe2a2XeBxrEHP"
)
if webbrowser:
webbrowser.open("https://airtable.com/shrUcrUnd7jB9ChZV")

def _get_all_github_public_repos(self):
url = "https://api.github.com/users/{0}/repos".format(GITHUB_ORG)
while url:
response = requests.get(url, params={"per_page": 100})
response.raise_for_status()
yield from response.json()
if "next" in response.links:
url = response.links["next"]["url"] # get the next page
else:
webbrowser.open("https://airtable.com/shrUcrUnd7jB9ChZV")
break # no more pages, stop the loop

def github(self):
"""List models available in the GitHub model hub repository"""
Expand All @@ -101,6 +113,7 @@ def github(self):
"Looking for model repositories in {0} organization".format(GITHUB_ORG)
)
if token:
self.logger.debug("Token provided: ***")
g = Github(token)
repo_list = [i for i in g.get_user().get_repos()]
repos = []
Expand All @@ -110,11 +123,10 @@ def github(self):
continue
repos += [name]
else:
self.logger.debug("Token not provided!")
repos = []
url = "https://api.github.com/users/{0}/repos".format(GITHUB_ORG)
results = requests.get(url).json()
for r in results:
repos += [r["name"]]
for repo in self._get_all_github_public_repos():
repos += [repo["name"]]
models = []
for repo in repos:
if self._is_eos(repo):
Expand All @@ -126,34 +138,47 @@ def hub(self):
"""List models available in Ersilia model hub repository"""
mc = ModelCard()
models = self.github()
R = []
for model_id in models:
card = mc.get(model_id)
if card is None:
continue
slug = self._get_slug(card)
title = self._get_title(card)
mode = self._get_mode(card)
R += [[model_id, slug, title, mode]]
return CatalogTable(R, columns=["MODEL_ID", "SLUG", "TITLE", "MODE"])
if self.only_identifier:
R = []
for model_id in models:
R += [[model_id]]
return CatalogTable(R, columns=["Identifier"])
else:
R = []
for model_id in models:
card = mc.get(model_id)
if card is None:
continue
slug = self._get_slug(card)
title = self._get_title(card)
R += [[model_id, slug, title]]
return CatalogTable(R, columns=["Identifier", "Slug", "Title"])

def local(self):
"""List models available locally"""
mc = ModelCard()
R = []
logger.debug("Looking for models in {0}".format(self._bundles_dir))
for model_id in os.listdir(self._bundles_dir):
if not self._is_eos(model_id):
continue
card = mc.get(model_id)
slug = self._get_slug(card)
title = self._get_title(card)
mode = self._get_mode(card)
R += [[model_id, slug, title, mode]]
if self.only_identifier:
R = []
for model_id in os.listdir(self._bundles_dir):
if not self._is_eos(model_id):
continue
R += [[model_id]]
columns = ["Identifier"]
else:
for model_id in os.listdir(self._bundles_dir):
if not self._is_eos(model_id):
continue
card = mc.get(model_id)
slug = self._get_slug(card)
title = self._get_title(card)
R += [[model_id, slug, title]]
columns = ["Identifier", "Slug", "Title"]
logger.info("Found {0} models".format(len(R)))
if len(R) == 0:
return None
return CatalogTable(data=R, columns=["MODEL_ID", "SLUG", "TITLE", "MODE"])
return CatalogTable(data=R, columns=columns)

def bentoml(self):
"""List models available as BentoServices"""
Expand All @@ -175,5 +200,5 @@ def bentoml(self):
for i, idx in enumerate(zip(cut_idxs, cut_idxs[1:] + [None])):
r += [row[idx[0] : idx[1]].rstrip()]
R += [[r[0].split(":")[0]] + r]
columns = ["MODEL_ID"] + columns
columns = ["Identifier"] + columns
return CatalogTable(data=R, columns=columns)
65 changes: 0 additions & 65 deletions ersilia/hub/content/table_update.py

This file was deleted.

0 comments on commit 8a666f4

Please sign in to comment.