In [1]:
import torch
import torch.nn as nn

class MassComponent(nn.Module):
    def __init__(self):
        """
            We unpack the tensor configuration in the common features
            between different mass components.

            In this code, all the components are at same redshift, so the
            redshift is not a parameter of the mass component.

            We have only the position.
        """
        super().__init__()


    def deflection_angle(self, lens_grid, z_source):
        """
           This forward computes the deflection field of the mass component
        """
        pass

    import torch

import shared_utils.units as units

import shared_utils.units as units

class SIS(MassComponent):
    def __init__(self, input_tensor=None, device="cuda", dtype=torch.float32):
        super().__init__()
        self.device = device
        self.dtype = dtype
        
        # Store compute precision constants for reuse
        self.compute_dtype = torch.float32
        self.c = torch.tensor(units.c, device=device, dtype=self.compute_dtype)
        self.pi4 = torch.tensor(4 * torch.pi, device=device, dtype=self.compute_dtype)
        self.epsilon = torch.tensor(1e-8, device=device, dtype=self.compute_dtype)
        self.zero = torch.tensor(0.0, device=device, dtype=self.compute_dtype)
        
        if input_tensor is not None:
            # Convert input tensor to compute precision directly
            self.input_tensor = input_tensor.to(device=device, dtype=self.compute_dtype)
            
            # Extract parameters (all in compute precision)
            self.pos = self.input_tensor[:, :2]
            self.redshift = self.input_tensor[:, 2]
            self.vel_disp = self.input_tensor[:, 3]
            self.D_ls = self.input_tensor[:, 4]
            self.D_s = self.input_tensor[:, 5]
            
            # Pre-compute Einstein angle (most expensive calculation)
            vel_disp_c = self.vel_disp / self.c
            squared = vel_disp_c ** 2
            D_ratio = self.D_ls / self.D_s
            self.einstein_angle = self.pi4 * squared * D_ratio
        else:
            self.input_tensor = None
        
        # Buffers will be allocated on first use
        self._initialized_buffers = False
        self._buffer_shapes = None
    
    def _initialize_buffers(self, batch_size, height, width):
        """Initialize all buffers at once with compute precision"""
        # Only initialize once for each shape configuration
        current_shapes = (batch_size, height, width)
        if self._initialized_buffers and self._buffer_shapes == current_shapes:
            return
            
        # Create all buffers in compute precision directly
        self._x_rel = torch.empty((batch_size, height, width, 2), 
                                  device=self.device, dtype=self.compute_dtype)
        self._r_squared = torch.empty((batch_size, height, width), 
                                      device=self.device, dtype=self.compute_dtype)
        self._r = torch.empty((batch_size, height, width, 1), 
                             device=self.device, dtype=self.compute_dtype)
        self._result = torch.empty((batch_size, height, width, 2), 
                                  device=self.device, dtype=self.compute_dtype)
        
        self._initialized_buffers = True
        self._buffer_shapes = current_shapes

    @torch.jit.script_method
    def deflection_angle(self, lens_grid, z_source=None):
        """Optimized implementation using pre-allocated buffers and minimizing type conversions"""
        # Upcast input grid once if needed
        lens_grid_f32 = lens_grid if lens_grid.dtype == self.compute_dtype else lens_grid.to(dtype=self.compute_dtype)
        
        # Get dimensions
        batch_size, height, width, _ = lens_grid_f32.shape
        
        # Ensure buffers are initialized
        self._initialize_buffers(batch_size, height, width)
        
        # Calculate relative positions (in-place)
        pos_expanded = self.pos.view(batch_size, 1, 1, 2)
        torch.sub(lens_grid_f32, pos_expanded, out=self._x_rel)
        
        # Calculate squared distances (in-place)
        x_rel_x = self._x_rel[..., 0]
        x_rel_y = self._x_rel[..., 1]
        torch.mul(x_rel_x, x_rel_x, out=self._r_squared)
        torch.addcmul(self._r_squared, x_rel_y, x_rel_y, value=1.0, out=self._r_squared)
        
        # Calculate r with safe division (in-place)
        torch.sqrt(self._r_squared, out=self._r.squeeze(-1))
        torch.maximum(self._r.squeeze(-1), self.epsilon, out=self._r.squeeze(-1))
        self._r = self._r.view(batch_size, height, width, 1)
        
        # Einstein angle already in compute precision, reshape for broadcasting
        einstein_angle_expanded = self.einstein_angle.view(batch_size, 1, 1, 1)
        
        # Calculate deflection result directly (in-place)
        torch.mul(einstein_angle_expanded, self._x_rel, out=self._result)
        torch.div(self._result, self._r, out=self._result)
        
        # Convert to target precision only at the end
        if self.dtype != self.compute_dtype:
            result = self._result.to(dtype=self.dtype)
        else:
            result = self._result
            
        return result


