In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from matplotlib.animation import FuncAnimation


In [2]:
# Physical constants
hbar = 1.054571817e-34  # Reduced Planck's constant (J·s)
m_e = 9.10938356e-31    # Electron mass (kg)
mu_B = 9.274009994e-24  # Bohr magneton (J/T)

# Simulation parameters
nx, ny = 200, 200       # Number of spatial grid points
dx, dy = 1e-9, 1e-9     # Spatial resolution (m)
dt = 4e-16              # Time step (s)
nt = 2800                # Number of time steps

# Magnetic field parameters
B0 = 1.5                # Base magnetic field (T)
G = 8e9                 # Magnetic field gradient (T/m)

# Initial electron parameters
x0, y0 = 50e-9, 100e-9  # Initial position (m)
sigma_x = sigma_y = 5e-9  # Initial wavepacket spread (m)
p0 = np.array([1e5 * m_e, 0.0])  # Initial momentum (kg·m/s)


In [3]:
# Create spatial grids
x = np.linspace(0, (nx - 1) * dx, nx)
y = np.linspace(0, (ny - 1) * dy, ny)
X, Y = np.meshgrid(x, y, indexing='ij')

# Convert grids to torch tensors
X_torch = torch.from_numpy(X).float()
Y_torch = torch.from_numpy(Y).float()


In [4]:
# Initialize the spinor wavefunction ψ = [ψ_up, ψ_down]^T
# Start with a Gaussian wavepacket centered at (x0, y0)
def gaussian_wavepacket(x, y, x0, y0, sigma_x, sigma_y, p0):
    exponent = -((x - x0)**2 / (2 * sigma_x**2) + (y - y0)**2 / (2 * sigma_y**2))
    phase = (p0[0] * (x - x0) + p0[1] * (y - y0)) / hbar
    envelope = torch.exp(torch.tensor(exponent, dtype=torch.float32))
    phase_factor = torch.exp(1j * torch.tensor(phase, dtype=torch.float32))
    return envelope * phase_factor

# Initial spinor components (equal superposition)
psi_up = gaussian_wavepacket(X_torch, Y_torch, x0, y0, sigma_x, sigma_y, p0) / np.sqrt(2)
psi_down = gaussian_wavepacket(X_torch, Y_torch, x0, y0, sigma_x, sigma_y, p0) / np.sqrt(2)

# Combine into a single tensor
psi = torch.stack([psi_up, psi_down], dim=-1)  # Shape: (nx, ny, 2)


  envelope = torch.exp(torch.tensor(exponent, dtype=torch.float32))
  phase_factor = torch.exp(1j * torch.tensor(phase, dtype=torch.float32))


In [5]:
# Define the kinetic operator using finite differences
kx = torch.fft.fftfreq(nx, d=dx) * 2 * np.pi
ky = torch.fft.fftfreq(ny, d=dy) * 2 * np.pi
KX, KY = torch.meshgrid(kx, ky, indexing='ij')
K_squared = KX**2 + KY**2

# Kinetic energy operator in momentum space
T_operator = torch.exp(-1j * hbar * K_squared * dt / (2 * m_e))

# Magnetic field along z with gradient in y
Bz = B0 + G * (Y_torch - y0)

# Pauli matrices
sigma_z = torch.tensor([[1, 0], [0, -1]], dtype=torch.cfloat)

# Potential energy operator due to magnetic field
V = -mu_B * Bz.unsqueeze(-1).unsqueeze(-1) * sigma_z  # Shape: (nx, ny, 2, 2)

# Exponential of the potential energy operator
V_operator = torch.linalg.matrix_exp(-1j * V * dt / hbar)


In [6]:
# Prepare figure and axis for animation
fig, ax = plt.subplots(figsize=(8, 6))
extent = (0, nx*dx*1e9, 0, ny*dy*1e9)  # Convert to nm or mm as needed

# Initialize the image with the first frame
prob_density = (psi.conj() * psi).real.sum(dim=-1)
prob_density /= prob_density.max() + 1e-16  # Normalize
im = ax.imshow(prob_density.numpy(),
               extent=extent,
               origin='lower',
               aspect='auto',
               cmap='viridis',
               vmin=0, vmax=1)  # Set color limits
ax.set_xlabel('z position (nm)')
ax.set_ylabel('y position (nm)')
title = ax.set_title('')

# Function to update each frame
def animate(t):
    global psi

    for _ in range(50):  # Update 50 steps per frame
        # Apply potential operator (spin-dependent)
        psi = torch.einsum('...ab,...b->...a', V_operator, psi)

        # Fourier transform to momentum space
        psi_ft = torch.fft.fft2(psi, dim=(0, 1))

        # Apply kinetic operator
        psi_ft[..., 0] *= T_operator
        psi_ft[..., 1] *= T_operator

        # Inverse Fourier transform back to position space
        psi = torch.fft.ifft2(psi_ft, dim=(0, 1))

    # Calculate probability density
    prob_density = (psi.conj() * psi).real.sum(dim=-1)

    # Normalize the probability density
    prob_density /= prob_density.max() + 1e-16  # Add small number to prevent division by zero

    # Update image data
    im.set_data(prob_density.numpy())
    # Color limits are already set to [0, 1], no need to update clim
    title.set_text(f'Time: {t * 50 * dt * 1e15:.1f} fs')
    return [im, title]

# Create the animation
frames = nt // 50  # Number of frames
anim = FuncAnimation(fig, animate, frames=frames, interval=50, blit=True)

# Save the animation as a GIF
anim.save('stern_gerlach_simulation.gif', writer='pillow', fps=10)

plt.close()