In [1]:
import sys
import os
from pathlib import Path

# Obtener el directorio raíz del proyecto (un nivel arriba desde el notebook)
project_root = Path(os.getcwd()).parent

# Añadir el directorio raíz y el subdirectorio que contiene el módulo networks al sys.path
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

In [2]:
import torch
import torch.onnx
from huggingface_hub import HfApi
from networks import vit

In [3]:
def convert_pth_to_onnx(pth_path, onnx_path):
    # Cargar el modelo ViT
    # model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
    model = vit(n_channels=1, num_classes=1, fine_tune='full')
    
    # Cargar los pesos del checkpoint
    state_dict = torch.load(pth_path, map_location="cpu")
    model.load_state_dict(state_dict['model_state_dict'])
    model.eval()

    # Preparar un input de ejemplo
    dummy_input = torch.randn(1, 1, 224, 224)

    # Exportar a ONNX
    torch.onnx.export(model, dummy_input, onnx_path, opset_version=14)

def upload_to_huggingface(repo_id, file_path):
    api = HfApi()
    api.upload_file(
        path_or_fileobj=file_path,
        path_in_repo=file_path.split("/")[-1],
        repo_id=repo_id,
        repo_type="model",
    )

In [4]:
# Ejemplo de uso
repo_id = "SemilleroCV/vit-base-patch16-224-thermal-breast-cancer"

checkpoints_path = '../checkpoints/vit_32_full_00001/vit-base-patch16-224-thermal-breast-cancer'

models_list = ['h7knv1x1_1_checkpoint', '7bv92e7b_2_checkpoint', 
               '7sraw3yj_3_checkpoint', 'cypjhdg1_4_checkpoint', 
               'ok5dhuqe_5_checkpoint', 'st7nnu18_6_checkpoint',
               '0kma18yo_7_checkpoint']

for model in models_list:
    pth_path = f"{checkpoints_path}/{model}.pth"
    onnx_path = f"{checkpoints_path}/{model}.onnx"
    
    # Convertir de .pth a .onnx
    convert_pth_to_onnx(pth_path, onnx_path)
    
    # Subir a Hugging Face
    upload_to_huggingface(repo_id, onnx_path)

print("Conversión y carga completadas.")

  assert condition, message


h7knv1x1_1_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

7bv92e7b_2_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

7sraw3yj_3_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

cypjhdg1_4_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

ok5dhuqe_5_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

st7nnu18_6_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

0kma18yo_7_checkpoint.onnx:   0%|          | 0.00/342M [00:00<?, ?B/s]

Conversión y carga completadas.