class NFW(MassComponent):
    def __init__(self, input_tensor, device="cuda", dtype=torch.float32):
        super().__init__()
        self.device = device
        self.dtype = dtype
        self.compute_dtype = torch.float32
        
        # Convert input tensor to compute precision directly
        self.input_tensor = input_tensor.to(device=device, dtype=self.compute_dtype)
        
        # Extract parameters (all in compute precision)
        self.pos = self.input_tensor[:, :2]
        self.mass_max = self.input_tensor[:, 2]
        self.r_max_kpc = self.input_tensor[:, 3]
        self.D_l = self.input_tensor[:, 4] * 1000  # Convert to kpc immediately
        self.D_s = self.input_tensor[:, 5] * 1000  # Convert to kpc immediately
        self.D_ls = self.input_tensor[:, 6] * 1000  # Convert to kpc immediately

        # Constants in compute precision
        self.const = torch.tensor(2.16258, device=device, dtype=self.compute_dtype)
        self.pi = torch.tensor(torch.pi, device=device, dtype=self.compute_dtype)
        self.G = torch.tensor(units.G, device=device, dtype=self.compute_dtype)
        self.c_squared = torch.tensor(units.c**2, device=device, dtype=self.compute_dtype)
        self.epsilon = torch.tensor(1e-8, device=device, dtype=self.compute_dtype)
        self.four = torch.tensor(4.0, device=device, dtype=self.compute_dtype)
        self.two = torch.tensor(2.0, device=device, dtype=self.compute_dtype)
        self.one = torch.tensor(1.0, device=device, dtype=self.compute_dtype)
        self.zero = torch.tensor(0.0, device=device, dtype=self.compute_dtype)
        
        # Pre-compute derivable parameters (all in compute precision)
        # r_s calculation
        self.r_s = self.r_max_kpc / self.const
        
        # rho_s calculation
        log_term = torch.log(1.0 + self.const)
        const_term = self.const / (1.0 + self.const)
        denominator = log_term - const_term
        r_s_cubed = self.r_s ** 3
        numerator = self.mass_max / (self.four * self.pi * r_s_cubed)
        self.rho_s = numerator / denominator
        
        # Pre-compute sigma_crit
        num = self.c_squared * self.D_s
        denom = self.four * self.pi * self.G * self.D_l * self.D_ls
        self.sigma_crit = num / denom
        
        # Pre-compute kappa_s
        self.ks = (self.r_s * self.rho_s) / self.sigma_crit
        
        # Buffers will be allocated on first use
        self._initialized_buffers = False
        self._buffer_shapes = None
    
    def _initialize_buffers(self, batch_size, height, width):
        """Initialize all buffers at once with compute precision"""
        # Only initialize once for each shape configuration
        current_shapes = (batch_size, height, width)
        if self._initialized_buffers and self._buffer_shapes == current_shapes:
            return
        
        grid_shape = (batch_size, height, width)
        
        # Create all buffers in compute precision directly
        self._xrel = torch.empty((batch_size, height, width, 2), 
                               device=self.device, dtype=self.compute_dtype)
        self._rs2 = torch.empty(grid_shape, 
                              device=self.device, dtype=self.compute_dtype)
        self._rs = torch.empty(grid_shape, 
                             device=self.device, dtype=self.compute_dtype)
        self._rs_nodim = torch.empty(grid_shape, 
                                  device=self.device, dtype=self.compute_dtype)
        self._F = torch.empty(grid_shape, 
                            device=self.device, dtype=self.compute_dtype)
        self._log_term = torch.empty(grid_shape, 
                                  device=self.device, dtype=self.compute_dtype)
        self._alpha = torch.empty(grid_shape, 
                                device=self.device, dtype=self.compute_dtype)
        self._alpha_vec = torch.empty((batch_size, height, width, 2), 
                                   device=self.device, dtype=self.compute_dtype)
        
        self._initialized_buffers = True
        self._buffer_shapes = current_shapes
    
    @torch.jit.script_method
    def deflection_angle(self, lens_grid, z_source=None):
        """Optimized implementation using pre-allocated buffers and minimizing type conversions"""
        # Upcast input grid once if needed
        lens_grid_f32 = lens_grid if lens_grid.dtype == self.compute_dtype else lens_grid.to(dtype=self.compute_dtype)
        
        # Get dimensions
        batch_size, height, width, _ = lens_grid_f32.shape
        
        # Ensure buffers are initialized
        self._initialize_buffers(batch_size, height, width)
        
        # Calculate relative positions (in-place)
        pos_expanded = self.pos.view(batch_size, 1, 1, 2)
        torch.sub(lens_grid_f32, pos_expanded, out=self._xrel)
        
        # Calculate radius squared (in-place)
        x_rel_x = self._xrel[..., 0]
        x_rel_y = self._xrel[..., 1]
        torch.mul(x_rel_x, x_rel_x, out=self._rs2)
        torch.addcmul(self._rs2, x_rel_y, x_rel_y, value=1.0, out=self._rs2)
        
        # Calculate radius with safe division (in-place)
        torch.sqrt(self._rs2, out=self._rs)
        torch.maximum(self._rs, self.epsilon, out=self._rs)
        
        # Calculate rs_nodim (in-place)
        D_l_expanded = self.D_l.view(-1, 1, 1)
        r_s_expanded = self.r_s.view(-1, 1, 1)
        torch.mul(self._rs, D_l_expanded, out=self._rs_nodim)
        torch.div(self._rs_nodim, r_s_expanded, out=self._rs_nodim)
        
        # Reset F buffer (in-place)
        self._F.zero_()
        
        # Get masks efficiently
        mask1 = self._rs_nodim < 1
        mask2 = self._rs_nodim == 1
        mask3 = self._rs_nodim > 1
        
        # Handle each case
        if mask1.any():
            rs_nodim_mask1 = self._rs_nodim[mask1]
            rs_nodim_sq = rs_nodim_mask1 * rs_nodim_mask1
            sqrt_term1 = torch.sqrt(self.one - rs_nodim_sq)
            F_values = torch.atanh(sqrt_term1) / sqrt_term1
            self._F[mask1] = F_values
        
        if mask2.any():
            self._F[mask2] = 1.0
        
        if mask3.any():
            rs_nodim_mask3 = self._rs_nodim[mask3]
            rs_nodim_sq = rs_nodim_mask3 * rs_nodim_mask3
            sqrt_term3 = torch.sqrt(rs_nodim_sq - self.one)
            F_values = torch.atan(sqrt_term3) / sqrt_term3
            self._F[mask3] = F_values
        
        # Calculate log term (in-place)
        torch.div(self._rs_nodim, self.two, out=self._log_term)
        torch.log(self._log_term, out=self._log_term)
        
        # Calculate deflection magnitude (in-place)
        ks_expanded = self.ks.view(batch_size, 1, 1)
        r_s_expanded = self.r_s.view(-1, 1, 1)
        D_l_expanded = self.D_l.view(-1, 1, 1)
        
        # Calculate the log_term + F value
        torch.add(self._log_term, self._F, out=self._alpha)
        
        # Calculate full alpha formula
        torch.mul(self._alpha, self.four * ks_expanded, out=self._alpha)
        torch.mul(self._alpha, r_s_expanded, out=self._alpha)
        torch.div(self._alpha, D_l_expanded * self._rs_nodim, out=self._alpha)
        
        # Calculate deflection vector (in-place)
        alpha_expanded = self._alpha.view(batch_size, height, width, 1)
        rs_expanded = self._rs.view(batch_size, height, width, 1)
        torch.mul(alpha_expanded, self._xrel, out=self._alpha_vec)
        torch.div(self._alpha_vec, rs_expanded, out=self._alpha_vec)
        
        # Convert to target precision only at the end
        if self.dtype != self.compute_dtype:
            result = self._alpha_vec.to(dtype=self.dtype)
        else:
            result = self._alpha_vec
            
        return result

