In [None]:
!pip install --upgrade torch torchvision torchaudio

In [None]:
# Install diffusers from the main branch until future stable release
!pip install git+https://github.com/huggingface/diffusers.git

In [None]:
!pip show torch

In [None]:
!nvidia-smi

In [5]:
import os
from pathlib import Path

CACHE_DIR = Path('/content/drive/MyDrive/artifacts/flux_kontext_mega_cache')

# Set environment variables
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["CUDA_CACHE_PATH"] = str(CACHE_DIR / ".nv_cache")
os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(CACHE_DIR / ".inductor_cache")
os.environ["TRITON_CACHE_DIR"] = str(CACHE_DIR / ".triton_cache")

# Optional: Enable debug logs for PyTorch compilation/cache behavior
os.environ["TORCH_LOGS"] = "+torch._inductor.codecache"
mega_cache_path = CACHE_DIR / ".mega_cache"

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import torch
torch.cuda.get_device_capability()

In [None]:
import torch
from diffusers import FluxKontextPipeline, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import load_image
# from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")

In [9]:
input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")

In [None]:
%%time
image = pipe(
  image=input_image,
  prompt="Add a hat to the cat",
  guidance_scale=2.5
).images[0]


In [None]:
# If compilation has previously been performed
if mega_cache_path.exists():
    with open(mega_cache_path, "rb") as f:
        artifact_bytes = f.read()
    if artifact_bytes:
        torch.compiler.load_cache_artifacts(artifact_bytes)
        print("Loaded torch mega-cache artifacts")
else:
    print("Torch mega-cache artifacts not found, will generate new cache")

In [None]:
# The order of events here matters
%%time
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.disable_progress = False
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.shape_padding = True

pipe.transformer.fuse_qkv_projections()
pipe.vae.fuse_qkv_projections()
pipe.transformer.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

pipe.transformer = torch.compile(
    pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True
)
pipe.vae.decode = torch.compile(
    pipe.vae.decode, mode="max-autotune-no-cudagraphs", dynamic=True
)

pipe(
  image=input_image,
  prompt="Add a hat to the cat",
)

In [None]:
from pprint import pprint
import gc
gc.collect()
torch.cuda.empty_cache()
pprint(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:
pipe.transformer

In [16]:
# Use for first time compilation to get artifacts
artifact_bytes, cache_info  = torch.compiler.save_cache_artifacts()

In [17]:
with open(mega_cache_path, "wb") as f:
    f.write(artifact_bytes)

In [18]:
with open(mega_cache_path, "rb") as f:
    artifact_bytes_2 = f.read()

if artifact_bytes_2:
    assert artifact_bytes == artifact_bytes_2
    # torch.compiler.load_cache_artifacts(artifact_bytes)

In [None]:
# sanity check
!ls -al /content/drive/MyDrive/artifacts/flux_kontext_mega_cache
!ls -al /content/drive/MyDrive/artifacts/flux_kontext_mega_cache/.mega_cache

In [None]:
import time
start = time.time()
image = pipe(
  image=input_image,
  prompt="Add a hat to the cat",
  guidance_scale=2.5
).images[0]
end = time.time() - start
print(f"Inference time: {end:.2f} seconds")

In [None]:
from IPython.display import display
display(image)