Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mac crash #1

Closed
jerlinn opened this issue Mar 3, 2024 · 27 comments · Fixed by #94
Closed

Mac crash #1

jerlinn opened this issue Mar 3, 2024 · 27 comments · Fixed by #94
Labels
bug Something isn't working upstream Issue shared across all LayerDiffuse impls

Comments

@jerlinn
Copy link

jerlinn commented Mar 3, 2024

Docode node always crash

💻 Mac M2

export PYTORCH_ENABLE_MPS_FALLBACK=1
--force-fp16
❌ just python main.py

CleanShot 2024-03-03 at 14 44 43@2x

@ynie
Copy link

ynie commented Mar 3, 2024

Same running into this issue :(

@huchenlei huchenlei added the bug Something isn't working label Mar 3, 2024
@yiwangsimple
Copy link

M1 Same running into this issue :( +1

@huchenlei
Copy link
Owner

I do not have a Macbook with M-series chip. Can you help confirm that whether the issue exists for SD Forge's impl as well? https://github.com/layerdiffusion/sd-forge-layerdiffusion

@yiwangsimple
Copy link

截屏2024-03-05 10 35 36
My observation today is that it's the error that occurs when executing this node that causes Python to crash outright

@huchenlei
Copy link
Owner

Any Mac user tried https://github.com/layerdiffusion/sd-forge-layerdiffusion with any success?

@Tukeping
Copy link

Tukeping commented Mar 6, 2024

@huchenlei Hi, Bro. Have you tried to fix this problem? I also have an MacOS with an M1 chip, and I have this problem.

@huchenlei
Copy link
Owner

@huchenlei Hi, Bro. Have you tried to fix this problem? I also have an MacOS with an M1 chip, and I have this problem.

I would like to first confirm whether this issue is comfyui specific or for SD Forge as well.

@Tukeping
Copy link

Tukeping commented Mar 6, 2024

@huchenlei Hi, Bro. Have you tried to fix this problem? I also have an MacOS with an M1 chip, and I have this problem.

I would like to first confirm whether this issue is comfyui specific or for SD Forge as well.

@huchenlei
I tried running sd-forge-layerdiffuse and it reported the same error message.

------ Loggers ------------- Below ---------
Running on local URL: http://127.0.0.1:7860

To create a public link, set share=True in launch().
model_type EPS
UNet ADM Dimension 2816
Startup time: 14.5s (prepare environment: 0.4s, import torch: 5.2s, import gradio: 1.6s, setup paths: 2.1s, other imports: 3.0s, load scripts: 1.1s, create ui: 0.4s, gradio launch: 0.6s).
Using split attention in VAE
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
Using split attention in VAE
extra {'cond_stage_model.clip_g.transformer.text_model.embeddings.position_ids', 'cond_stage_model.clip_l.transformer.text_model.embeddings.position_ids', 'cond_stage_model.clip_l.logit_scale', 'cond_stage_model.clip_l.text_projection', 'cond_stage_model.clip_g.logit_scale'}
To load target model SDXLClipModel
Begin to load 1 model
Moving model(s) has taken 0.00 seconds
Model loaded in 7.7s (load weights from disk: 0.6s, forge load real models: 5.7s, calculate empty prompt: 1.3s).
[Layer Diffusion] LayerMethod.FG_ONLY_ATTN
To load target model SDXL
Begin to load 1 model
Moving model(s) has taken 3.52 seconds
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:11<00:00, 1.69it/s]
To load target model AutoencoderKL██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:10<00:00, 1.81it/s]
Begin to load 1 model
Moving model(s) has taken 0.10 seconds
To load target model UNet1024
Begin to load 1 model
Moving model(s) has taken 0.27 seconds
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00, 2.16it/s]
/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArraySort.mm:287: failed assertion `(null)" Axis = 4. This class only supports axis = 0, 1, 2, 3
'
./webui.sh: line 292: 13941 Abort trap: 6 "${python_cmd}" -u "${LAUNCH_SCRIPT}" "$@"
/opt/homebrew/Caskroom/miniconda/base/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '

@gabrie
Copy link

gabrie commented Mar 9, 2024

M2 Same running into this issue :( +1

@huchenlei huchenlei added the upstream Issue shared across all LayerDiffuse impls label Mar 9, 2024
@jerlinn
Copy link
Author

jerlinn commented Mar 9, 2024

Maybe in MPS frameworks, it can not sort the batch_zise dimension.

@tilseam
Copy link

tilseam commented Mar 10, 2024

M2 Pro+32G RAM get the same issue.

@BannyLon
Copy link

我的也是,mac M2 工作流一旦运行到LayerDiffusion解码(RGBA),python就会自动中断运行,跳出错误提示

@yiwangsimple
Copy link

The mac doesn't have to get hung up on it. You can replace the node function with another process.

@BannyLon
Copy link

The mac doesn't have to get hung up on it. You can replace the node function with another process.

How to replace node function with another process.

@hike2008
Copy link

same error, so how to resolve it

@BannyLon
Copy link

Looking forward to professional solutions, thank you very much!

@BannyLon
Copy link

The popularity of plugins is so high, how can this common bug be ignored and there are no experts to help solve it!

@dingshanliang
Copy link

dingshanliang commented Apr 8, 2024

same issue with mac m2.
it seems problems related with the MPS framework.
Sort data along a dimension (axis) that the MPS framework doesn't support. Currently, it can only handle sorting along the first 4 dimensions of an N-dimensional array while we get 5 dimensions in this situation.

/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArraySort.mm:287: failed assertion `(null)" Axis = 4. This class only supports axis = 0, 1, 2, 3