import torch
from shared_utils import _arcsec_to_rad

class ExternalPotential(MassComponent):
    def __init__(self, input_tensor, device="cuda"):
        """
        For the ExternalPotential, the input tensor is expected to be of shape [B, 4]:
            [[shear_center_x, shear_center_y, shear_strength, shear_angle_arcsec],
             [shear_center_x, shear_center_y, shear_strength, shear_angle_arcsec],
             ...]
        """
        super().__init__()
        self.device = device
        self.input_tensor = input_tensor.to(self.device)
        
        # Extract batched parameters.
        # shear_center: shape [B, 2]
        self.shear_center = self.input_tensor[:, :2]
        # shear_strength: shape [B]
        self.ss = self.input_tensor[:, 2]
        # Convert shear angle from arcseconds to radians.
        self.sa = _arcsec_to_rad(self.input_tensor[:, 3])

    def deflection_angle(self, lens_grid, precomp_dict=None, z_source=None):
        """
        Compute the batched deflection angle for the external shear.
        
        Parameters:
          lens_grid: Tensor of shape [B, H, W, 2] representing the grid in the lens plane.
          precomp_dict: Dummy argument for uniformity with other mass components.
          z_source: Dummy argument for clarity.
          
        Returns:
          A tensor of deflection angles with shape [B, H, W, 2].
        """
        # Reshape shear_center to [B, 1, 1, 2] for broadcasting.
        shear_center = self.shear_center.view(-1, 1, 1, 2)
        # Compute relative positions.
        xrel = lens_grid - shear_center
        
        # Reshape shear strength and shear angle for broadcasting.
        ss = self.ss.view(-1, 1, 1)
        sa = self.sa.view(-1, 1, 1)
        
        # Compute cosine and sine terms for the shear transformation.
        cs2 = -torch.cos(2 * sa)
        sn2 = -torch.sin(2 * sa)
        
        # Compute the deflection angle components.
        alpha_x = ss * (cs2 * xrel[..., 0] + sn2 * xrel[..., 1])
        alpha_y = ss * (sn2 * xrel[..., 0] - cs2 * xrel[..., 1])
        
        # Stack the components to form the vector field.
        alpha = torch.stack((alpha_x, alpha_y), dim=-1)
        return alpha

import torch

def _hyp2f1_series(z, r2, t, q, max_terms=15):
    """
    Batched implementation of the hypergeometric 2F1 series for the PEMD lens model.
    
    Parameters:
      z       : Complex tensor of shape [B, ...] representing the (complex) coordinate.
      r2      : Tensor of shape [B, ...] representing the squared elliptical radius.
      t       : Tensor of shape [B, ...] representing the power-law slope.
      q       : Tensor of shape [B, ...] representing the axis ratio.
      max_terms: Maximum number of terms in the series expansion.
      
    Returns:
      f       : Complex tensor of the same shape as z, containing the series evaluation.
    
    Note:
      A warning is printed if any element of q is less than 0.8, since convergence
      issues may occur in that regime.
    """
    if (q < 0.8).any():
        print("Warning: some q < 0.8 in this _hyp2f1_series implementation may not converge")
    
    # Compute qp = (1 - q^2) / q^2, with q being batched.
    qp = (1 - q**2) / q**2
    # Compute w2 = qp * r2 / z^2. Division is elementwise.
    w2 = qp * r2 / (z**2)
    
    # Compute u = 0.5 * (1 - sqrt(1 - w2))
    u = 0.5 * (1.0 - torch.sqrt(1.0 - w2))
    
    # Initialize u_n and a_n as tensors of ones with the same shape (and type) as u.
    u_n = torch.ones_like(u)  # u_n will accumulate powers of u
    a_n = torch.ones_like(u)  # a_n accumulates the coefficient product
    # Initialize the series sum.
    f = a_n * u_n
    
    # Sum the series for max_terms iterations.
    for n in range(max_terms):
        u_n = u_n * u  # Increase power: u^(n+1)
        # Compute the multiplier factor elementwise.
        # Note: the operations are broadcasted against t.
        num = (2 * n + 4) - 2 * t
        den = (2 * n + 4) - t
        a_n = a_n * (num / den)
        f = f + a_n * u_n
        
    return f



import torch

