In [1]:
import rootutils
rootutils.setup_root(".", indicator=".project-root", pythonpath=True, cwd=True)

PosixPath('/home/dgcnz/development/thesis/PART')

In [2]:
import pytest
from PIL import Image
import torchvision.transforms.v2.functional as TTFv2
import hydra
from omegaconf import OmegaConf
import torch
from src.models.part_vit_module import PARTViTModule
import lightning as L
from plotly.subplots import make_subplots

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from copy import deepcopy
# import plotly dash
from torch.profiler import ProfilerActivity
from torch.profiler import record_function




In [3]:
def get_cfg():
    with hydra.initialize(version_base=None, config_path="../../configs"):
        cfg = hydra.compose(
            config_name="train.yaml",
            overrides=[
                "experiment=partmae_im1k",
                "data.batch_size=16",
            ],
        )
    try:
        OmegaConf.register_new_resolver("eval", eval)
    except:
        pass
    return cfg

cfg = get_cfg()   

In [4]:
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# MAE: https://github.com/facebookresearch/mae
# DropPos: https://github.com/Haochen-Wang409/DropPos
# --------------------------------------------------------

import math
from functools import partial
from typing import Type, TypedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int, Bool
from src.data.components.sampling_utils import sample_and_stitch
from src.models.components.utils.part_utils import (
    compute_gt_transform,
    get_all_pairs_subset,
)
from torch import Tensor
from torch.nn.functional import mse_loss
from timm.models.vision_transformer import PatchEmbed, Block
from src.models.components.utils.pos_embed import get_2d_sincos_pos_embed


class EncoderOutput(TypedDict):
    "The output of the encoder"
    z: Float[Tensor, "B N_vis D"]

    # TODO: validate if the following shapes are correct, not sure if it's N or N_vis
    mask: Bool[Tensor, "B N"]
    "The indices of the unmasked tokens"
    ids_keep: Int[Tensor, "B N_vis"]
    ids_restore: Int[Tensor, "B ?"] # not sure what N is this
    ids_restore_pos: Int[Tensor, "B ?"] 
    "The indices of the tokens with unmasked posembeds, relative to the unmasked tokens"
    ids_keep_pos: Int[Tensor, "B N_pos"]
    "False if the position embedding "
    mask_pos: Bool[Tensor, "B N_vis"]


class DecoderOutput(TypedDict):
    pred_T: Float[Tensor, "B N_nopos**2 2"]


class LossOutput(TypedDict):
    loss: Float[Tensor, "B"]
    patch_pair_indices: Int[Tensor, "B N_nopos 2"]
    gt_T: Float[Tensor, "B N_nopos**2 2"]
    pred_T: Float[Tensor, "B N_nopos**2 2"]


class ForwardOutput(EncoderOutput, DecoderOutput, LossOutput):
    pass


