In [None]:
import json
from pathlib import Path
from safetensors.torch import load_file, save_file
import torch
from tqdm import tqdm
from typing import Dict, Any
import sys
from huggingface_hub import hf_hub_download
from google.colab import drive

print("Google Driveをマウントしています...")
drive.mount('/content/drive')

path = hf_hub_download(repo_id="fal/AuraFlow-v0.3", filename="aura_flow_0.3.safetensors", revision="main")

model_file : Path  = Path(path)
SAVE_DIR : Path = Path("/content/drive/MyDrive/ComfyUI/models/checkpoints")
SAVE_DIR.mkdir(parents=True, exist_ok=True)
output_path : Path  = SAVE_DIR / "aura_flow_0.3-fp8.safetensors"

# read safetensors metadata
def read_safetensors_metadata(path: str):
    with open(path, 'rb') as f:
        header_size = int.from_bytes(f.read(8), 'little')
        header_json = f.read(header_size).decode('utf-8')
        header = json.loads(header_json)
        metadata = header.get('__metadata__', {})
        return metadata

metadata = read_safetensors_metadata(path)
print(json.dumps(metadata, indent=4)) #show metadata

non_quantized_layers = set(["vae.",
                            "model.double_layers.",
                            "model.single_layers.31"])

sd_pruned = dict() #initialize empty dict
layer_status = []  # レイヤー名とスキップ有無を保存するリスト
state_dict = load_file(path) #load safetensors file

for key in tqdm(state_dict): #for each key in the safetensors file
    layer_name = str(key)
    # layer_nameがnon_quantized_layers内のいずれかのプレフィックスで始まるかをチェック
    if any(layer_name.startswith(skip_layer) for skip_layer in non_quantized_layers):
        layer_status.append((layer_name, "スキップ"))  # スキップされた場合
        sd_pruned[key] = state_dict[key]
        continue  # スキップする場合は次へ

    layer_status.append((layer_name, "処理済み"))  # スキップされなかった場合
    sd_pruned[key] = state_dict[key].to(torch.float8_e4m3fn) #convert to fp8

# save the pruned safetensors file
save_file(sd_pruned, output_path, metadata={"format": "pt", **metadata})
# レイヤー名とスキップの有無を出力
for layer_name, status in layer_status:
    print(f"{layer_name}, {status}")

print(model_file)
print(output_path)


Google Driveをマウントしています...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
{}


100%|██████████| 824/824 [00:01<00:00, 736.41it/s]


model.cond_seq_linear.weight, 処理済み
model.double_layers.0.attn.w1k.weight, スキップ
model.double_layers.0.attn.w1o.weight, スキップ
model.double_layers.0.attn.w1q.weight, スキップ
model.double_layers.0.attn.w1v.weight, スキップ
model.double_layers.0.attn.w2k.weight, スキップ
model.double_layers.0.attn.w2o.weight, スキップ
model.double_layers.0.attn.w2q.weight, スキップ
model.double_layers.0.attn.w2v.weight, スキップ
model.double_layers.0.mlpC.c_fc1.weight, スキップ
model.double_layers.0.mlpC.c_fc2.weight, スキップ
model.double_layers.0.mlpC.c_proj.weight, スキップ
model.double_layers.0.mlpX.c_fc1.weight, スキップ
model.double_layers.0.mlpX.c_fc2.weight, スキップ
model.double_layers.0.mlpX.c_proj.weight, スキップ
model.double_layers.0.modC.1.weight, スキップ
model.double_layers.0.modX.1.weight, スキップ
model.double_layers.1.attn.w1k.weight, スキップ
model.double_layers.1.attn.w1o.weight, スキップ
model.double_layers.1.attn.w1q.weight, スキップ
model.double_layers.1.attn.w1v.weight, スキップ
model.double_layers.1.attn.w2k.weight, スキップ
model.double_layers.1.attn.w2o.