class PEMD(MassComponent):
    """
    Power-law elliptical mass distribution (PEMD).

    In this version, the Einstein radius is computed on the fly
    from the velocity dispersion and the precomputed angular distances.
    
    Input tensor is expected to be of shape [B, 8]:
      [
        [slope, pos_x, pos_y, orient, q, vel_disp, D_s, D_ls],
        ...
      ]
    where:
      - slope (t) is the power-law slope,
      - pos is the lens center,
      - orient is the lens orientation in radians,
      - q is the axis ratio,
      - vel_disp is the velocity dispersion in km/s,
      - D_s is the source angular diameter distance (in Mpc),
      - D_ls is the lens-to-source angular diameter distance (in Mpc).
    """
    def __init__(self, input_tensor, device="cuda"):
        super().__init__()
        self.device = device
        self.input_tensor = input_tensor.to(self.device)
        
        # Extract batched parameters.
        self.slope    = self.input_tensor[:, 0]        # [B]
        self.pos      = self.input_tensor[:, 1:3]        # [B, 2]
        self.th       = self.input_tensor[:, 3]          # [B] orientation (radians)
        self.q        = self.input_tensor[:, 4]          # [B]
        self.vel_disp = self.input_tensor[:, 5]          # [B] in km/s
        self.D_s      = self.input_tensor[:, 6]          # [B] in Mpc
        self.D_ls     = self.input_tensor[:, 7]          # [B] in Mpc
        
    def deflection_angle(self, lens_grid, precomp_dict=None, z_source=None):
        """
        Compute the batched deflection angle for the PEMD profile.
        
        Parameters:
          lens_grid: Tensor of shape [B, H, W, 2] representing the angular grid
                     in the lens plane.
          precomp_dict: Dummy argument (distances are provided in the input tensor).
          z_source: Dummy argument.
          
        Returns:
          A tensor of deflection angles of shape [B, H, W, 2].
        """
        # Compute Einstein radius from velocity dispersion.
        # For a SIS model: θ_E = 4π (vel_disp/c)² (D_ls/D_s)
        # We assume c ~ 3e5 km/s.
        c = torch.as_tensor(3e5, device=self.device)  # km/s
        
        # Reshape distances for broadcasting over the spatial grid.
        D_s  = self.D_s.view(-1, 1, 1)   # [B, 1, 1] in Mpc
        D_ls = self.D_ls.view(-1, 1, 1)  # [B, 1, 1] in Mpc
        vel_disp = self.vel_disp.view(-1, 1, 1)  # [B, 1, 1] in km/s
        
        theta_E = 4 * torch.pi * (vel_disp / c)**2 * (D_ls / D_s)  # [B, 1, 1]
        
        # Compute the scale parameter b.
        # The relation is taken as b = θ_E * √q.
        q = self.q.view(-1, 1, 1)  # [B, 1, 1]
        b = theta_E * torch.sqrt(q)  # [B, 1, 1]
        
        # Prepare the slope parameter for broadcasting.
        t = self.slope.view(-1, 1, 1)  # [B, 1, 1]
        
        # Rotate coordinates:
        # Subtract lens center and convert to complex form.
        pos = self.pos.view(-1, 1, 1, 2)  # [B, 1, 1, 2]
        diff = lens_grid - pos           # [B, H, W, 2]
        z = torch.view_as_complex(diff)  # [B, H, W]
        
        # Apply orientation rotation.
        th = self.th.view(-1, 1, 1)       # [B, 1, 1]
        crot = torch.exp(-1j * th)
        z = crot * z  # rotated coordinates
        
        # Compute the elliptical coordinate:
        # rs² = q² * (Re(z))² + (Im(z))².
        rs2 = (q**2) * (z.real**2) + (z.imag**2)  # [B, H, W]
        rs = torch.sqrt(rs2)
        
        # Compute factor A = b² / (q * z) * (b / rs)^(t - 2)
        A = (b**2) / (q * z) * (b / rs)**(t - 2)
        
        # Compute the hypergeometric series term.
        F = _hyp2f1_series(z, rs2, t, q)
        
        # Compute the complex deflection angle and rotate back.
        alpha_complex = torch.conj(A * F * crot)
        
        # Convert the complex deflection to a real 2-vector field.
        alpha_real = torch.view_as_real(alpha_complex)
        return alpha_real



In [2]:
import torch
import torch.nn as nn
from collections import defaultdict



class LensModel(nn.Module):
    def __init__(self, config_list, precomp_dict_list, device="cuda", dtype=torch.float32):
        """
        Initialize the LensModel with batched mass components.
        
        Parameters:
            config_list: List of dictionaries, each containing configuration for one lens system
            precomp_dict_list: List of dictionaries with precomputed quantities for each system
            device: Device to run computations on
            dtype: Data type for tensor computations (torch.float32, torch.float16, etc.)
        """
        super().__init__()
        self.device = device
        self.dtype = dtype
        self.num_systems = len(config_list)
        self.precomp_dict_list = precomp_dict_list
        
        # Initialize component classes mapping
        self.component_classes = {
            "SIS": SIS,
            "NFW": NFW,
            "ExternalPotential": ExternalPotential,
            "PEMD": PEMD
        }
        
        # Process configurations and create batched components
        self.process_configs(config_list)
        
    def process_configs(self, config_list):
        """
        Process configuration list and create batched components
        """
        # Group components by type
        component_params = defaultdict(list)
        self.system_indices = defaultdict(list)
        
        # Collect all components and their system indices
        for sys_idx, config in enumerate(config_list):
            mass_components = config.get("mass_components", [])
            for comp_config in mass_components:
                comp_type = comp_config["type"]
                # Skip unknown component types
                if comp_type not in self.component_classes:
                    continue
                    
                # Add params and system index
                component_params[comp_type].append((comp_config["params"], self.precomp_dict_list[sys_idx]))
                self.system_indices[comp_type].append(sys_idx)
        
        # Create batched components
        self.components = {}
        for comp_type, params_list in component_params.items():
            # Build parameter tensor for this component type
            param_tensor = self.build_param_tensor(comp_type, params_list)
            
            # Create the component, passing both device and dtype
            self.components[comp_type] = self.component_classes[comp_type](
                param_tensor, device=self.device, dtype=self.dtype)
            
            # Convert system indices to tensor
            self.system_indices[comp_type] = torch.tensor(
                self.system_indices[comp_type], device=self.device)
    
    def build_param_tensor(self, comp_type, params_list):
        """
        Build parameter tensor for a specific component type with the specified dtype
        """
        param_rows = []
        
        # Process the parameters based on component type
        if comp_type == "SIS":
            # Format: [x, y, redshift, vel_disp, D_ls, D_s]
            for params, precomp in params_list:
                param_rows.append([
                    params["pos"][0], params["pos"][1], 
                    params["redshift"], params["vel_disp"],
                    precomp["D_ls"], precomp["D_s"]
                ])
                
        elif comp_type == "NFW":
            # Format: [x, y, mass_max, r_max_kpc, D_l, D_s, D_ls]
            for params, precomp in params_list:
                param_rows.append([
                    params["pos"][0], params["pos"][1],
                    params["mass_max"], params["r_max_kpc"],
                    precomp["D_l"], precomp["D_s"], precomp["D_ls"]
                ])
                
        elif comp_type == "ExternalPotential":
            # Format: [shear_center_x, shear_center_y, shear_strength, shear_angle_arcsec]
            for params, _ in params_list:
                param_rows.append([
                    params["shear_center"][0], params["shear_center"][1],
                    params["shear_strength"], params["shear_angle_arcsec"]
                ])
                
        elif comp_type == "PEMD":
            # Format: [slope, pos_x, pos_y, orient, q, vel_disp, D_s, D_ls]
            for params, precomp in params_list:
                param_rows.append([
                    params["slope"], params["pos"][0], params["pos"][1], 
                    params["orient"], params["q"], params["vel_disp"],
                    precomp["D_s"], precomp["D_ls"]
                ])
        
        # Convert to tensor with specified data type
        return torch.tensor(param_rows, device=self.device, dtype=self.dtype)
    
    def deflection_field(self, lens_grid):
        """
        Compute the deflection field for all systems
        """
        # Ensure input grid is in the correct data type
        if lens_grid.dtype != self.dtype:
            lens_grid = lens_grid.to(dtype=self.dtype)
            
        # Initialize deflection field for all systems
        batch_size, H, W, _ = lens_grid.shape
        total_deflection = torch.zeros_like(lens_grid, dtype=self.dtype)
        
        # Calculate deflections for each component type
        for comp_type, component in self.components.items():
            # Get system indices for this component type
            sys_indices = self.system_indices[comp_type]
            
            # Extract the relevant grid for each component instance
            comp_grid = lens_grid[sys_indices]
            
            # Calculate deflection for this component type
            comp_deflection = component.deflection_angle(comp_grid, z_source=None)
            
            # Add the deflection to the corresponding systems
            for i, sys_idx in enumerate(sys_indices):
                total_deflection[sys_idx] += comp_deflection[i]
        
        return total_deflection
    
    def forward(self, lens_grid):
        """
        Calculate the source plane positions for all lens systems
        """
        # Ensure input grid is in the correct data type
        if lens_grid.dtype != self.dtype:
            lens_grid = lens_grid.to(dtype=self.dtype)
            
        # Calculate the deflection field
        deflection = self.deflection_field(lens_grid)
        
        # Calculate source plane positions
        source_grid = lens_grid - deflection
        
        return source_grid