class PARTMaskedAutoEncoderViT(nn.Module):
    def __init__(
        self,
        # Encoder params
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 1024,
        depth: int = 24,
        num_heads: int = 16,
        mlp_ratio: float = 4.0,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        mask_ratio: float = 0.75,
        pos_mask_ratio: float = 0.75,
        # Decoder params
        decoder_embed_dim: int = 512,
        decoder_depth: int = 8,
        decoder_num_heads: int = 16,
        num_targets: int = 2,
    ):
        super().__init__()

        # --------------------------------------------------------------------------
        self.embed_dim = embed_dim
        self.img_size = img_size
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.pos_mask_ratio = pos_mask_ratio
        self.num_targets = num_targets
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # fixed sin-cos embedding
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False
        )
        # mask token for position
        self.mask_pos_token = nn.Parameter(
            torch.zeros(1, 1, embed_dim), requires_grad=True
        )

        self.blocks = nn.ModuleList(
            [
                Block(
                    embed_dim,
                    num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim)
        # --------------------------------------------------------------------------

        # --------------------------------------------------------------------------
        # PART decoder specifics (w/o position embedding)
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.decoder_blocks = nn.ModuleList(
            [
                Block(
                    decoder_embed_dim,
                    decoder_num_heads,
                    mlp_ratio,
                    qkv_bias=True,
                    norm_layer=norm_layer,
                )
                for i in range(decoder_depth)
            ]
        )

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.decoder_pred = nn.Linear(
            decoder_embed_dim, num_targets, bias=False
        )  # decoder to patch
        self.tanh = nn.Tanh()
        # label smoothing for positions
        self.initialize_weights()

    def initialize_weights(self):
        # initialization
        # initialize (and freeze) pos_embed by sin-cos embedding
        pos_embed = get_2d_sincos_pos_embed(
            self.pos_embed.shape[-1],
            int(self.patch_embed.num_patches**0.5),
            cls_token=True,
        )
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
        w = self.patch_embed.proj.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
        torch.nn.init.normal_(self.cls_token, std=0.02)
        torch.nn.init.normal_(self.mask_pos_token, std=0.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x: Float[Tensor, "B N D"], mask_ratio: float):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        B, N, D = x.shape  # batch, length, dim
        len_keep = int(N * (1 - mask_ratio))

        noise = torch.rand(B, N, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1
        )  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        # remove the second subset
        ids_remove = ids_shuffle[:, len_keep:]

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([B, N], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore).bool()

        return ids_keep, mask, ids_restore, ids_remove

    def forward_encoder(
        self, x: Float[Tensor, "B C H W"], mask_ratio: float, pos_mask_ratio: float
    ) -> EncoderOutput:
        # ------------------- EMBED PATCHES w/o [cls] TOKEN -------------------
        x = self.patch_embed(x)
        B, N, D = x.shape
        # --------------------------- GENERATE MASK ---------------------------
        ids_keep, mask, ids_restore, ids_remove = self.random_masking(x, mask_ratio)

        # gather patch embeddings and position embeddings
        x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        pos_embed_all = self.pos_embed[:, 1:, :].data.repeat(B, 1, 1)  # w/o [cls] token
        pos_embed_vis = torch.gather(
            pos_embed_all, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
        ).detach()

        # random masking for position embedding
        ids_keep_pos, mask_pos, ids_restore_pos, ids_remove_pos = self.random_masking(
            x, pos_mask_ratio
        )

        # gather position embeddings
        pos_embed = torch.gather(
            pos_embed_vis, dim=1, index=ids_keep_pos.unsqueeze(-1).repeat(1, 1, D)
        )
        mask_pos_length = math.ceil(math.floor(N * (1 - mask_ratio)) * pos_mask_ratio)
        mask_pos_tokens = self.mask_pos_token.repeat(B, mask_pos_length, 1)
        pos_embed = torch.cat([pos_embed, mask_pos_tokens], dim=1)

        # restore position embeddings before adding
        pos_embed = torch.gather(
            pos_embed, dim=1, index=ids_restore_pos.unsqueeze(-1).repeat(1, 1, D)
        )

        # add position embedding w/o [cls] token
        x = x + pos_embed

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return {
            "z": x,
            "mask": mask,
            "ids_keep": ids_keep,
            "ids_restore": ids_restore,
            "mask_pos": mask_pos,
            "ids_keep_pos": ids_keep_pos,
            "ids_restore_pos": ids_restore_pos,
        }

    def forward_decoder(
        self,
        x: Float[Tensor, "B N_vis D"],
        img_size: int,
        mask_pos: Bool[Tensor, "B N_vis"],
    ) -> DecoderOutput:
        x = self.decoder_embed(x)
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)
        x = self.decoder_pred(x)
        x: Float[Tensor, "B N_vis 2"] = x[:, 1:, :]  # remove CLS
        B, N_vis, C = x.shape
        # only predict pairwise transformations for visible patches with no position embeddings
        x: Float[Tensor, "B N_nopos C"] = x[mask_pos].view(B, -1, C)
        # compute pairwise differences
        x: Float[Tensor, "B N_nopos N_nopos C"] = x.unsqueeze(2) - x.unsqueeze(1)
        x: Float[Tensor, "B N_nopos**2 C"] = x.flatten(1, 2)
        x = self.tanh(x) * img_size
        return {"pred_T": x}

    def forward_loss(
        self,
        pred_T: Float[Tensor, "B N_nopos**2 2"],
        ids_nopos: Int[Tensor, "B N_vis"],
        patch_positions: Int[Tensor, "B N 2"],
        img_size: int,
    ) -> LossOutput:
        """
        Compute the loss for the model.
        :param pred_T: predicted transformations
        :param ids_nopos: indices of the visible patches without position embeddings
        :param patch_positions: positions of the patches
        :param img_size: size of the image
        """

        # only compute loss over the visible paches **without** position embeddings
        patch_pair_indices: Int[Tensor, "B N_nopos**2 2"] = get_all_pairs_subset(
            ids=ids_nopos
        )
        gt_T = compute_gt_transform(patch_pair_indices, patch_positions)
        loss = mse_loss(pred_T / img_size, gt_T / img_size)
        return {
            "loss": loss,
            "patch_pair_indices": patch_pair_indices,
            "gt_T": gt_T,
        }

    def forward(self, x: Float[Tensor, "B C H W"]) -> ForwardOutput:
        """
        :param x: image batch
        :return: ForwardOutput
        """
        out = dict()
        img_size = x.shape[-2]
        assert img_size == x.shape[-1], "Input image must be square"
        x_shuffled, patch_positions = sample_and_stitch(x, self.patch_size, "canonical")
        out |= self.forward_encoder(x_shuffled, self.mask_ratio, self.pos_mask_ratio)
        out |= self.forward_decoder(out["z"], img_size, out["mask_pos"])
        ids_nopos = out["ids_keep"][out["mask_pos"]].view(x.shape[0], -1)
        out |= self.forward_loss(out["pred_T"], ids_nopos, patch_positions, img_size)
        return out


def PART_mae_vit_base_patch16_dec512d8b(**kwargs):
    model = PARTMaskedAutoEncoderViT(
        patch_size=16,
        embed_dim=768,
        depth=12,
        num_heads=12,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        mlp_ratio=4,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs
    )
    return model


def patchify(
    imgs: Float[Tensor, "B C H W"], patch_size: int
) -> Float[Tensor, "B N (P^2 * C)"]:
    B, C, H, W = imgs.shape
    P = patch_size
    assert H == W and H % P == 0
    assert C == 3

    h = w = imgs.shape[2] // P
    x = imgs.reshape(shape=(B, C, h, P, w, P))
    x = torch.einsum("nchpwq->nhwpqc", x)
    x = x.reshape(shape=(B, h * w, P**2 * C))
    return x


def unpatchify(
    x: Float[Tensor, "B N (P^2 * C)"], patch_size: int
) -> Float[Tensor, "B C H W"]:
    B, N, _ = x.shape
    C = 3
    p = patch_size
    h = w = int(N**0.5)
    assert h * w == N
    x = x.reshape(shape=(B, h, w, p, p, C))
    x = torch.einsum("nhwpqc->nchpwq", x)
    imgs = x.reshape(shape=(B, C, h * p, h * p))
    return imgs


# set recommended archs
PART_mae_vit_base_patch16 = (
    PART_mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks
)

# if __name__ == "__main__":
#     backbone = PART_mae_vit_base_patch16(pos_mask_ratio=0.4, mask_ratio=0.3)
#     x = torch.randn(2, 3, 224, 224)
#     # def forward(self, imgs, mask_ratio, pos_mask_ratio):
#     patch_size = 16
#     num_patches = (224 // patch_size) ** 2
#     print(backbone.patch_embed.patch_size)
#     out = backbone.forward(x)
#     print(out["loss"])
#     # compute expected loss
#     best_T = torch.zeros_like(out["gt_T"]) / 224
#     print(mse_loss(best_T, out["gt_T"] / 224))
# 

In [5]:
model = PART_mae_vit_base_patch16(pos_mask_ratio=0.4, mask_ratio=0.3).eval().cuda()

In [6]:
batch_size = 128
x: torch.Tensor = torch.randn(batch_size, 3, 224, 224).cuda()

In [7]:
# (c) Meta Platforms, Inc. and affiliates. 
import logging
import socket
from datetime import datetime, timedelta

import torch

from torch.autograd.profiler import record_function
from torchvision import models

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

def trace_handler(prof: torch.profiler.profile):
   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   # Construct the trace file.
   prof.export_chrome_trace(f"{file_prefix}.json.gz")

   # Construct the memory timeline file.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")

with torch.profiler.profile(
    activities=[
        torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],
    # schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
    with_modules=True,
    experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)
    # on_trace_ready=trace_handler,
) as prof:
    with torch.no_grad():
        out = model(x)  # Forward pass
    # with record_function("## forward ##"):
    #    out = module.model_step({"image": x})  # Forward pass
    #    loss = out["loss"]                    # Get loss from output dict
    # with record_function("## backward ##"):
    #     loss.backward()                       # Backward pass
    # with record_function("## optimizer ##"):
    #     optimizer.step()                      # Optimizer step
    #     optimizer.zero_grad(set_to_none=True) # Zero grad