@nigo81
Copy link

nigo81 commented Apr 9, 2024

M3 pro max + 128g get the same issue.

@franksuni
Copy link

same issue with Mac M2 ultra.
it does seem to be an MPS issue.
the workaround I got is to use a custom node called "ImageSelector". you can apply this directly after vae decoder to select the layer you want. however, this will result in losing the matted FG (RGBA result with mask). use the [Generate BG + FG + Blended together] workflow as an example, apply 3 ImageSelectors after vae decode, and set selected_index each to 1, 2, 3 on the three selectors. what you get can be the blended FG+BG, separate FG (with a gray background, not transparent), and the BG it generated.
I suppose there can be methods to get the mask again either by using a different selector that supports the mask; or use applying seg methods to create a new mask based on the FG's gray background. haven't tested anything yet. but you are welcome to try it by yourself.

@devgdovg
Copy link

devgdovg commented Apr 22, 2024

The workaround below works fine on my m1max macbook.

In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI.

PS:
I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

@devgdovg
Copy link

to reproduce the crash in a simple scenario, try the code below on your m series macbook:

import torch
a = torch.randn(8, 1, 4, 512, 512)
mps_device = torch.device("mps")
b = a.to(mps_device)
tt = torch.median(b, dim=0) # crash here

but if you try a tensor with lower dimension, eg. a = torch.randn(8, 1, 4, 512), there will be no crash

@forgetphp
Copy link

The workaround below works fine on my m1max macbook.

In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI.

PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

我是 M1 用户。按照大佬的方法我成功的解决这个错误了。

image

@tilseam
Copy link

tilseam commented May 8, 2024

The workaround below works fine on my m1max macbook.
In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI.
PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

我是 M1 用户。按照大佬的方法我成功的解决这个错误了。

image

小细节:不知道为什么用文本编辑器和xcode修改,插件会出错运行不了。用VSCode修改成功。

@feihuang520
Copy link

The workaround below works fine on my m1max macbook.
In file lib_layerdiffusion/models.py, find

median = torch.median(result, dim=0).values

and modify like this

if self.load_device == torch.device("mps"):
            '''
            In case that apple silicon devices would crash when calling torch.median() on tensors
            in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
            and then move the result back to gpu.
            '''
            median = torch.median(result.cpu(), dim=0).values
            median = median.to(device=self.load_device, dtype=self.dtype)