In [3]:
import torch
import torch.nn as nn

def _translate_rotate(x, xc, th_rad):
    # Function remains unchanged as it already supports batched operations
    return torch.view_as_real(torch.exp(-1j*th_rad)*torch.view_as_complex(x-xc))

class SourceModel(nn.Module):
    """
    The only initial constraint is the redshift of the source.
    """
    def __init__(self):
        super().__init__()

    def forward(self, source_grid):
        """
        Compute the brightness of the source on the source grid.
        """
        raise NotImplementedError

class GaussianBlob(SourceModel):
    """
    Memory-efficient Gaussian blob source model
    """
    def __init__(self, config_dict, precomp_dict, device):
        super().__init__()
        self.device = device
        
        # 1. Convert all parameters to tensors on device during initialization
        if isinstance(config_dict["I"], torch.Tensor):
            self.I = config_dict["I"].to(device)
        else:
            self.I = torch.tensor(config_dict["I"], device=device, dtype=torch.float32)
            
        if isinstance(config_dict["position_rad"], torch.Tensor):
            self.position = config_dict["position_rad"].to(device)
        else:
            self.position = torch.tensor(config_dict["position_rad"], device=device, dtype=torch.float32)
            
        if isinstance(config_dict["orient_rad"], torch.Tensor):
            self.orient_rad = config_dict["orient_rad"].to(device)
        else:
            self.orient_rad = torch.tensor(config_dict["orient_rad"], device=device, dtype=torch.float32)
            
        if isinstance(config_dict["q"], torch.Tensor):
            self.q = config_dict["q"].to(device)
        else:
            self.q = torch.tensor(config_dict["q"], device=device, dtype=torch.float32)
            
        if isinstance(config_dict["std_kpc"], torch.Tensor):
            self.std_kpc = config_dict["std_kpc"].to(device)
        else:
            self.std_kpc = torch.tensor(config_dict["std_kpc"], device=device, dtype=torch.float32)
            
        # Convert D_s to tensor if needed
        if isinstance(precomp_dict["D_s"], torch.Tensor):
            self.D_s = precomp_dict["D_s"].to(device) * 1000.0  # Mpc to kpc
        else:
            self.D_s = torch.tensor(precomp_dict["D_s"] * 1000.0, device=device, dtype=torch.float32)
            
        # Pre-compute std_rad
        self.std_rad = self.std_kpc / self.D_s
        
        # 2. Constants for calculations
        self.half = torch.tensor(0.5, device=device, dtype=torch.float32)
        
        # 3. Pre-allocate buffers for calculations
        self._buffers = {
            # Parameter expansion buffers
            'position_expanded': None,
            'orient_rad_expanded': None,
            'q_expanded': None,
            'std_rad_expanded': None,
            'I_expanded': None,
            
            # Calculation buffers
            'diff': None,        # For translate operation
            'complex_z': None,   # For rotation operation
            'rotated_z': None,   # For rotation result
            'xrel': None,        # For rotated coordinates
            'rs2': None,         # For squared distances
            'exp_term': None,    # For exponent calculation
            'sb': None           # For surface brightness
        }
        
        # 4. Pre-allocate rotation factors
        self._batch_sizes_seen = set()
        self._rotation_factors = {}  # Cache for exp(-i*theta) for different batch sizes
    
    def _ensure_buffers(self, batch_size, height, width):
        """Initialize or resize buffers based on input dimensions"""
        grid_shape = (batch_size, height, width)
        
        if (self._buffers['position_expanded'] is None or 
            self._buffers['position_expanded'].shape[0] != batch_size):
            
            # Parameter expansion buffers
            self._buffers['position_expanded'] = torch.empty(
                (batch_size, 1, 1, 2), device=self.device, dtype=torch.float32)
            self._buffers['orient_rad_expanded'] = torch.empty(
                (batch_size, 1, 1), device=self.device, dtype=torch.float32)
            self._buffers['q_expanded'] = torch.empty(
                (batch_size, 1, 1), device=self.device, dtype=torch.float32)
            self._buffers['std_rad_expanded'] = torch.empty(
                (batch_size, 1, 1), device=self.device, dtype=torch.float32)
            self._buffers['I_expanded'] = torch.empty(
                (batch_size, 1, 1), device=self.device, dtype=torch.float32)
            
            # Translation and rotation buffers
            self._buffers['diff'] = torch.empty(
                (batch_size, height, width, 2), device=self.device, dtype=torch.float32)
            self._buffers['complex_z'] = torch.empty(
                (batch_size, height, width), device=self.device, dtype=torch.complex64)
            self._buffers['rotated_z'] = torch.empty(
                (batch_size, height, width), device=self.device, dtype=torch.complex64)
            self._buffers['xrel'] = torch.empty(
                (batch_size, height, width, 2), device=self.device, dtype=torch.float32)
            
            # Brightness calculation buffers
            self._buffers['rs2'] = torch.empty(
                grid_shape, device=self.device, dtype=torch.float32)
            self._buffers['exp_term'] = torch.empty(
                grid_shape, device=self.device, dtype=torch.float32)
            self._buffers['sb'] = torch.empty(
                grid_shape, device=self.device, dtype=torch.float32)
            
            # Cache rotation factor for this batch size
            if batch_size not in self._batch_sizes_seen:
                self._batch_sizes_seen.add(batch_size)
                self._rotation_factors[batch_size] = None
    
    def _get_rotation_factor(self, batch_size, orient_rad):
        """Get or compute rotation factor exp(-i*theta)"""
        if self._rotation_factors[batch_size] is None:
            self._rotation_factors[batch_size] = torch.empty(
                (batch_size, 1, 1), device=self.device, dtype=torch.complex64)
        
        # Compute rotation factor (in-place when possible)
        neg_i_theta = -1j * orient_rad
        torch.exp(neg_i_theta, out=self._rotation_factors[batch_size])
        
        return self._rotation_factors[batch_size]
    
    def _optimized_translate_rotate(self, source_grid, position, orient_rad):
        """Memory-efficient implementation of translate and rotate"""
        batch_size, height, width, _ = source_grid.shape
        
        # 1. Translation (in-place)
        torch.sub(source_grid, position, out=self._buffers['diff'])
        
        # 2. View as complex for rotation
        self._buffers['complex_z'] = torch.view_as_complex(
            self._buffers['diff'].contiguous().view(batch_size, height, width, 1, 2)
            .squeeze(-2))
        
        # 3. Get rotation factor and multiply (in-place)
        rot_factor = self._get_rotation_factor(batch_size, orient_rad)
        torch.mul(self._buffers['complex_z'], rot_factor, out=self._buffers['rotated_z'])
        
        # 4. Convert back to real coordinates
        self._buffers['xrel'] = torch.view_as_real(self._buffers['rotated_z'])
        
        return self._buffers['xrel']
    
    def forward(self, source_grid):
        """Memory-efficient implementation of Gaussian blob"""
        # Get dimensions
        batch_size, height, width, _ = source_grid.shape
        
        # 1. Ensure buffers are allocated with correct shape
        self._ensure_buffers(batch_size, height, width)
        
        # 2. Expand parameters efficiently (in-place)
        # Expand position [2] → [B, 1, 1, 2]
        pos_expanded = self._buffers['position_expanded']
        pos_expanded.copy_(self.position.expand(batch_size, 2).view(batch_size, 1, 1, 2))
        
        # Expand scalars [1] → [B, 1, 1]
        orient_expanded = self._buffers['orient_rad_expanded']
        orient_expanded.copy_(self.orient_rad.expand(batch_size).view(batch_size, 1, 1))
        
        q_expanded = self._buffers['q_expanded']
        q_expanded.copy_(self.q.expand(batch_size).view(batch_size, 1, 1))
        
        std_expanded = self._buffers['std_rad_expanded']
        std_expanded.copy_(self.std_rad.expand(batch_size).view(batch_size, 1, 1))
        
        I_expanded = self._buffers['I_expanded']
        I_expanded.copy_(self.I.expand(batch_size).view(batch_size, 1, 1))
        
        # 3. Calculate rotated coordinates efficiently
        xrel = self._optimized_translate_rotate(
            source_grid, pos_expanded, orient_expanded)
        
        # 4. Calculate squared distance (in-place)
        rs2 = self._buffers['rs2']
        rs2.zero_()
        
        # rs2 = q²·x² + y²
        torch.pow(xrel[..., 0], 2, out=rs2)
        torch.mul(rs2, q_expanded**2, out=rs2)
        torch.addcmul(rs2, xrel[..., 1], xrel[..., 1], value=1.0, out=rs2)
        
        # 5. Calculate exponential term (in-place)
        exp_term = self._buffers['exp_term']
        torch.div(rs2, std_expanded**2, out=exp_term)
        torch.mul(exp_term, -self.half, out=exp_term)
        torch.exp(exp_term, out=exp_term)
        
        # 6. Calculate surface brightness (in-place)
        sb = self._buffers['sb']
        torch.mul(exp_term, I_expanded, out=sb)
        
        # Store for reference
        self.surface_brightness = sb
        
        return sb
    
    def __del__(self):
        """Clean up buffers when object is deleted"""
        # Clear all buffers
        for key in self._buffers:
            self._buffers[key] = None
        
        # Clear rotation factors
        self._rotation_factors.clear()
        self._batch_sizes_seen.clear()