# prof.export_memory_timeline(f"profile.html", device="cuda:0")


In [8]:
import time
start_time = time.time()

In [9]:
# example stack
# stack=['<built-in method unbind of Tensor object at 0x7e9491e75490>',
#   'timm/models/vision_transformer.py(85): forward',
#   'torch/nn/modules/module.py(1740): _call_impl',
#   'nn.Module: Attention_6',
#   'timm/models/vision_transformer.py(164): forward',
#   'torch/nn/modules/module.py(1740): _call_impl',
#   'nn.Module: Block_6',
#   'torch/nn/modules/container.py(248): forward',
#   'torch/nn/modules/module.py(1740): _call_impl',
#   'nn.Module: Sequential_0',
#   'timm/models/vision_transformer.py(802): forward_features',
#   'timm/models/vision_transformer.py(828): forward',
#   'torch/nn/modules/module.py(1740): _call_impl',
#   'nn.Module: VisionTransformer_0',
#   'src/models/components/part_vit.py(75): forward',
#   'torch/nn/modules/module.py(1740): _call_impl',
#   'nn.Module: PARTViT_0',
#   'src/models/part_vit_module.py(147): forward',
#   'src/models/part_vit_module.py(177): model_step',
#   '/tmp/ipykernel_161062/3606682951.py(33): <module>',
#   'IPython/core/interactiveshell.py(3577): run_code',
#   'IPython/core/interactiveshell.py(3517): run_ast_nodes',
#   'IPython/core/interactiveshell.py(3334): run_cell_async',
#   'IPython/core/async_helpers.py(128): _pseudo_sync_runner',
#   'IPython/core/interactiveshell.py(3130): _run_cell',
#   'IPython/core/interactiveshell.py(3075): run_cell',
#   'ipykernel/zmqshell.py(549): run_cell',
#   'ipykernel/ipkernel.py(449): do_execute',
#   'ipykernel/kernelbase.py(778): execute_request',
#   'ipykernel/ipkernel.py(362): execute_request',
#   'ipykernel/kernelbase.py(437): dispatch_shell',]
# 
# parse all "nn.Module"
import re
def parse_stack(stack):
    pat = re.compile(r"nn.Module: ([\w\d_]+)")
    modules = [ m.group(1) for m in [pat.match(s) for s in stack] if m is not None]
    # now return module names with slashes, example: PARTViT_0/VisonTransformer_0/Sequential_0/Block_6/Attention_6
    return "/".join(modules[::-1])


