In [2]:
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from torchvision import transforms
from tqdm.auto import tqdm

from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel

from dataclasses import dataclass
import wandb
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import (
    KarrasDiffusionSchedulers,
    SchedulerMixin,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput

from diffusers.configuration_utils import ConfigMixin, register_to_config


from diffusers.models.modeling_utils import ModelMixin

# from diffusers.models.unets.unet_2d_blocks import get_down_block, get_up_block
from diffusers.models.unets.unet_2d import UNet2DOutput

import torch.nn as nn

from diffusers.schedulers.scheduling_ddpm import DDPMSchedulerOutput
from diffusers.utils.torch_utils import is_torch_version


from diffusers.models.activations import get_activation

from diffusers.utils import deprecate
from functools import partial
import numbers

from diffusers.utils import deprecate, logging
from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
from diffusers.models.attention_processor import (
    AttnAddedKVProcessor,
    SlicedAttnAddedKVProcessor,
    SlicedAttnProcessor,
    AttentionProcessor,
)
import inspect
import math
from typing import List, Optional, Tuple, Union

logger = logging.get_logger(__name__)

from diffusers.utils.torch_utils import apply_freeu
from diffusers.configuration_utils import register_to_config
from diffusers.utils import deprecate, logging


from diffusers.models.normalization import (
    RMSNorm,
)


class ResnetBlock2D(nn.Module):
    r"""ok
    A Resnet block.

    Parameters:
        in_channels (`int`): The number of channels in the input.
        out_channels (`int`, *optional*, default to be `None`):
            The number of output channels for the first conv2d layer. If None, same as `in_channels`.
        dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
        temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
        groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
        groups_out (`int`, *optional*, default to None):
            The number of groups to use for the second normalization layer. if set to None, same as `groups`.
        eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
        non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
        time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
            By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
            stronger conditioning with scale and shift.
        kernel (`torch.Tensor`, optional, default to None): FIR filter, see
            [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
        output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
        use_in_shortcut (`bool`, *optional*, default to `True`):
            If `True`, add a 1x1 nn.conv2d layer for skip-connection.
        up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
        down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
        conv_shortcut_bias (`bool`, *optional*, default to `True`):  If `True`, adds a learnable bias to the
            `conv_shortcut` output.
        conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
            If None, same as `out_channels`.
    """

    def __init__(
        self,
        *,
        in_channels: int,
        out_channels: Optional[int] = None,
        conv_shortcut: bool = False,
        dropout: float = 0.0,
        temb_channels: int = 512,
        groups: int = 32,
        groups_out: Optional[int] = None,
        pre_norm: bool = True,
        eps: float = 1e-6,
        non_linearity: str = "swish",
        skip_time_act: bool = False,
        time_embedding_norm: str = "default",  # default, scale_shift,
        kernel: Optional[torch.Tensor] = None,
        output_scale_factor: float = 1.0,
        use_in_shortcut: Optional[bool] = None,
        up: bool = False,
        down: bool = False,
        conv_shortcut_bias: bool = True,
        conv_2d_out_channels: Optional[int] = None,
    ):
        super().__init__()

        self.pre_norm = True
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut
        # self.up=False
        # self.down=False
        self.up = up
        self.down = down
        self.output_scale_factor = output_scale_factor
        self.time_embedding_norm = time_embedding_norm
        self.skip_time_act = skip_time_act

        if groups_out is None:
            groups_out = groups

        self.norm1 = torch.nn.GroupNorm(
            num_groups=groups, num_channels=in_channels, eps=eps, affine=True
        )

        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )

        self.time_emb_proj = nn.Linear(temb_channels, out_channels)

        self.norm2 = torch.nn.GroupNorm(
            num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
        )

        self.dropout = torch.nn.Dropout(dropout)
        conv_2d_out_channels = conv_2d_out_channels or out_channels
        self.conv2 = nn.Conv2d(
            out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1
        )
        # self.nonlinearity=SiLU()
        self.nonlinearity = get_activation(non_linearity)

        self.upsample = self.downsample = None
        # self.use_in_shortcut=False
        self.use_in_shortcut = (
            self.in_channels != conv_2d_out_channels
            if use_in_shortcut is None
            else use_in_shortcut
        )

        self.conv_shortcut = None
        if self.use_in_shortcut:
            self.conv_shortcut = nn.Conv2d(
                in_channels,
                conv_2d_out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=conv_shortcut_bias,
            )

    def forward(
        self,
        input_tensor: torch.Tensor,
        temb: torch.Tensor,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)
        hidden_states = input_tensor

        # torch.Size([16, 128, 64, 64])
        hidden_states = self.norm1(hidden_states)
        # torch.Size([16, 128, 64, 64])
        hidden_states = self.nonlinearity(hidden_states)
        # torch.Size([16, 128, 64, 64])
        hidden_states = self.conv1(hidden_states)

        if self.time_emb_proj is not None:
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, None, None]
        # self.time_embedding_norm="default"

        # temb=torch.Size([16, 128, 1, 1])
        if temb is not None:
            hidden_states = hidden_states + temb
        hidden_states = self.norm2(hidden_states)

        hidden_states = self.nonlinearity(hidden_states)

        hidden_states = self.dropout(hidden_states)
        # torch.Size([16, 128, 64, 64])
        hidden_states = self.conv2(hidden_states)

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor.contiguous())
        # input_tensor=torch.Size([4, 128, 64, 64])
        # hidden_states=torch.Size([4, 128, 64, 64])
        # self.output_scale_factor=1.0
        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

        return output_tensor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