else:
            median = torch.median(result, dim=0).values

Save and restart ComfyUI.
PS: I tried to make a pull request, but it turns out I have no access rights to this repo. I don't know if the repo's author have configured the repo settings.

ERROR: Permission to huchenlei/ComfyUI-layerdiffuse.git denied to devgdovg.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights

Anyone can help me with this? thanks.

我是 M1 用户。按照大佬的方法我成功的解决这个错误了。
image

小细节:不知道为什么用文本编辑器和xcode修改,插件会出错运行不了。用VSCode修改成功。

同是m1,能否上传你修改的文件,让我们覆盖原文件试试

@tilseam
Copy link

tilseam commented May 13, 2024

同是m1,能否上传你修改的文件,让我们覆盖原文件试试
models.py.zip

@forgetphp
Copy link

同是m1,能否上传你修改的文件,让我们覆盖原文件试试
models.py.zip
@tilseam 您好!以下是我修改后的源文件。

import torch.nn as nn
import torch
import cv2
import numpy as np

from tqdm import tqdm
from typing import Optional, Tuple
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block


def check_diffusers_version():
    import diffusers
    from packaging.version import parse

    assert parse(diffusers.__version__) >= parse(
        "0.25.0"
    ), "diffusers>=0.25.0 requirement not satisfied. Please install correct diffusers version."


check_diffusers_version()


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class LatentTransparencyOffsetEncoder(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.blocks = torch.nn.Sequential(
            torch.nn.Conv2d(4, 32, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            torch.nn.Conv2d(128, 128, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            torch.nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
            nn.SiLU(),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            zero_module(torch.nn.Conv2d(256, 4, kernel_size=3, padding=1, stride=1)),
        )

    def __call__(self, x):
        return self.blocks(x)


# 1024 * 1024 * 3 -> 16 * 16 * 512 -> 1024 * 1024 * 3
class UNet1024(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str] = (
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types: Tuple[str] = (
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
        ),
        block_out_channels: Tuple[int] = (32, 32, 64, 128, 256, 512, 512),
        layers_per_block: int = 2,
        mid_block_scale_factor: float = 1,
        downsample_padding: int = 1,
        downsample_type: str = "conv",
        upsample_type: str = "conv",
        dropout: float = 0.0,
        act_fn: str = "silu",
        attention_head_dim: Optional[int] = 8,
        norm_num_groups: int = 4,
        norm_eps: float = 1e-5,
    ):
        super().__init__()

        # input
        self.conv_in = nn.Conv2d(
            in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
        )
        self.latent_conv_in = zero_module(
            nn.Conv2d(4, block_out_channels[2], kernel_size=1)
        )

        self.down_blocks = nn.ModuleList([])
        self.mid_block = None
        self.up_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                temb_channels=None,
                add_downsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=(
                    attention_head_dim
                    if attention_head_dim is not None
                    else output_channel
                ),
                downsample_padding=downsample_padding,
                resnet_time_scale_shift="default",
                downsample_type=downsample_type,
                dropout=dropout,
            )
            self.down_blocks.append(down_block)

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            temb_channels=None,
            dropout=dropout,
            resnet_eps=norm_eps,
            resnet_act_fn=act_fn,
            output_scale_factor=mid_block_scale_factor,
            resnet_time_scale_shift="default",
            attention_head_dim=(
                attention_head_dim
                if attention_head_dim is not None
                else block_out_channels[-1]
            ),
            resnet_groups=norm_num_groups,
            attn_groups=None,
            add_attention=True,
        )

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]
            input_channel = reversed_block_out_channels[
                min(i + 1, len(block_out_channels) - 1)
            ]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=layers_per_block + 1,
                in_channels=input_channel,
                out_channels=output_channel,
                prev_output_channel=prev_output_channel,
                temb_channels=None,
                add_upsample=not is_final_block,
                resnet_eps=norm_eps,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=(
                    attention_head_dim
                    if attention_head_dim is not None
                    else output_channel
                ),
                resnet_time_scale_shift="default",
                upsample_type=upsample_type,
                dropout=dropout,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # out
        self.conv_norm_out = nn.GroupNorm(
            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
        )
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(
            block_out_channels[0], out_channels, kernel_size=3, padding=1
        )

    def forward(self, x, latent):
        sample_latent = self.latent_conv_in(latent)
        sample = self.conv_in(x)
        emb = None

        down_block_res_samples = (sample,)
        for i, downsample_block in enumerate(self.down_blocks):
            if i == 3:
                sample = sample + sample_latent

            sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
            down_block_res_samples += res_samples

        sample = self.mid_block(sample, emb)

        for upsample_block in self.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[
                : -len(upsample_block.resnets)
            ]
            sample = upsample_block(sample, res_samples, emb)

        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)
        return sample