In [10]:
prof.events()[0].stack

['<built-in method arange of type object at 0x7da2a7d6dd40>',
 'src/data/components/sampling_utils.py(108): _sample_ongrid',
 'src/data/components/sampling_utils.py(43): sample_and_stitch',
 '/tmp/ipykernel_21294/2319727133.py(311): forward',
 'torch/nn/modules/module.py(1753): _call_impl',
 'nn.Module: PARTMaskedAutoEncoderViT_0',
 '/tmp/ipykernel_21294/3251402764.py(33): <module>',
 'IPython/core/interactiveshell.py(3577): run_code',
 'IPython/core/interactiveshell.py(3517): run_ast_nodes',
 'IPython/core/interactiveshell.py(3334): run_cell_async',
 'IPython/core/async_helpers.py(128): _pseudo_sync_runner',
 'IPython/core/interactiveshell.py(3130): _run_cell',
 'IPython/core/interactiveshell.py(3075): run_cell',
 'ipykernel/zmqshell.py(549): run_cell',
 'ipykernel/ipkernel.py(449): do_execute',
 'ipykernel/kernelbase.py(778): execute_request',
 'ipykernel/ipkernel.py(362): execute_request',
 'ipykernel/kernelbase.py(437): dispatch_shell',
 'ipykernel/kernelbase.py(534): process_one',

In [11]:
evs = [ev for ev in prof.events() if "3251402764.py" in "\n".join(ev.stack)]
# [ev.stack[:6] for ev in evs]
# plot the memory usa
import plotly.express as px
import pandas as pd

# plot timeline with color coding for memory usage
from datetime import datetime

df = pd.DataFrame(
    [
        {
            # "cpu_start": datetime.fromtimestamp(start_time + ev.time_range.start / 1e6),
            # "cpu_end": datetime.fromtimestamp(start_time + ev.time_range.end / 1e6),
            "cpu_start": ev.time_range.start / 1e3,
            "cpu_end": ev.time_range.end / 1e3,
            "name": ev.name,
            "id": ev.id,
            "memory": ev.device_memory_usage / 1024**3,
            "stack": "\n".join(ev.stack[:5]),
            "module": parse_stack(ev.stack),
        }
        for ev in evs
    ]
)

In [20]:

df["main_module"] = df["module"].apply(
    lambda x: x.split("/")[1] if len(x.split("/")) > 1 else x
)
# df['delta'] = df['Finish'] - df['Start']
df["delta"] = df["cpu_end"] - df["cpu_start"]
df = df.sort_values("cpu_start")


# Create cumulative memory usage
df["cumulative_memory"] = df["memory"].cumsum()


# fig = px.timeline(evs, x_start="cpu_start", x_end="cpu_end", y="name", color="device", title="Memory Usage Timeline")
# fig = px.timeline(
#     df,
#     x_start="cpu_start",
#     x_end="cpu_end",
#     y="main_module",
#     color="memory",
#     title="Memory Usage Timeline",
#     hover_data="module",
# )  # , color_continuous_scale=px.colors.sequential.Teal)
# fig.update_yaxes(visible=False)
#
# fig.update_yaxes(autorange="reversed")
#
# fig.layout.xaxis.type = 'linear'
# make gantt instead of timeline
fig_timeline = px.timeline(
    df,
    x_start="cpu_start",
    x_end="cpu_end",
    y="name",
    color="memory",
    title="Memory Usage Timeline",
    hover_data="module",
    color_continuous_scale=px.colors.sequential.BuGn,
)
fig_timeline.update_yaxes(visible=False)
fig_timeline.layout.xaxis.type = "linear"
fig_timeline.data[0].x = df.delta.tolist()

# # Create the plot
fig_area = px.area(
    df,
    x="cpu_start",
    y="cumulative_memory",
    title="GPU Memory Usage Over Time",
    labels={"timestamp": "Time", "cumulative_memory": "Memory Usage (GB)"},
)

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# i want to overlay the timeline and the area plot that is, both in the same figure
# it needs to be a timeline

fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.02)
timeline_trace = fig.add_trace(fig_timeline.data[0], row=1, col=1)
area_trace = fig.add_trace(fig_area.data[0], row=2, col=1)

