In [None]:
#!/usr/bin/env python
# coding: utf-8

import re
import contextlib
from datetime import datetime
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import torch
import plotly.graph_objects as go

TIME_FORMAT_STR = "%b_%d_%H_%M_%S"

class MemoryProfiler:
    def __init__(self, filter_keyword):
        """
        Initialize the MemoryProfiler.

        Args:
            filter_keyword (str): Keyword to filter profiler events.
        """
        self.filter_keyword = filter_keyword
        self.prof = None
        self.df_events = None

    @staticmethod
    def parse_stack(stack):
        """
        Parse a stack list to extract nn.Module names.
        Returns a string with module names (from innermost to outer) separated by '/'.
        """
        pat = re.compile(r"nn.Module: ([\w\d_]+)")
        modules = [m.group(1) for m in (pat.match(s) for s in stack) if m is not None]
        return "/".join(modules[::-1])
    
    @contextlib.contextmanager
    def profile(self, profiler_kwargs=None):
        """
        A context manager for profiling a block of code.
        
        Usage:
            with profiler.profile():
                # Code to profile, e.g., model(*input)
        """
        # Combine default kwargs with any overrides
        kwargs = dict()
        if profiler_kwargs:
            kwargs.update(profiler_kwargs)
            
        self.prof = torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ],
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
            with_modules=True,
            experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True),
            **kwargs
        )
        torch.cuda.memory._record_memory_history(max_entries=1000000)   # start
        self.prof.__enter__()
        try:
            yield
        finally:
            self.prof.__exit__(None, None, None)

    def extract_events_df(self):
        """
        Extract profiler events as a pandas DataFrame.
        Returns a DataFrame with columns for start/end times (ms), event name, memory (GB),
        the full stack, a short version of the stack, parsed module hierarchy, and duration.
        """
        evs = [ev for ev in self.prof.events() if self.filter_keyword in "\n".join(ev.stack)]
        df = pd.DataFrame(
            [
                {
                    "cpu_start": ev.time_range.start / 1e3,
                    "cpu_end": ev.time_range.end / 1e3,
                    "name": ev.name,
                    "id": ev.id,
                    "memory": (ev.device_memory_usage / 1024**3) if ev.device_memory_usage is not None else 0,
                    "stack": "\n".join(ev.stack),
                    "short_stack": "\n".join(ev.stack[5:]),
                    "module": self.parse_stack(ev.stack),
                }
                for ev in evs
            ]
        )
        df["delta"] = df["cpu_end"] - df["cpu_start"]
        self.df_events = df.sort_values("cpu_start")
        return self.df_events

    def plot_timeline(self):
        """
        Create an interactive timeline plot of events using Plotly.
        Each event is represented as a box from its start to end time and colored by its memory usage.
        Hover data shows the parsed module hierarchy.
        """
        if self.df_events is None:
            raise ValueError("No events extracted. Run extract_events_df() first.")
        fig = px.timeline(
            self.df_events,
            x_start="cpu_start",
            x_end="cpu_end",
            y="name",
            color="memory",
            title="Memory Usage Timeline",
            hover_data=["module"],
        )
        fig.update_yaxes(visible=False)
        fig.layout.xaxis.type = "linear"
        # Adjust widths of boxes using the event durations.
        # explicit override of both start (base) and width (x)
        fig.update_traces(
            base=self.df_events["cpu_start"].tolist(),
            x   =self.df_events["delta"].tolist(),
            selector=dict(type="bar")   # only affect bar traces
        )
        return fig

    def plot_cumulative_memory(self):
        """
        Plot the accumulated memory allocations over time.
        Note: This is a cumulative sum and may exceed the device's total memory.
        """
        if self.df_events is None:
            raise ValueError("No events extracted. Run extract_events_df() first.")
        df = self.df_events.copy()
        df["cumulative_memory"] = df["memory"].cumsum()
        fig = px.area(
            df,
            x="cpu_start",
            y="cumulative_memory",
            title="Accumulated GPU Memory Allocation Over Time",
            labels={"cpu_start": "Time (ms)", "cumulative_memory": "Accumulated Memory (GB)"},
        )
        return fig

    def compute_net_memory_usage(self):
        """
        Compute the net memory usage (i.e., memory in use) over time.
        Each event contributes a positive delta at its start and a negative delta at its end.
        Returns a DataFrame with time stamps and net memory usage.
        """
        if self.df_events is None:
            raise ValueError("No events extracted. Run extract_events_df() first.")
        changes = []
        for _, row in self.df_events.iterrows():
            changes.append((row["cpu_start"], row["memory"]))
            changes.append((row["cpu_end"], -row["memory"]))
        changes.sort(key=lambda x: x[0])
        times, net_memory = [], []
        current = 0
        for t, delta in changes:
            current += delta
            times.append(t)
            net_memory.append(current)
        net_df = pd.DataFrame({"time": times, "net_memory": net_memory})
        return net_df

    def plot_net_memory_usage(self):
        """
        Create an area plot showing the net GPU memory usage over time.
        This plot represents the actual memory used at each moment.
        """
        net_df = self.compute_net_memory_usage()
        fig = px.area(
            net_df,
            x="time",
            y="net_memory",
            title="Net GPU Memory Usage Over Time",
            labels={"time": "Time (ms)", "net_memory": "Memory Usage (GB)"},
        )
        return fig

    def plot_record_functions_timeline(self, record_fn_names=["forward", "backward"]):
        """
        Create an interactive timeline plot for recorded function events using Plotly.
        Each record_function event is represented as a box spanning its start to end time.
        If no events are found, the method prints all recorded event names for debugging.
        """
        # Normalize the record function names to lower case.
        target_names = [name.lower() for name in record_fn_names]
        record_events = [ev for ev in self.prof.events() if ev.name and ev.name.lower() in target_names]
        df_record = pd.DataFrame(
            {
                "cpu_start": [ev.time_range.start / 1e3 for ev in record_events],
                "cpu_end": [ev.time_range.end / 1e3 for ev in record_events],
                "name": [ev.name for ev in record_events],
            }
        )
        df_record["delta"] = df_record["cpu_end"] - df_record["cpu_start"]
        fig = px.timeline(
            df_record,
            x_start="cpu_start",
            x_end="cpu_end",
            y="name",
            # color="name",
            title="Recorded Functions Timeline",
            #color_discrete_sequence=px.colors.qualitative.Plotly,
        )
        fig.layout.xaxis.type = "linear"
        fig.update_traces(
            base=df_record["cpu_start"].tolist(),
            x   =df_record["delta"].tolist(),
            selector=dict(type="bar")   # only affect bar traces
        )
        return fig

    def create_combined_plot(self, timeline_fig, memory_fig, record_fig, height=1200, row_heights=[0.3, 0.4, 0.3]):
        """
        Combine three subplots:
          1. The recorded functions timeline.
          2. The main memory events timeline.
          3. The memory usage plot (accumulated or net).
        All subplots share the same x-axis.
        
        Args:
            timeline_fig: The main memory events timeline figure.
            memory_fig: The memory usage figure.
            record_fig: The recorded functions timeline figure.
            height (int): Total height of the combined figure.
            row_heights (list): Relative heights for the three subplots.
        """
        fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.02,
                            row_heights=row_heights)
        # Row 1: Record function timeline.
        for trace in record_fig.data:
            fig.add_trace(trace, row=1, col=1)
        # Row 2: Main memory events timeline.
        for trace in timeline_fig.data:
            fig.add_trace(trace, row=2, col=1)
        # Row 3: Memory usage plot.
        for trace in memory_fig.data:
            fig.add_trace(trace, row=3, col=1)
        fig.update_layout(coloraxis_colorbar=dict(title="Memory (GB)"), height=height)
        fig.update_xaxes(title_text="Time (ms)", row=3, col=1)
        fig.update_yaxes(title_text="Memory Usage (GB)", row=3, col=1)
        fig.update_yaxes(title_text="Event", row=2, col=1, showticklabels=False)
        return fig