In [4]:
import torch
import torch.nn as nn

from shared_utils import recursive_to_tensor


class LensingSystem(nn.Module):
    """
    Lensing system that handles multiple lens configurations in a batch
    """
    def __init__(self, config_list, device, dtype=torch.float32):
        """
        Initialize the lensing system with multiple configurations
        
        Parameters:
            config_list: List of configuration dictionaries, one per lens system
            device: Device to run computations on
        """
        super().__init__()
        self.device = device
        self.num_systems = len(config_list)
        self.dtype = dtype
        self.precision_sensitive_ops_dtype = torch.float32  #
        
        # Process configs and move to device
        self.config_list = [recursive_to_tensor(config, device, datatype=dtype) for config in config_list]
        
        # Extract lens model configs and precomputed values
        lens_model_list = [config["lens_model"] for config in self.config_list]
        precomp_dict_list = [config["precomputed"] for config in self.config_list]
        
        # Create a single lens model for all systems
        self.lens_model = LensModel(lens_model_list, precomp_dict_list, device=device, dtype=dtype)
        
        # Create source models for each system
        self.source_models = nn.ModuleList()
        source_mapping = {
            "Gaussian_blob": GaussianBlob,
        }
        
        # Create a source model for each system
        for config in self.config_list:
            source_type = config["source_model"]["type"]
            if source_type not in source_mapping:
                raise ValueError(f"Unknown source model type: {source_type}")
                
            source_model = source_mapping[source_type](
                config["source_model"]["params"],
                precomp_dict=config["precomputed"],
                device=device
            )
            self.source_models.append(source_model)
            
        # Store source redshifts
        self.source_redshifts = [config["source_model"]["params"]["redshift"] for config in self.config_list]
        
    def forward(self, lens_grid):
        """
        Forward pass for all lens systems
        
        Parameters:
            lens_grid: Tensor of shape [B, H, W, 2] where B is the batch size (number of systems)
            
        Returns:
            Tensor of shape [B, H, W] with image plane intensities for all systems
        """
        if lens_grid.dtype != self.dtype:
            lens_grid = lens_grid.to(dtype=self.dtype)
            print("converting lens grid to dtype", self.dtype)
            
        # Get source plane positions for all systems
        source_grid = self.lens_model(lens_grid)
        
        # Initialize output tensor
        output_shape = source_grid.shape[:-1]  # [B, H, W]
        output = torch.zeros(output_shape, device=self.device)
        
        # Process each system with its source model
        for i, source_model in enumerate(self.source_models):
            if i < lens_grid.shape[0]:  # Only process within batch size
                system_source_grid = source_grid[i:i+1]  # [1, H, W, 2]
                output[i] = source_model(system_source_grid).squeeze(0)
        
        return output