# colors = 
max_mem = df.memory.max()
# timeblock_color = [px.colors.sequential.RdBu_r[int(m)] for m in df.memory]
# timeblock_color = [px.colors.sequential.Plotly3[int(9*m/max_mem)] for m in df.memory]

# invisible y acis for the timeline
# fig.update_yaxes(visible=False, row=1, col=1)
# update y axis to 0 to 10 for the area plot
# fig.update_yaxes(range=[0, 16], row=2, col=1)
# y label is event
fig.update_yaxes(title_text="Event", row=1, col=1)
# remove ticks from y axis
fig.update_yaxes(showticklabels=False, row=1, col=1)

# update colorscheme of timeline to Teal
# timeline_trace
# fig.update_traces(marker=dict(color=df["memory"], colorscale=px.colors.sequential.BuGn), selector=dict(type="bar"))
# fig.update_traces(marker=dict(color=timeblock_color, ), selector=dict(type="bar"), overwrite=True)
# fig.update_layout(coloraxis_colorbar=dict(title="Memory (GB)"))
# add colorbar to legend
fig.update_layout(coloraxis_colorbar=dict(title="Memory (GB)"))

# x axis is ms
fig.update_xaxes(title_text="Time (ms)", row=2, col=1)
# y axis is memory in GB
fig.update_yaxes(title_text="Memory (GB)", row=2, col=1)