def checkerboard(shape):
    return np.indices(shape).sum(axis=0) % 2


def fill_checkerboard_bg(y: torch.Tensor) -> torch.Tensor:
    alpha = y[..., :1]
    fg = y[..., 1:]
    B, H, W, C = fg.shape
    cb = checkerboard(shape=(H // 64, W // 64))
    cb = cv2.resize(cb, (W, H), interpolation=cv2.INTER_NEAREST)
    cb = (0.5 + (cb - 0.5) * 0.1)[None, ..., None]
    cb = torch.from_numpy(cb).to(fg)
    vis = fg * alpha + cb * (1 - alpha)
    return vis


class TransparentVAEDecoder:
    def __init__(self, sd, device, dtype):
        self.load_device = device
        self.dtype = dtype

        model = UNet1024(in_channels=3, out_channels=4)
        model.load_state_dict(sd, strict=True)
        model.to(self.load_device, dtype=self.dtype)
        model.eval()
        self.model = model

    @torch.no_grad()
    def estimate_single_pass(self, pixel, latent):
        y = self.model(pixel, latent)
        return y

    @torch.no_grad()
    def estimate_augmented(self, pixel, latent):
        args = [
            [False, 0],
            [False, 1],
            [False, 2],
            [False, 3],
            [True, 0],
            [True, 1],
            [True, 2],
            [True, 3],
        ]

        result = []

        for flip, rok in tqdm(args):
            feed_pixel = pixel.clone()
            feed_latent = latent.clone()

            if flip:
                feed_pixel = torch.flip(feed_pixel, dims=(3,))
                feed_latent = torch.flip(feed_latent, dims=(3,))

            feed_pixel = torch.rot90(feed_pixel, k=rok, dims=(2, 3))
            feed_latent = torch.rot90(feed_latent, k=rok, dims=(2, 3))

            eps = self.estimate_single_pass(feed_pixel, feed_latent).clip(0, 1)
            eps = torch.rot90(eps, k=-rok, dims=(2, 3))

            if flip:
                eps = torch.flip(eps, dims=(3,))

            result += [eps]

        result = torch.stack(result, dim=0)
        median = torch.median(result, dim=0).values
        return median

    @torch.no_grad()
    def decode_pixel(
        self, pixel: torch.TensorType, latent: torch.TensorType
    ) -> torch.TensorType:
        # pixel.shape = [B, C=3, H, W]
        assert pixel.shape[1] == 3
        pixel_device = pixel.device
        pixel_dtype = pixel.dtype

        pixel = pixel.to(device=self.load_device, dtype=self.dtype)
        latent = latent.to(device=self.load_device, dtype=self.dtype)
        # y.shape = [B, C=4, H, W]
        y = self.estimate_augmented(pixel, latent)
        y = y.clip(0, 1)
        assert y.shape[1] == 4
        # Restore image to original device of input image.
        return y.to(pixel_device, dtype=pixel_dtype)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working upstream Issue shared across all LayerDiffuse impls
Projects
None yet
Development

Successfully merging a pull request may close this issue.