In [5]:
import json
from shared_utils import recursive_to_tensor, _grid_lens
from matplotlib import pyplot as plt
import torch

device="cuda"


#Print detailed memory information


def track_memory(label=""):
    print(f"\n--- MEMORY: {label} ---")
    print(torch.cuda.memory_summary())
    print(f"Allocated: {torch.cuda.memory_allocated()/1e6:.2f} MB")
    print(f"Cached: {torch.cuda.memory_reserved()/1e6:.2f} MB")


one_image_dict={
            "system_index": 0,
            "precomputed": {
                "D_l": 1717.9862002612902,
                "D_s": 1518.4022740256946,
                "D_ls": 761.5082718989717,
                "Theta_E": 6.163572379738989e-06
            },
            "lens_model": {
                "num_substructures": 1,
                "mass_components": [
                    {
                        "type": "SIS",
                        "is_substructure": False,
                        "params": {
                            "pos": [
                                0.0,
                                0.0
                            ],
                            "redshift": 1.0554628084028788,
                            "vel_disp": 296.47503651129796
                        }
                    },
                    {
                        "type": "NFW",
                        "is_substructure": True,
                        "params": {
                            "pos": [
                                0.0,
                                0.0
                            ],
                            "mass_max": 100000000000.0,
                            "r_max_kpc": 1.0,
                            "redshift": 1.0554628084028788
                        }
                    }
                ]
            },
            "source_model": {
                "type": "Gaussian_blob",
                "params": {
                    "I": 1.0,
                    "position_rad": [
                        0.0,
                        0.0
                    ],
                    "orient_rad": 0.0,
                    "q": 0.8,
                    "std_kpc": 0.1,
                    "redshift": 3.6654574221282386
                }
            }
        }



#one_image_dict=recursive_to_tensor(one_image_dict, device=device)

batch_size=30
grid = _grid_lens(6, 1000, device=device, dtype=torch.float32)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1) if len(grid.shape) == 3 else grid.repeat(batch_size, 1, 1, 1)

from catalog_manager import CatalogManager
my_catalog = CatalogManager(catalog_name_input="SIS_10e7_sub_train")
config_list=my_catalog.catalog["SL_systems"][0: batch_size]
config_list=[recursive_to_tensor(config, device) for config in config_list]
# config_list=[]
# for i in range(batch_size):
#     config_list.append(one_image_dict)
# #print(json.dumps(one_image_dict, indent=4))

config_list=[recursive_to_tensor(config, device) for config in config_list]

def run():
    #with torch.no_grad():
    lensing_system=LensingSystem(config_list, device=device, dtype=torch.float32)
    image=lensing_system(grid)
    return image
image=run()

print(torch.cuda.memory_summary())
torch.cuda.reset_peak_memory_stats()

torch.cuda.reset_accumulated_memory_stats()  
torch.cuda.empty_cache() 

plt.figure(figsize=(15, 15))
#plt.imshow(image[1].cpu().detach().numpy())
print(image.shape)



# batch_sizes = [1, 10, 20, 30, 40]
# for batch_size in batch_sizes:
#     torch.cuda.empty_cache()  # Clear cache before each test
#     track_memory(f"Before batch_size={batch_size}")
    
#     # Create system with batch_size
#     grid = _grid_lens(6, 1000, device=device)
#     grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)
#     config_list = [one_image_dict] * batch_size
    
#     lensing_system = LensingSystem(config_list, device=device)
#     track_memory(f"After system creation batch_size={batch_size}")
    
#     image = lensing_system(grid)
#     track_memory(f"After forward pass batch_size={batch_size}")
    
#     del lensing_system, grid, image, config_list


# batch_sizes = [1, 10, 20, 30, 40]
# for batch_size in batch_sizes:
#     # Create system with batch_size
#     grid = _grid_lens(6, 100, device=device)
#     grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)
#     config_list = [one_image_dict] * batch_size
    
#     lensing_system = LensingSystem(config_list, device=device)
    
#     image = lensing_system(grid)
    
#     del lensing_system, grid, image, config_list

TypeError: 'ScriptMethodStub' object is not callable

In [None]:
import torch
import time
import matplotlib.pyplot as plt
from shared_utils import _grid_lens

# Define batch sizes to test
#batch_sizes = [1, 10, 20, 40, 60, 80, 100, 150, 200, 400, 512, 1024]
batch_sizes = torch.arange(1, 2, 2)
# Arrays to store timing results
forward_times = []
device="cuda"

from catalog_manager import CatalogManager


