In [5]:
import os
from pathlib import Path
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm.auto import tqdm


def convert_model_tp_fp8(file_path: str):
    print(f"Converting: {file_path}")
    
    # Read metadata
    try:
        with safe_open(file_path, framework="pt", device="cpu") as f:
            metadata = f.metadata()
            metadata = metadata if metadata is not None else {}
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return False

    # Convert the model
    try:
        sd_pruned = {}
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for key in tqdm(f.keys(), desc="Converting tensors"):
                tensor = f.get_tensor(key)
                sd_pruned[key] = tensor.to(torch.float8_e4m3fn)

        # Save the converted model (in the original directory)
        output_dir = os.path.dirname(file_path)
        model_name = os.path.splitext(os.path.basename(file_path))[0]
        output_filename = f"{model_name}.fp8.safetensors"
        output_path = os.path.join(output_dir, output_filename)
        save_file(sd_pruned, output_path, metadata={"format": "pt", **metadata})
        print(f"Saved: {output_path}")

        return True
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return False
    
        

In [6]:
def convert_model_to_int8(file_path: str):
    print(f"Converting: {file_path}")
    
    # Read metadata
    try:
        with safe_open(file_path, framework="pt", device="cpu") as f:
            metadata = f.metadata()
            metadata = metadata if metadata is not None else {}
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return False

    # Convert the model
    try:
        sd_pruned = {}
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for key in tqdm(f.keys(), desc="Converting tensors"):
                tensor = f.get_tensor(key)
                sd_pruned[key] = tensor.to(torch.int8)

        # Save the converted model (in the original directory)
        output_dir = os.path.dirname(file_path)
        model_name = os.path.splitext(os.path.basename(file_path))[0]
        output_filename = f"{model_name}.int8.safetensors"
        output_path = os.path.join(output_dir, output_filename)
        save_file(sd_pruned, output_path, metadata={"format": "pt", **metadata})
        print(f"Saved: {output_path}")

        return True
    except Exception as e:
        print(f"ERROR: {str(e)}")
        return False
        

In [7]:
file_path = "models/diffusion/checkpoints/sdxl/animagine-xl-4.0-opt.safetensors"

In [8]:
convert_model_to_int8(file_path)

Converting: models/diffusion/checkpoints/sdxl/animagine-xl-4.0-opt.safetensors


Converting tensors:   0%|          | 0/2514 [00:00<?, ?it/s]

Saved: models/diffusion/checkpoints/sdxl/animagine-xl-4.0-opt.int8.safetensors


True