<a href="https://colab.research.google.com/github/juergen-ka/juergen-ka/blob/main/Partial_waves.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title QED Prerequisites Scattering 8-Partial Waves
#@markdown This notebook visualizes partial waves as demonstrated in XylyXylyX's
#@markdown video **QED Prerequisites Scattering 8-Partial Waves** https://youtu.be/aHFHAGcAj1c.<br>

#@markdown $e^{ikrcos(\theta)} = \sum^{\infty}_{l=0} i^l \;(2l + 1)\;j_l (kr)P_l cos(\theta) \qquad|\; kr=n\lambda_B$

import numpy as np
import plotly.graph_objects as go
from scipy.special import spherical_jn as jl 

#@markdown ---
#@markdown **Parameters:**<br>
#@markdown Show individual partial wave or superposition
mode = "superposition" #@param ["individual", "superposition"]
#@markdown Order L of the spherical Bessel function and Legendre polynomial
L = 10 #@param {type:"slider", min:0, max:50, step:1}


def PartialWave(l:int, x:np.ndarray, y:np.ndarray, mode:str='individual') -> np.ndarray:
  """
  Calculate partial wave for the meshgrid point given by x and y for order l and the
  propagation direction x. If mode is 'individual' calculate the partial wave
  just for the given value l, if mode is 'superposition' calculate it for the
  summation of the partial waves from 0 to l. For individual mode, the function
  returns the real part for even values of l and the imaginary part for odd values,
  the other part being 0. For superposition mode it returns the imaginary part
  with exemption of l = 0, which is not really a superposition and has only a
  real part.
  Input:
  l:      order of Bessel function and Legendre polynomial
  x       (200,200): x-values of the meshgrid
  y       (200,200): y-values of the meshgrid
  mode:   individual or superposition

  Output:
  wave    (200,200): real or imaginary part of the values at the meshgrid points
  """
  assert mode in ['individual', 'superposition'], \
  f"mode must be either 'individual' or 'superposition'. Found mode={mode}"

  theta = np.arctan2(y, x)      # theta of each point, defines x as propagation direction
  r = np.sqrt(x**2 + y**2)      # magnitude of r of each point(x,y)

  if mode == 'individual':
    coeffs = [0]*l + [1]        # list of coefficients for Legendre polynomials
    P = np.polynomial.legendre.Legendre(coeffs)
    if l % 2 == 0:              # even order
      wave = (2 * l + 1) * jl(l, r) * P(np.cos(theta))
      title = f'Individual partial wave of order {l}. Showing real part.<br>Propagation in x-direction.'
    else:                       # odd order
      wave = (1j* (2 * l + 1) * jl(l, r) * P(np.cos(theta))).imag
      title = f'Individual partial wave of order {l}. Showing imaginary part.<br>Propagation in x-direction.'

  elif mode == 'superposition':
    wave = np.zeros_like(x, dtype=complex)
    if l == 0:
      P = np.polynomial.legendre.Legendre([1])
      wave = jl(0, r) * P(np.cos(theta)).real
      title = f'Individual partial wave of order 0. Showing real part.<br>Propagation in x-direction.'
    else:
      for li in range(l+1):      
        coeffs = [0]*li + [1]
        P = np.polynomial.legendre.Legendre(coeffs)
        wave += (1j)**li * (2 * li + 1) * jl(li, r) * P(np.cos(theta))
      wave = wave.imag
      title = f'Superposition of partial waves of order 0 to order {l}. Showing imaginary part.<br>Propagation in x-direction.'
  
  return wave, title

# -------------------------------- MAIN ----------------------------------------
x = np.linspace(-50, 50, 200)
y = np.linspace(-50, 50, 200)
x, y = np.meshgrid(x, y)
z, title = PartialWave(L, x, y, mode=mode)

fig = go.Figure(
    data=[
        go.Surface(
            x=x,
            y=y,
            z=z,
            colorscale='YlOrBr',
            cmin=-0.5, cmax= 1.,
            showscale=False
        )
    ],
    layout=dict(
        title=title,
        height=1000,
        width=1000,
        scene=dict(
            xaxis=dict(title=r'x = kr = nλ<sub>B</sub>', showbackground=False, showticklabels=False),
            yaxis=dict(title=r'y = kr = nλ<sub>B</sub>',showbackground=False, showticklabels=False),
            zaxis=dict(range=[-2*np.max(z), 2*np.max(z)], showbackground=False, showticklabels=False)
            ),
        margin=dict(l=0, r=400),
    )
)
fig.show()