# Schwarzschild Black Hole (From Video Tutorial)

This notebook is based on the YouTube tutorial: [Black Hole Ray Tracing](https://www.youtube.com/watch?v=8-B6ryuBkCM)

It provides a streamlined implementation of the Schwarzschild black hole visualization.

## Quick Start Guide

This notebook is optimized for quick experimentation. For detailed physics explanations, see `BlackHoleAccretion.ipynb`.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numba import njit, prange
import time

# Quick configuration
WIDTH, HEIGHT = 800, 600
FOV = 60.0
RS = 1.0
R_ISCO = 3.0 * RS
R_DISK_MAX = 12.0 * RS
CAM_DIST = 30.0
CAM_PITCH = 25.0
STEP_SIZE = 0.05
MAX_STEPS = 5000

## Core Physics Functions

In [None]:
@njit
def get_derivatives(state):
    r, r_dot, phi, phi_dot = state
    phi_ddot = -2.0 / r * r_dot * phi_dot
    r_ddot = r * (phi_dot ** 2) * (1.0 - (1.5 * RS) / r)
    return np.array([r_dot, r_ddot, phi_dot, phi_ddot])

@njit
def rk4_step(state, h):
    k1 = get_derivatives(state)
    k2 = get_derivatives(state + 0.5 * h * k1)
    k3 = get_derivatives(state + 0.5 * h * k2)
    k4 = get_derivatives(state + h * k3)
    return state + (h / 6.0) * (k1 + 2*k2 + 2*k3 + k4)

@njit
def texture_lookup(r, phi, x_hit, y_hit, z_hit):
    ring_noise = 0.6 * np.sin(r * 2.0) + 0.4 * np.sin(r * 5.0)
    brightness = 0.6 + 0.4 * ring_noise
    doppler = 1.0 + 0.5 * (x_hit / r)
    dist_norm = max(0.0, min(1.0, (r - R_ISCO) / (R_DISK_MAX - R_ISCO)))
    return np.array([
        1.0 * brightness * doppler,
        (0.8 - 0.6 * dist_norm) * brightness * doppler,
        (0.3 - 0.3 * dist_norm) * brightness * doppler
    ])

@njit
def ray_march(ray_origin, ray_dir):
    pos, vel = ray_origin, ray_dir
    L = np.cross(pos, vel)
    L_norm = np.linalg.norm(L)
    if L_norm < 1e-6:
        return np.array([0.0, 0.0, 0.0])
    
    r_hat = pos / np.linalg.norm(pos)
    n_hat = L / L_norm
    phi_hat = np.cross(n_hat, r_hat)
    
    r_init = np.linalg.norm(pos)
    state = np.array([r_init, np.dot(vel, r_hat), 0.0, np.dot(vel, phi_hat) / r_init])
    
    for i in range(MAX_STEPS):
        old_r, old_phi = state[0], state[2]
        state = rk4_step(state, STEP_SIZE)
        new_r, new_phi = state[0], state[2]
        
        if new_r < RS:
            return np.array([0.0, 0.0, 0.0])
        if new_r > 50.0:
            return np.array([0.02, 0.02, 0.05])
        
        pos_3d_new = new_r * (np.cos(new_phi) * r_hat + np.sin(new_phi) * phi_hat)
        pos_3d_old = old_r * (np.cos(old_phi) * r_hat + np.sin(old_phi) * phi_hat)
        y0, y1 = pos_3d_old[1], pos_3d_new[1]
        
        if (y0 > 0 and y1 < 0) or (y0 < 0 and y1 > 0):
            fraction = abs(y0) / (abs(y0) + abs(y1))
            hit_r = old_r + (new_r - old_r) * fraction
            if R_ISCO < hit_r < R_DISK_MAX:
                hit_pos = pos_3d_old + (pos_3d_new - pos_3d_old) * fraction
                return texture_lookup(hit_r, 0, hit_pos[0], hit_pos[1], hit_pos[2])
    
    return np.array([0.0, 0.0, 0.0])

@njit(parallel=True)
def render_image(width, height, cam_pos, cam_target, fov):
    image = np.zeros((height, width, 3))
    aspect_ratio = width / height
    fov_rad = np.radians(fov)
    half_height = np.tan(fov_rad / 2.0)
    half_width = aspect_ratio * half_height
    
    w = (cam_target - cam_pos) / np.linalg.norm(cam_target - cam_pos)
    up = np.array([0.0, 1.0, 0.0])
    u = np.cross(w, up) / np.linalg.norm(np.cross(w, up))
    v = np.cross(u, w)
    
    for y in prange(height):
        for x in prange(width):
            ndc_x = (x + 0.5) / width * 2.0 - 1.0
            ndc_y = 1.0 - (y + 0.5) / height * 2.0
            pixel_screen_pos = cam_pos + w + u * (ndc_x * half_width) + v * (ndc_y * half_height)
            ray_dir = (pixel_screen_pos - cam_pos) / np.linalg.norm(pixel_screen_pos - cam_pos)
            image[y, x] = ray_march(cam_pos, ray_dir)
    
    return image

print("✓ All functions compiled!")

## Render and Display

In [None]:
# Setup camera
pitch_rad = np.radians(CAM_PITCH)
cam_pos = np.array([0.0, CAM_DIST * np.sin(pitch_rad), -CAM_DIST * np.cos(pitch_rad)])
cam_target = np.array([0.0, 0.0, 0.0])

# Warm up JIT
print("Compiling...")
_ = render_image(10, 10, cam_pos, cam_target, FOV)

# Render
print(f"Rendering {WIDTH}×{HEIGHT}...")
t0 = time.time()
img = render_image(WIDTH, HEIGHT, cam_pos, cam_target, FOV)
print(f"Done in {time.time()-t0:.2f}s")

# Display
plt.figure(figsize=(10, 8))
plt.imshow(np.clip(img, 0, 1))
plt.axis('off')
plt.title("Schwarzschild Black Hole\n(Based on YouTube Tutorial)", fontsize=12)
plt.tight_layout()
plt.savefig('schwarzschild_bh.jpg', dpi=150, bbox_inches='tight')
plt.show()

## References

- **YouTube Tutorial**: https://www.youtube.com/watch?v=8-B6ryuBkCM
- For detailed physics explanations, see `BlackHoleAccretion.ipynb`
- For rotating black holes, see `KerrBlackHole.ipynb`