In [None]:
#default_exp data.loader
%load_ext autoreload
%autoreload 2

In [None]:
# export
from fastprogress.fastprogress import progress_bar
from pathlib import Path
import requests
import os, sys
from getpass import getpass
from datetime import datetime
from git import Repo
import re
from pymemri.gitlab_api import MEMRI_PATH, MEMRI_GITLAB_BASE_URL, ACCESS_TOKEN_PATH, GITLAB_API_BASE_URL, TIME_FORMAT_GITLAB, \
PROJET_ID_PATTERN, DEFAULT_PACKAGE_VERSION, GitlabAPI

In [None]:
# export
DEFAULT_PLUGIN_MODEL_PACKAGE_NAME = "plugin-model-package"
DEFAULT_PYTORCH_MODEL_NAME = "pytorch_model.bin"
DEFAULT_HUGGINFACE_CONFIG_NAME = "config.json"

# - Downloading & Uploading functions for Models

In [None]:
# export
def write_huggingface_model_to_package_registry(project_name, model, version=DEFAULT_PACKAGE_VERSION, client=None):
    import torch
    api = GitlabAPI(client=client)
    project_id = api.project_id_from_name(project_name)
    local_save_dir = Path("/tmp")
    torch.save(model.state_dict(), local_save_dir / DEFAULT_PYTORCH_MODEL_NAME)
    model.config.to_json_file(local_save_dir / DEFAULT_HUGGINFACE_CONFIG_NAME)
    
    for f in [DEFAULT_HUGGINFACE_CONFIG_NAME, DEFAULT_PYTORCH_MODEL_NAME]:
        file_path = local_save_dir / f
        print(f"writing {f} to package registry of {project_name} with project id {project_id}")
        api.write_file_to_package_registry(project_id, file_path, package_name=DEFAULT_PLUGIN_MODEL_PACKAGE_NAME, version=version)

In [None]:
# export
def write_model_to_package_registry(model, project_name=None, client=None):
    project_name = project_name if project_name is not None else find_git_repo()
    if type(model).__module__.startswith("transformers"):
        import transformers
        import torch
    if isinstance(model, transformers.PreTrainedModel):
        write_huggingface_model_to_package_registry(project_name, model, client=client)
    else:
        raise ValueError(f"Model type not supported: {type(model)}")

In [None]:
# export
def download_huggingface_model_for_project(project_path=None, files=None, download_if_exists=False, client=None):
    api = GitlabAPI(client=client)
    if files is None:
        files = ["config.json", "pytorch_model.bin"]
    for f in files:
        out_file_path = api.download_package_file(f, project_path=project_path, package_name=DEFAULT_PLUGIN_MODEL_PACKAGE_NAME)
    return out_file_path.parent

In [None]:
# export
def load_huggingface_model_for_project(project_path=None, files=None, download_if_exists=False, client=None):
    out_dir = download_huggingface_model_for_project(project_path, files, download_if_exists, client=client)
    from transformers import AutoModelForSequenceClassification
    model = AutoModelForSequenceClassification.from_pretrained(out_dir)
    return model

# - Transformers tests

In [None]:
# skip
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModel
model = AutoModelForSequenceClassification.from_pretrained("distilroberta-base", num_labels=10)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.weight'

In [None]:
from pymemri.gitlab_api import GitlabAPI

In [None]:
api = GitlabAPI()

In [None]:
# skip
write_model_to_package_registry(model, project_name="test-1234")

In [None]:
# skip
model = load_huggingface_model_for_project(project_path="memri/finetuning-example")

/Users/koen/.memri/projects/finetuning-example/config.json
/Users/koen/.memri/projects/finetuning-example/config.json already exists, and `download_if_exists`==False, using cached version
/Users/koen/.memri/projects/finetuning-example/pytorch_model.bin
/Users/koen/.memri/projects/finetuning-example/pytorch_model.bin already exists, and `download_if_exists`==False, using cached version


In [None]:
# skip
out_dir = download_huggingface_model_for_project(project_path="memri/finetuning-example")
model = AutoModelForSequenceClassification.from_pretrained(out_dir, num_labels=20)

/Users/koen/.memri/projects/finetuning-example/config.json
/Users/koen/.memri/projects/finetuning-example/config.json already exists, and `download_if_exists`==False, using cached version
/Users/koen/.memri/projects/finetuning-example/pytorch_model.bin
/Users/koen/.memri/projects/finetuning-example/pytorch_model.bin already exists, and `download_if_exists`==False, using cached version


# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted basic.ipynb.
Converted cvu.utils.ipynb.
Converted data.dataset.ipynb.
Converted data.loader.ipynb.
Converted data.oauth.ipynb.
Converted data.photo.ipynb.
Converted exporters.exporters.ipynb.
Converted gitlab_api.ipynb.
Converted index.ipynb.
Converted itembase.ipynb.
Converted plugin.authenticators.credentials.ipynb.
Converted plugin.authenticators.oauth.ipynb.
Converted plugin.listeners.ipynb.
Converted plugin.pluginbase.ipynb.
Converted plugin.states.ipynb.
Converted plugins.authenticators.password.ipynb.
Converted pod.api.ipynb.
Converted pod.client.ipynb.
Converted pod.db.ipynb.
Converted pod.utils.ipynb.
Converted template.config.ipynb.
Converted template.formatter.ipynb.
Converted test_schema.ipynb.
Converted test_utils.ipynb.