In [9]:
import json
from pathlib import Path
from safetensors.torch import load_file, save_file
import torch
from tqdm import tqdm
from typing import Dict, Any
import sys
from huggingface_hub import hf_hub_download
from google.colab import drive

print("Google Driveをマウントしています...")
drive.mount('/content/drive')

path = hf_hub_download(repo_id="fal/AuraFlow-v0.3", filename="aura_flow_0.3.safetensors", revision="main")

model_file : Path  = Path(path)
SAVE_DIR : Path = Path("/content/drive/MyDrive/ComfyUI/models/checkpoints")
SAVE_DIR.mkdir(parents=True, exist_ok=True)
output_path : Path  = SAVE_DIR / "aura_flow_0.3-fp8.safetensors"

# read safetensors metadata
def read_safetensors_metadata(path: str):
    with open(path, 'rb') as f:
        header_size = int.from_bytes(f.read(8), 'little')
        header_json = f.read(header_size).decode('utf-8')
        header = json.loads(header_json)
        metadata = header.get('__metadata__', {})
        return metadata

metadata = read_safetensors_metadata(path)
print(json.dumps(metadata, indent=4)) #show metadata

non_quantized_layers = set(["vae.",
                            "model.double_layers.",
                            "model.single_layers.31"])

sd_pruned = dict() #initialize empty dict
layer_status = []  # レイヤー名とスキップ有無を保存するリスト
state_dict = load_file(path) #load safetensors file

for key in tqdm(state_dict): #for each key in the safetensors file
    layer_name = str(key)
    # layer_nameがnon_quantized_layers内のいずれかのプレフィックスで始まるかをチェック
    if any(layer_name.startswith(skip_layer) for skip_layer in non_quantized_layers):
        layer_status.append((layer_name, "スキップ"))  # スキップされた場合
        sd_pruned[key] = state_dict[key]
        continue  # スキップする場合は次へ

    layer_status.append((layer_name, "処理済み"))  # スキップされなかった場合
    sd_pruned[key] = state_dict[key].to(torch.float8_e4m3fn) #convert to fp8

# save the pruned safetensors file
save_file(sd_pruned, output_path, metadata={"format": "pt", **metadata})
# レイヤー名とスキップの有無を出力
for layer_name, status in layer_status:
    print(f"{layer_name}, {status}")

print(model_file)
print(output_path)


Google Driveをマウントしています...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
{}


100%|██████████| 824/824 [00:01<00:00, 736.41it/s]


model.cond_seq_linear.weight, 処理済み
model.double_layers.0.attn.w1k.weight, スキップ
model.double_layers.0.attn.w1o.weight, スキップ
model.double_layers.0.attn.w1q.weight, スキップ
model.double_layers.0.attn.w1v.weight, スキップ
model.double_layers.0.attn.w2k.weight, スキップ
model.double_layers.0.attn.w2o.weight, スキップ
model.double_layers.0.attn.w2q.weight, スキップ
model.double_layers.0.attn.w2v.weight, スキップ
model.double_layers.0.mlpC.c_fc1.weight, スキップ
model.double_layers.0.mlpC.c_fc2.weight, スキップ
model.double_layers.0.mlpC.c_proj.weight, スキップ
model.double_layers.0.mlpX.c_fc1.weight, スキップ
model.double_layers.0.mlpX.c_fc2.weight, スキップ
model.double_layers.0.mlpX.c_proj.weight, スキップ
model.double_layers.0.modC.1.weight, スキップ
model.double_layers.0.modX.1.weight, スキップ
model.double_layers.1.attn.w1k.weight, スキップ
model.double_layers.1.attn.w1o.weight, スキップ
model.double_layers.1.attn.w1q.weight, スキップ
model.double_layers.1.attn.w1v.weight, スキップ
model.double_layers.1.attn.w2k.weight, スキップ
model.double_layers.1.attn.w2o.