temp_tensor = torch.randn(
    (16, 128, 64, 64),
    device="cuda",
)
temb_local = torch.randn(
    (16, 512),
    device="cuda",
)

In [4]:
default_resnet = ResnetBlock2D(
    in_channels=128,
    out_channels=128,
    temb_channels=512,
    eps=1e-05,
    groups=128,
    dropout=0.0,
    time_embedding_norm="default",
    non_linearity="silu",
    output_scale_factor=1.0,
    pre_norm=True,
)
default_resnet = default_resnet.cuda()

In [4]:
default_resnet.skip_time_act

False

In [6]:
import time

amount = 1000
start = time.time()
for _ in range(amount):
    result = default_resnet(
        temp_tensor,
        temb_local,
    )
duration = time.time() - start
print("duration s 1000, ", duration)

duration s 1000,  0.9223389625549316


### torch compile

In [13]:
compile_resnet = ResnetBlock2D(
    in_channels=128,
    out_channels=128,
    temb_channels=512,
    eps=1e-05,
    groups=128,
    dropout=0.0,
    time_embedding_norm="default",
    non_linearity="silu",
    output_scale_factor=1.0,
    pre_norm=True,
)
compile_resnet = compile_resnet.cuda()
compile_resnet = torch.compile(
    compile_resnet,
    fullgraph=True,
    mode="reduce-overhead",
    # mode="max-autotune",
)

In [8]:
compile_resnet = compile_resnet.to('cuda')

In [16]:
import time

amount = 1000
start = time.time()
for _ in range(amount):
    compile_resnet(
        temp_tensor,
        temb_local,
    )
duration = time.time() - start
print("duration s 1000, ", duration)

duration s 1000,  0.4762232303619385


### Profile

In [17]:
from torch.profiler import profile, record_function, ProfilerActivity

activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
device = "cuda"
sort_by_keyword = device + "_time_total"

compile_resnet = ResnetBlock2D(
    in_channels=128,
    out_channels=128,
    temb_channels=512,
    eps=1e-05,
    groups=128,
    dropout=0.0,
    time_embedding_norm="default",
    non_linearity="silu",
    output_scale_factor=1.0,
    pre_norm=True,
)

compile_resnet = compile_resnet.cuda()
compile_resnet = torch.compile(
    compile_resnet,
    fullgraph=True,
    mode="reduce-overhead",
)
amount = 10

with profile(
    activities=activities,
    record_shapes=True,
) as prof:
    for _ in range(amount):
        result = compile_resnet(
            temp_tensor,
            temb_local,
        )



In [18]:
print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             Torch-Compiled Region: 0/1         0.03%     456.878us        99.98%        1.352s     135.241ms       0.000us         0.00%       9.882ms     988.188us            10  
                                       CompiledFunction         0.53%       7.136ms        99.94%        1.352s     135.195ms       9.125ms        92.34%       9.882ms     988.188us            10  
         

In [19]:
prof.export_chrome_trace("trace_default_reduce-overhead.json")