my_catalog = CatalogManager(catalog_name_input="SIS_10e7_sub_train")
config_list=my_catalog.catalog["SL_systems"][0:]


# Loop through batch sizes
for batch_size in batch_sizes:
    #print(f"Testing batch size {batch_size}...")
    
    # Clear cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Create system with batch_size
    grid = _grid_lens(6, 1000, device=device)
    grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)
    config_list=my_catalog.catalog["SL_systems"][0: batch_size]
    
    lensing_system = LensingSystem(config_list, device=device)
    
    # Time the forward pass
    start_time = time.time()
    image = lensing_system(grid)
    torch.cuda.synchronize()  # Make sure GPU completes
    elapsed = time.time() - start_time
    
    # Store timing result
    forward_times.append(elapsed)
    #print(f"  Forward pass took {elapsed:.4f} seconds")
    
    # Clean up
    del lensing_system, grid, image, config_list

# Plot results
# plt.figure(figsize=(10, 6))
# plt.plot(batch_sizes, forward_times, 'o-')
# plt.xlabel('Batch Size')
# plt.ylabel('Time (seconds)')
# plt.title('Forward Pass Time vs Batch Size')
# plt.grid(True)
# plt.savefig('batch_timing.png')
# plt.show()


batch_size = 2
grid = _grid_lens(6, 2000, device=device)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)
config_list = [one_image_dict] * batch_size

lensing_system = LensingSystem(config_list, device=device)

# Time the forward pass
%prun image = lensing_system(grid)

          308 function calls (300 primitive calls) in 0.139 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.079    0.079    0.133    0.133 271530330.py:115(deflection_field)
        1    0.033    0.033    0.051    0.051 2291391825.py:283(deflection_angle)
        2    0.015    0.008    0.015    0.008 {built-in method torch.tensor}
       23    0.001    0.000    0.001    0.000 {method 'copy_' of 'torch._C.TensorBase' objects}
        1    0.001    0.001    0.002    0.002 2291391825.py:85(deflection_angle)
       44    0.001    0.000    0.001    0.000 {built-in method torch.empty}
        2    0.001    0.001    0.005    0.002 2768220853.py:165(forward)
        7    0.001    0.000    0.001    0.000 {method 'pow' of 'torch._C.TensorBase' objects}
        1    0.000    0.000    0.139    0.139 2877471766.py:57(forward)
       28    0.000    0.000    0.000    0.000 {method 'view' of 'torch._C.TensorBase' objects}
  

In [None]:
lensing_system = LensingSystem(config_list, device=device)

In [None]:
import json
from shared_utils import recursive_to_tensor, _grid_lens
from matplotlib import pyplot as plt
import torch
from lensing_system import LensingSystem as LensingSystemNonBatched

device="cuda"



#one_image_dict=recursive_to_tensor(one_image_dict, device=device)

batch_size=2000
grid = _grid_lens(6, 100, device=device, dtype=torch.float32)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1) if len(grid.shape) == 3 else grid.repeat(batch_size, 1, 1, 1)

from catalog_manager import CatalogManager
import time
my_catalog = CatalogManager(catalog_name_input="SIS_10e7_sub_train")
config_list=my_catalog.catalog["SL_systems"][0: batch_size]
config_list=[recursive_to_tensor(config, device) for config in config_list]


def run_batched():
    print("running batched")
    #empty gpu cache
    torch.cuda.empty_cache()
    #with torch.no_grad():
    %prun lensing_system=LensingSystem(config_list, device=device, dtype=torch.float32)
    image=lensing_system(grid)
    print("finished batched")
    return image

def run_non_batched():
    print("running non batched")
    #empty gpu cache
    torch.cuda.empty_cache()
    #with torch.no_grad():
    images=[]
    for i in range(batch_size):
        lensing_system=LensingSystemNonBatched(config_list[i], device=device)
        images.append(lensing_system(grid[i]))
    #stack the images
    image=torch.stack(images)
    print("finished non batched")
    return image


# Measure running time for batched execution
num_repeats = 1
batched_times = []
non_batched_times = []

for _ in range(num_repeats):
    # Measure running time for batched execution
    start_time_batched = time.time()
    images = run_batched()
    elapsed_time_batched = time.time() - start_time_batched
    batched_times.append(elapsed_time_batched)

    # Measure running time for non-batched execution
    start_time_non_batched = time.time()
    images_non_batched = run_non_batched()
    elapsed_time_non_batched = time.time() - start_time_non_batched
    non_batched_times.append(elapsed_time_non_batched)

# Calculate average execution times
avg_batched_time = sum(batched_times) / num_repeats
avg_non_batched_time = sum(non_batched_times) / num_repeats

print(f"Average batched execution time: {avg_batched_time:.4f} seconds")
print(f"Average non-batched execution time: {avg_non_batched_time:.4f} seconds")

plt.figure(figsize=(4, 4))
#plt.imshow(images[0].cpu().detach().numpy())
print(image.shape)

running batched
 finished batched
running non batched
finished non batched
Average batched execution time: 8.5191 seconds
Average non-batched execution time: 11.0849 seconds
torch.Size([2, 2000, 2000])


<Figure size 400x400 with 0 Axes>

         593538 function calls (494995 primitive calls) in 3.367 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2015    1.298    0.001    1.298    0.001 {built-in method torch.tensor}
    42282    0.554    0.000    0.555    0.000 {built-in method torch.as_tensor}
     2005    0.452    0.000    0.470    0.000 module.py:472(__init__)
     2000    0.328    0.000    1.585    0.001 2768220853.py:25(__init__)
    24049    0.139    0.000    0.276    0.000 module.py:1897(__setattr__)
60376/2000    0.130    0.000    0.793    0.000 util.py:4(recursive_to_tensor)
237295/213222    0.074    0.000    0.156    0.000 {built-in method builtins.isinstance}
    24049    0.066    0.000    0.089    0.000 parameter.py:10(__instancecheck__)
16094/2000    0.062    0.000    0.788    0.000 util.py:7(<dictcomp>)
    12007    0.053    0.000    0.053    0.000 {method 'to' of 'torch._C.TensorBase' objects}
        2    0.039    0.019    0.878    0

asdf§