# add vertical lines for each main_module's start_time with its main_module as label
# main_df = df.groupby("main_module").first()
# for i, row in main_df.iterrows():
#     fig.add_vline(x=row.cpu_start, line_width=1, line_dash="dot", line_color="red", row=1, col=1)
#     fig.add_vline(x=row.cpu_start, line_width=1, line_dash="dot", line_color="red", row=2, col=1)

# main_df = df.groupby("main_module").first()
# for i, row in main_df.iterrows():
#     # set annotation text at 45 degree angle
#     # fig.add_vline(x=row.cpu_start, line_width=1, line_dash="dot", line_color="red", row=1, col=1, annotation_text=row.name, annotation_angle=45)
#     # fig.add_vline(x=row.cpu_start, line_width=1, line_dash="dot", line_color="red", row=2, col=1, annotation_text=row.name)
#     # use vrect instead
#     fig.add_vrect(x0=row.cpu_start, x1=row.cpu_end, fillcolor="red", opacity=0.1, layer="below", row=1, col=1)
#     fig.add_vrect(x0=row.cpu_start, x1=row.cpu_end, fillcolor="red", opacity=0.1, layer="below", row=2, col=1)






# doesn't work try something else
# add a colorbar to the right of the figure


# add  to fig overlay adapt to px.bar

fig.show()

In [14]:
df.columns

Index(['cpu_start', 'cpu_end', 'name', 'id', 'memory', 'stack', 'module',
       'main_module', 'delta', 'cumulative_memory'],
      dtype='object')

In [15]:
df["module"].apply(lambda x: x.split("/")[1] if len(x.split("/")) > 1 else x).unique()

array(['PARTMaskedAutoEncoderViT_0', 'PatchEmbed_0', 'Block_0', 'Block_1',
       'Block_2', 'Block_3', 'Block_4', 'Block_5', 'Block_6', 'Block_7',
       'Block_8', 'Block_9', 'Block_10', 'Block_11', 'LayerNorm_24',
       'Linear_48', 'Block_12', 'Block_13', 'Block_14', 'Block_15',
       'Block_16', 'Block_17', 'Block_18', 'Block_19', 'LayerNorm_41',
       'Linear_81', 'Tanh_0', ''], dtype=object)

In [16]:
evs[0].time_range.start / 1e3