# =============================================
# Example usage:
# =============================================
import torch
import lightning as L 
from src.models.components.partmae_v6 import PARTMaskedAutoEncoderViT
from src.data.components.transforms.multi_crop_v4 import ParametrizedMultiCropV4
import timm
from src.data.components.hf_dataset import HFDataset
from torch.utils.data import DataLoader
# Assume you have a CUDA model instance (an nn.Module) and an input tensor

torch.set_float32_matmul_precision("medium")
fabric = L.Fabric(precision="bf16-mixed", accelerator='gpu')
# fabric = L.Fabric(accelerator='cpu')
gV, lV = 2, 10
V = gV + lV
with fabric.init_module():
    model = PARTMaskedAutoEncoderViT(
        sampler="stratified_jittered",
        alpha_ts=0.8,
        mask_ratio=0.75,
        pos_mask_ratio=0.75,
        alpha_t=0.75,
        max_scale_ratio=6.0,
        canonical_img_size=512,
        depth=12,
        num_heads=12,
        mlp_ratio=4,
        embed_dim=768,
        decoder_embed_dim = 512,
        decoder_depth = 8,
        decoder_num_heads = 16,

    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=8e-6)


model = model.cuda()
B = 64
transform = ParametrizedMultiCropV4(n_global_crops=gV, n_local_crops=lV)
dataset = HFDataset("frgfm/imagenette", "160px", transform=transform)
train_dataloader = DataLoader(dataset, batch_size=B, shuffle=False)
model, optimizer = fabric.setup(model, optimizer)
model.train()
dataloader = fabric.setup_dataloaders(train_dataloader)
batch = next(iter(dataloader))

# Create the profiler instance
profiler = MemoryProfiler(filter_keyword="partmae_v6")


_ = model(*batch)
with profiler.profile():
    for _ in range(2):
        with torch.autograd.profiler.record_function("forward"):
            out = model(*batch)
        # with torch.autograd.profiler.record_function("backward"):
        #     fabric.backward(out["loss"])

# Extract events into a DataFrame
df_events = profiler.extract_events_df()

# Generate plots
timeline_fig = profiler.plot_timeline()
memory_fig = profiler.plot_net_memory_usage()  # or use plot_cumulative_memory()
record_fig = profiler.plot_record_functions_timeline(record_fn_names=["forward", "backward"])

# Combine all three plots into one figure.
combined_fig = profiler.create_combined_plot(timeline_fig, memory_fig, record_fig, height=1200, row_heights=[0.1, 0.5, 0.4])
combined_fig.show()


In [2]:
torch.cuda.memory._dump_snapshot("../../mem.pickle")
