# Example

In [1]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

# 1. Setup as before
device = torch.device('cuda')
def make_grid(H, W, device):
    ys = torch.linspace(-1, 1, H, device=device)
    xs = torch.linspace(-1, 1, W, device=device)
    yy, xx = torch.meshgrid(ys, xs, indexing='ij')
    return torch.stack([xx, yy], dim=-1)
grid = make_grid(256, 256, device)

def render_batch(einstein_radii, src_centers, src_sigma):
    B = einstein_radii.shape[0]
    mesh = grid.unsqueeze(0).expand(B, -1, -1, -1)
    r = torch.norm(mesh, dim=-1, keepdim=True) + 1e-6
    direction = mesh / r
    theta_E = einstein_radii.view(B,1,1,1)
    alpha   = theta_E * direction
    beta    = mesh - alpha
    ctr     = src_centers.view(B,1,1,2)
    sig     = src_sigma.view(B,1,1,1)
    dx      = beta - ctr
    return torch.exp(-0.5 * torch.sum(dx*dx, dim=-1) / (sig**2))

B = 8
einstein_radii = torch.linspace(0.1, 0.5, B, device=device)
src_centers    = torch.zeros(B,2,device=device)
src_centers[:,0] = torch.linspace(-0.3, 0.3, B)
src_sigma      = torch.full((B,), 0.1, device=device)

traced_fn      = torch.jit.trace(render_batch, (einstein_radii, src_centers, src_sigma))

# 2. Warm up both
for _ in range(20):
    render_batch(einstein_radii, src_centers, src_sigma)
    traced_fn(einstein_radii, src_centers, src_sigma)
torch.cuda.synchronize()

# 3. Profile both versions in separate regions
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
with profile(activities=activities, record_shapes=True) as prof:
    # raw Python version
    with record_function("RAW_render_batch"):
        for _ in range(50):
            render_batch(einstein_radii, src_centers, src_sigma)
    torch.cuda.synchronize()

    # JIT‐traced version
    with record_function("JIT_render_batch"):
        for _ in range(50):
            traced_fn(einstein_radii, src_centers, src_sigma)
    torch.cuda.synchronize()

# 4. Export trace
prof.export_chrome_trace("trace_raw_vs_jit.json")
print("Trace saved to trace_raw_vs_jit.json")


Trace saved to trace_raw_vs_jit.json


STAGE:2025-05-22 10:29:14 5966:5966 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2025-05-22 10:29:14 5966:5966 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2025-05-22 10:29:14 5966:5966 output_json.cpp:417] Completed Stage: Post Processing


# start by analyzing the most expensive functions for batch preparation

In [1]:
from deep_learning import NoNoiseDataset 
from deep_learning import custom_dataloader

In [2]:
# dataset=AlmaSinglePsfDataset(catalog_dict=catalog_dict,
#                              psf_name="devon_first_advice_psf_3_pix_16_arcsec",
#                              noise_std=0.0,
#                              broadcasting=False)
dataset=NoNoiseDataset(
        catalog_name="conor_similar_cat_val.json",
        grid_width_arcsec=8.0,
        grid_pixel_side=1000,
        broadcasting=True,
        final_transform=False
)



Using device: cuda
Currently this dataloader is calculating the images in float32


In [3]:
dataloader=custom_dataloader(dataset=dataset, batch_size=300, shuffle=False)

In [4]:
iterator=iter(dataloader)

def warmup():
    for i in range (3):
        batch, _ = next(iterator)
    
def target_eval():
    for i in range (10):
        batch, _ = next(iterator)
        

%prun target_eval()

Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
Currently not cropping and rotating, set final_transform in the intialization to True to use it
 

         624334 function calls (479754 primitive calls) in 11.464 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      120   10.062    0.084   10.062    0.084 {built-in method torch.tensor}
       10    0.921    0.092    0.959    0.096 util.py:3(_hyp2f1_series)
    86808    0.140    0.000    0.140    0.000 {built-in method torch.as_tensor}
119704/3000    0.071    0.000    0.276    0.000 util.py:4(recursive_to_tensor)
29926/3000    0.035    0.000    0.274    0.000 util.py:7(<dictcomp>)
       60    0.034    0.001    0.034    0.001 {built-in method torch.sqrt}
       10    0.033    0.003   11.459    1.146 no_noise_dataset.py:59(get_batch)
      270    0.028    0.000    0.028    0.000 {built-in method torch.stack}
     3000    0.024    0.000    0.024    0.000 {method 'long' of 'torch._C._TensorBase' objects}
302040/301120    0.024    0.000    0.025    0.000 {built-in method builtins.isinstance}
       30    0.021    0.001    