0.263396

In [17]:
# Create a dataframe for memory usage over time with datetime
memory_events = pd.DataFrame(
    [
        {
            "timestamp": ev.time_range.start / 1e3,  # convert to datetime
            # "memory": ev.device_memory_usage,
            # "memory": ev.device_memory_usage / 1024 ** 2,
            "memory": ev.device_memory_usage / 1024 ** 3,
            "name": ev.name
        }
        for ev in evs
        if ev.device_memory_usage is not None
    ]
)

# Sort by timestamp 
memory_events = memory_events.sort_values("timestamp")

# Create cumulative memory usage
memory_events["cumulative_memory"] = memory_events["memory"].cumsum()

# Create the plot
fig = px.area(
    memory_events, 
    x="timestamp",
    y="cumulative_memory",
    title="GPU Memory Usage Over Time",
    labels={
        "timestamp": "Time",
        "cumulative_memory": "Memory Usage (GB)"
    }
)

# add a vertical line at each time a module starts
# for example module PartViT_0/VisonTransformer_0/Sequential_0/Block_6/Attention_6 starts on the ev with minimum ev.time_range.start
# let's keep only PartViT_0/VisonTransformer_0/Sequential_0/Block_6
# df["last_module"]  = df["module"].apply(lambda x: x.split("/")[-1])
# df_grouped_by_module = df.groupby("last_module")
# for module, df_module in df_grouped_by_module:
#     module_start = df_module["cpu_start"].min()
#     # print(f"Module {module} starts at {module_start}")
#     fig.add_vline(
#         x=module_start.timestamp() * 1000,
#         line_dash="dash",
#         line_color="red",
#         annotation_text=module
#     )

df["main_module"] = df["module"].apply(lambda x: x.split("/")[1] if len(x.split("/")) > 1 else x)
df_grouped_by_module = df.groupby("main_module")
for module, df_module in df_grouped_by_module:
    module_start = df_module["cpu_start"].min()
    print(f"Module {module} starts at {module_start}")
    # add a vertical line at each time a module starts



fig.add_hline(
    y=memory_events["cumulative_memory"].max(), 
    line_dash="dash", 
    line_color="red",
    annotation_text=f"Max Memory Usage: {memory_events['cumulative_memory'].max():.2f} GB"
)
fig.update_xaxes(
    title_text="Time",
    tickformat='%L',
    ticklabelmode="period",
    rangeslider_visible=True,
)

fig.update_layout(
    xaxis_title="Time",
    yaxis_title="Memory Usage (GB)",
    showlegend=False
)

fig.show()

Module  starts at 541.148466
Module Block_0 starts at 77.460755
Module Block_1 starts at 122.48560099999999
Module Block_10 starts at 125.442081
Module Block_11 starts at 125.752296
Module Block_12 starts at 126.134806
Module Block_13 starts at 126.507028
Module Block_14 starts at 126.790522
Module Block_15 starts at 127.132666
Module Block_16 starts at 127.427091
Module Block_17 starts at 127.749208
Module Block_18 starts at 128.04085700000002
Module Block_19 starts at 128.338829
Module Block_2 starts at 122.844347
Module Block_3 starts at 123.227769
Module Block_4 starts at 123.53571000000001
Module Block_5 starts at 123.84277800000001
Module Block_6 starts at 124.155758
Module Block_7 starts at 124.481411
Module Block_8 starts at 124.776137
Module Block_9 starts at 125.118311
Module LayerNorm_24 starts at 126.04656
Module LayerNorm_41 starts at 128.640156
Module Linear_48 starts at 126.08223699999999
Module Linear_81 starts at 128.704337
Module PARTMaskedAutoEncoderViT_0 starts at 0

In [18]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_memory_usage", row_limit=5))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  Source Location                                                              
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
     

In [19]:
import torch.autograd.profiler_util
e: torch.autograd.profiler_util.FunctionEvent = prof.events()[100]
# get first event with stack
