# Simulating a virtual Soft X-ray Tomography system

In [1]:
import numpy as np
from numpy.linalg import norm, inv
from numpy.random import uniform
from numpy import sin, cos, tan, arctan2, sqrt, hypot, pi as π
import matplotlib.pyplot as plt
import seaborn
seaborn.set_style("darkgrid")  # adds seaborn style to charts, eg. grid
plt.style.use("dark_background")  # inverts colors to dark theme
plt.rcParams['font.family'] = 'monospace' # sets font to monospace
np.set_printoptions(precision=3) # set precision for printing numpy arrays
from time import time, sleep
from tqdm import tqdm

In [2]:
# parameters
KHRES = 2 # [#] multiplier for the high resolution of the grid
RES = 32*KHRES # [#] low resolution of the grid in pixels (square grid)
L = 1.0 # [m] length of the grid in the r/x direction (square grid)
R0 = 1.5 # [m] grid start in the r/x direction
Z0 = -0.5 # [m] grid start in the z/y direction
R1, Z1 = R0+L, Z0+L # [m] grid ends in the r/x and z/y direction
RM, ZM = 0.5*(R0+R1), 0.5*(Z0+Z1) # [m] grid center in the x/r and z direction
# calculated constants
R = np.linspace(R0, R1, RES)
Z = np.linspace(Z0, Z1, RES)
assert np.isclose(R1-R0, Z1-Z0), "grid must be square"
δ = L/RES # [m] grid spacing
RR, ZZ = np.meshgrid(R, Z) # create a grid of R and Z values
rr, zz = RR[::KHRES, ::KHRES], ZZ[::KHRES, ::KHRES] # create low resolution grid
RZ = np.stack((RR, ZZ), axis=-1) # create a grid of R and Z values
FW = np.array([[RM+L/2*cos(θ), ZM+L/2*sin(θ)] for θ in np.linspace(0, 2*π, 100)]) # [m] first wall

In [3]:
# helper functions
def wrap_angle(α): return np.arctan2(np.sin(α), np.cos(α))

def gaussian(v, μ=np.array([R0+L/2, Z0+L/2]), Σ=np.array([[L/4,0],[0,L/4]]), polar=False):
    rshape = v.shape[:-1] # save the original shape
    v = v.reshape(-1, 2) # flatten the input to 2D
    d = v-μ # difference vector
    if polar: d[:,1] = wrap_angle(d[:,1]) # wrap angles
    g = np.exp(-0.5*np.sum(d @ inv(Σ) * d, axis=-1)) # gaussian formula
    r = g.reshape(rshape) # return the result in the original shape
    return r

# #fake gaussian
# def gaussian(v, μ=np.array([R0+L/2, Z0+L/2]), Σ=np.array([[L/4,0],[0,L/4]]), polar=False):
#     rshape = v.shape[:-1] # save the original shape
#     v = v.reshape(-1, 2) # flatten the input to 2D
#     d = v-μ # difference vector
#     if polar: d[:,1] = wrap_angle(d[:,1]) # wrap angles
#     g = np.sqrt(0.5*np.sum(d @ inv(Σ) * d, axis=-1)) # fake gaussian formula
#     r = g.reshape(rshape) # return the result in the original shape
#     return r

def create_line(c, θ, n=10*RES):
    cθ, sθ = cos(θ), sin(θ)
    if np.abs(cθ) > np.abs(sθ): # less than 45 degrees
        x = np.linspace(R0-δ/2, R1+δ/2, n)
        y = (sθ/cθ)*x + (c[1] - (sθ/cθ)*c[0])
    else: # more than 45 degrees
        y = np.linspace(Z0-δ/2, Z1+δ/2, n)
        x = (cθ/sθ)*y + (c[0] - (cθ/sθ)*c[1])
    # keep only points inside the grid/first wall
    # idxs = (R0-δ/2 <= x) & (x <= R1+δ/2) & (Z0-δ/2 <= y) & (y <= Z1+δ/2) # inside grid
    idxs = (x-RM)**2 + (y-ZM)**2 <= (L/2+δ/2)**2 # inside first wall
    if sum(idxs) == 0: print(f'Warning: line outside: c={c}, θ={θ:.2f}')
    return np.stack((x[idxs], y[idxs]), axis=-1)

def line_mask(c, θ, n=16*RES):
    lin = create_line(c, θ, n)
    mask = np.zeros((RES, RES)).reshape(-1)
    grid = RZ.copy().reshape(-1, 2)
    # for l in lin: mask[np.argmin(norm(grid-l, axis=-1))] += RES/n # easy way
    for l in lin: # more accurate way (still extremely inefficient)
        d = norm(grid-l, axis=-1) # get the distances
        idxs = np.argsort(d)[:4] #get the n closest points
        d = d[idxs] #get the  distances
        w = 1/d #get the weights
        w /= w.sum() #normalize the weights
        mask[idxs] += w*RES/n #add the weights to the mask
    mask_idxs = np.where(mask > 0)
    mask = mask[mask_idxs]
    return mask, mask_idxs

def polar2cart(v): 
    vshape = v.shape
    v = v.reshape(-1, 2)
    r, θ = v[:,0], v[:,1]
    x, y = RM+r*cos(θ), ZM+r*sin(θ)
    xy = np.stack((x, y), axis=-1)
    return xy.reshape(vshape)

def cart2polar(v): 
    vshape = v.shape
    v = v.reshape(-1, 2)
    x, y = v[:,0], v[:,1]
    r, θ = hypot(x-RM, y-ZM), arctan2(y-ZM, x-RM)
    rθ = np.stack((r, θ), axis=-1)
    return rθ.reshape(vshape)

In [None]:
# test polar2cart and cart2polar
RZ_POL = cart2polar(RZ)
assert RZ.shape == RZ_POL.shape, f'Expected {RZ.shape} but got {RZ_POL.shape}'
assert np.allclose(RZ_POL, cart2polar(polar2cart(RZ_POL))), 'cart2polar(polar2cart(x)) != x'
assert np.allclose(RZ, polar2cart(cart2polar(RZ))), 'polar2cart(cart2polar(x)) != x'
# plot RZ_POL
plt.figure(figsize=(6,6))
plt.scatter(*RZ_POL.T, s=5, c='r')
plt.title('RZ_POL')
plt.xlabel('R')
plt.ylabel('Z')
plt.show()

In [None]:
# test line
c, θ = (2, 0), π/2
c, θ = (2, 0), uniform(0, π)
# c, θ = (2, 0), π/6

l1 = create_line(c, θ)
print(l1.shape)

mask, midxs = line_mask(c, θ)
print(np.sum(mask))

square_mask = np.zeros((RES*RES))
square_mask[midxs] = mask

# plot
plt.figure()
# plt.scatter(l1[:,0], l1[:,1], c='r', s=1)
plt.scatter(RR, ZZ, c=square_mask.reshape((RES,RES)), s=10, cmap='gray')
plt.plot(l1[:,0], l1[:,1], ':y')
plt.plot(FW[:,0], FW[:,1], 'w')
plt.xlim(R0-δ/2, R1+δ/2)
plt.ylim(Z0-δ/2, Z1+δ/2)
plt.gca().set_aspect('equal', adjustable='box')
plt.grid(False)
plt.colorbar()
plt.show()

In [None]:
# test polar gaussian and rays
# #standard gaussian
# μ = np.array([R0+2*L/3, Z0+3*L/5])
# Σ = np.array([[L/40, 0], [0, L/16]])
# gauss = gaussian(RZ, μ, Σ) 

# polar gaussian
# μ, Σ = np.array([0.25, π/3]), np.array([[1/100, 0.0], [0.0, π/3]])
μ, Σ = np.array([0.0, π/3]), np.array([[1/500, 0.0], [0.0, 100*π]])
μ, Σ = np.array([0.2, π/3]), np.array([[1/500, 0.0], [0.0, π/5]])
gauss_pol = gaussian(RZ_POL, μ, Σ, polar=True)
gauss = gauss_pol
# c, θ = (2, 0), π/6
c, θ = (2, 0), uniform(0, π)
c, θ = (uniform(1.5,2.5), uniform(-.5, .5)), uniform(0, π)
mask1, midxs1 = line_mask(c, θ)
square_mask1 = np.zeros((RES*RES))
square_mask1[midxs1] = mask1

combined = gauss.reshape(-1)[midxs1] * mask1 * δ

combined_square = np.zeros((RES*RES))
combined_square[midxs1] = combined

# sxr integration values
sxr = np.sum(combined)
print(f'SXR: {sxr:.4f}')

# plot the 4 maps
#plot gasussian polar
plt.figure(figsize=(20,5))
plt.subplot(141)
plt.scatter(*RZ_POL.T, s=5, c=gauss_pol)
plt.title('Gaussian Polar')
plt.xlabel('r'), plt.ylabel('θ')
plt.colorbar()
plt.subplot(142)
plt.scatter(RR, ZZ, c=gauss, s=20)
plt.clim(0, 1)
plt.title('Gaussian')
plt.axis('equal')
plt.colorbar()
plt.grid(False)
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.subplot(143)
plt.scatter(RR, ZZ, c=square_mask1, s=10)
plt.clim(0, 1)
plt.title('Line')
plt.axis('equal')
plt.colorbar()
plt.grid(False)
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.subplot(144)
plt.scatter(RR, ZZ, c=combined_square/δ, s=10)
plt.clim(0, 1)
plt.title('Combined')
plt.axis('equal')
plt.colorbar()
plt.grid(False)
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.show()

In [None]:
# create a fan of rays
N_RAYS1, MAXΘ1 = 21, π/5
N_RAYS2, MAXΘ2 = 23, π/5
C1, C2 = (2.8, 0), (2.0, -0.7)
Θ1s = np.linspace(-MAXΘ1, MAXΘ1, N_RAYS1) # horizontal rays
Θ2s = np.linspace(-MAXΘ2, MAXΘ2, N_RAYS2) + π/2 # vertical rays

rays1 = [create_line(C1, θ) for θ in Θ1s]
fan1 = [line_mask(C1, θ) for θ in Θ1s]
rays2 = [create_line(C2, θ) for θ in Θ2s]
fan2 = [line_mask(C2, θ) for θ in Θ2s]

#redefine the gaussian
μ1 = np.array([uniform(R0, R1), uniform(Z0, Z1)])
Σ1 = np.array([[uniform(0, L/16), 0], [0, uniform(0, L/16)]])
gauss1 = gaussian(RZ, μ1, Σ1)
μ2 = np.array([uniform(R0, R1), uniform(Z0, Z1)])
Σ2 = np.array([[uniform(0, L/16), 0], [0, uniform(0, L/16)]])
gauss2 = gaussian(RZ, μ2, Σ2)
gauss = gauss1 + gauss2

# calculate the combined sxr integration values
combined1 = np.zeros((N_RAYS1, RES*RES))
combined2 = np.zeros((N_RAYS2, RES*RES))
for i, (mask, midxs) in enumerate(fan1):
    combined1[i,midxs] = gauss.reshape(-1)[midxs] * mask * δ
for i, (mask, midxs) in enumerate(fan2):
    combined2[i,midxs] = gauss.reshape(-1)[midxs] * mask * δ

# square combined sxr integration values
combined_square1 = np.zeros((N_RAYS1, RES*RES))
combined_square2 = np.zeros((N_RAYS2, RES*RES))
for i, (mask, midxs) in enumerate(fan1):
    combined_square1[i,midxs] = combined1[i,midxs]
for i, (mask, midxs) in enumerate(fan2):
    combined_square2[i,midxs] = combined2[i,midxs]

sxrs1 = np.sum(combined1, axis=-1)
sxrs2 = np.sum(combined2, axis=-1)

# plot the rays
plt.figure(figsize=(10,10))
plt.subplot(221)
plt.scatter(RR, ZZ, c=gauss, s=20)
plt.clim(0, 1)
for r in rays1: plt.plot(r[:,0], r[:,1], 'r:')
for r in rays2: plt.plot(r[:,0], r[:,1], 'b:')
plt.axis('equal')
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.colorbar()
plt.title('Rays')
plt.grid(False)
# plot the combined sxr integration values
plt.subplot(222)
plt.scatter(RR, ZZ, c=combined_square1.sum(axis=0)/δ, s=5)
plt.axis('equal')
plt.clim(0, 1)
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.grid(False)
plt.colorbar()
plt.title('Rays integration')
plt.subplot(224)
plt.scatter(RR, ZZ, c=combined_square2.sum(axis=0)/δ, s=5)
plt.axis('equal')
plt.clim(0, 1)
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.grid(False)
plt.colorbar()
plt.title('Rays integration')
# bar plot the combined sxr integration values
plt.subplot(223)
plt.plot(π/4-sxrs1, -np.linspace(-π/4,+π/4,N_RAYS1), 'rs-')
plt.axvline(π/4, color='r', linestyle=':')
plt.plot(-np.linspace(-π/4,+π/4,N_RAYS2), sxrs2-π/4, 'bs-')
plt.axhline(-π/4, color='b', linestyle=':')
plt.xticks([-π/4, -π/8, 0, π/8, π/4], ['-π/4', '-π/8', '0', 'π/8', 'π/4'])
plt.yticks([-π/4, -π/8, 0, π/8, π/4], ['-π/4', '-π/8', '0', 'π/8', 'π/4'])
plt.axis('equal')
plt.grid(True)
plt.title('SXR')

plt.show()

# Create a Dataset

In [8]:
# dataset parameters
NVRAYS = 21 # number of vertical rays
NHRAYS = 23 # number of horizontal rays
CV, CH = (2.0, -0.7), (2.7, 0) # centers of the fans
MAXVANGLE = π/5 # maximum vertical angle
MAXHANGLE = π/5 # maximum horizontal angle

In [None]:
# create fans of rays (can take time)
#angles
ΘVs = np.linspace(-MAXVANGLE, MAXVANGLE, NVRAYS) + π/2 # vertical rays
ΘHs = np.linspace(-MAXHANGLE, MAXHANGLE, NHRAYS) # horizontal rays
#fans
VFAN = [line_mask(CV, θ) for θ in tqdm(ΘVs)] # vertical fan
HFAN = [line_mask(CH, θ) for θ in tqdm(ΘHs)] # horizontal fan
vrays = [create_line(CV, θ) for θ in ΘVs] # vertical rays (for plotting)
hrays = [create_line(CH, θ) for θ in ΘHs] # horizontal rays (for plotting)

In [10]:
# constants for random gaussians
MAX_MIX = 1 # max number of gaussians
K1 = 0.9 # [m] max radius mean multiplier
K2A, K2B = 1/10, 1/4 # std deviation range multiplier
K3A, K3B = 0.9999, 1.0  # min and max constant for mixing gaussians
K4A, K4B = 0.4, 0.95 # min and max constant for grid size

def create_random_means_stds(max_mix=3, k1=K1, k2=(K2A, K2B), k3=(K3A, K3B)):
    n = np.random.randint(1, max_mix+1) if max_mix > 1 else 1 # number of gaussians
    mix = uniform(k3[0], k3[1], n) # mixing coefficients
    μr = uniform(0, L*k1/2, n) # [m] random radius
    μθ = uniform(0, 2*π, n) # [rad] random angle
    μ = np.stack((μr, μθ), axis=-1) # [m, rad] random mean
    μ = polar2cart(μ) #convert to cartesian
    Σx, Σy = uniform(L*k2[0], L*k2[1], n), uniform(L*k2[0], L*k2[1], n) # [m] random std deviation
    Σ = np.zeros((n, 2, 2))
    Σ[:,0,0] = Σx**2
    Σ[:,1,1] = Σy**2
    return mix, μ, Σ

def eval_gaussians(ps, mix, μ, Σ):
    assert len(mix) == len(μ) == len(Σ), 'mix, μ, Σ must be lists of the same length'
    d = np.zeros(ps.shape[:-1]) # initialize the distribution
    for i in range(len(μ)): d += mix[i]*gaussian(ps, μ[i], Σ[i]) # add the gaussians
    return d

def create_random_gaussian_mix():
    mix, μ, Σ = create_random_means_stds()
    g = eval_gaussians(RZ, mix, μ, Σ)
    return g

def eval_on_fan(d, f): # d distribution, f fan: [(mask, mask indexes), ...]
    sxr = np.zeros(len(f)) # sxr integration values
    #assert shape is a square
    assert d.shape[0] == d.shape[1], 'distribution must be square'
    l = d.shape[0] # square side length
    
    for i, (m, mi) in enumerate(f): # for each ray, m: mask, mi: mask indexes
        sxr[i] = np.sum(d.reshape(-1)[mi]*m*δ) # integrate the distribution along the ray
    return sxr

# def create_random_subgrid(k4=(K4A, K4B)):
#     w, h = L * uniform(k4[0], k4[1]), L * uniform(k4[0], k4[1]) # [m] random subgrid size
#     r0, z0 = R0 + uniform(0, L-w), Z0 + uniform(0, L-h) # [m] random start
#     r, z = np.linspace(r0, r0+w, RES), np.linspace(z0, z0+h, RES) # create the grid
#     rr, zz = np.meshgrid(r, z)
#     return rr, zz

In [None]:
# test
mix, μ, Σ = create_random_means_stds()
g = eval_gaussians(RZ, mix, μ, Σ)
sxrh = eval_on_fan(g, HFAN)
sxrv = eval_on_fan(g, VFAN)
g_lr = g[::KHRES, ::KHRES]

# plot g and sxrh
plt.figure(figsize=(15,5))
plt.subplot(131)
plt.scatter(RR, ZZ, c=g, s=20)
# plot rays
for r in vrays: plt.plot(r[:,0], r[:,1], 'b:')
for r in hrays: plt.plot(r[:,0], r[:,1], 'r:')
plt.axis('equal')
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.colorbar()
plt.title('Distribution HR')
plt.grid(False)
plt.subplot(132)
plt.scatter(rr, zz, c=g_lr, s=20)
plt.axis('equal')
plt.xlim(R0, R1), plt.ylim(Z0, Z1)
plt.colorbar()
plt.title('Distribution LR')
plt.grid(False)
plt.subplot(133)
plt.plot(ΘVs-π/2, sxrv, 'bs', label='V')
plt.plot(ΘHs, sxrh, 'rs', label='H')
# plt.plot(ΘVsHR-π/2, sxrvhr, 'b:', label='V HR')
# plt.plot(ΘHsHR, sxrhhr, 'r:', label='H HR')
plt.xticks([-π/4, -π/8, 0, π/8, π/4], ['-π/4', '-π/8', '0', 'π/8', 'π/4'])
plt.grid(True)
plt.legend()
plt.title('SXR')
plt.show()

In [None]:
# create dataset (n samples)
N = 100_000 # train

def create_dataset(n):
    emiss_hr = np.zeros((n, RES, RES), dtype=np.float32) # emissivity high resolution
    emiss_lr = np.zeros((n, RES//KHRES, RES//KHRES), dtype=np.float32) # emissivity low resolution
    sxrh = np.zeros((n, NHRAYS), dtype=np.float32)
    sxrv = np.zeros((n, NVRAYS), dtype=np.float32)
    for i in tqdm(range(n), desc=f'Creating {n} dataset'):
        mix, μ, Σ = create_random_means_stds(max_mix=MAXHANGLE, k1=K1, k2=(K2A, K2B), k3=(K3A, K3B))
        emiss_hr[i] = eval_gaussians(RZ, mix, μ, Σ)
        sxrh[i] = eval_on_fan(emiss_hr[i], HFAN)
        sxrv[i] = eval_on_fan(emiss_hr[i], VFAN)
        emiss_lr[i] = emiss_hr[i, ::KHRES, ::KHRES]
    # save the dataset
    np.savez(f'data/sxr_ds_{n}.npz', emiss_hr=emiss_hr, emiss_lr=emiss_lr, sxrh=sxrh, sxrv=sxrv, RR=RR, ZZ=ZZ, rr=rr, zz=zz)

create_dataset(N)
create_dataset(N//10)

In [None]:
# load the dataset
data = np.load(f'data/sxr_ds_{N//10}.npz')
print(data.files)
# extract the data
emiss_hr, emiss_lr, sxrh, sxrv = data['emiss_hr'], data['emiss_lr'], data['sxrh'], data['sxrv']
print(emiss_hr.shape, emiss_lr.shape, sxrh.shape, sxrv.shape)

In [None]:
# plot the dataset
N_PLOTS = 30
idxs = np.random.randint(0, N//10, N_PLOTS)
for i in idxs:
    plt.figure(figsize=(15,3))
    mind, maxd = np.min(emiss_hr[i]), np.max(emiss_hr[i])
    # distribution
    plt.subplot(131)
    plt.scatter(RR, ZZ, c=emiss_hr[i], s=6)
    plt.axis('equal')
    plt.xlim(R0, R1), plt.ylim(Z0, Z1)
    plt.colorbar()
    plt.clim(mind, maxd)
    plt.title(f'Distribution {i}')
    plt.grid(False)
    # subgrid
    plt.subplot(132)
    plt.scatter(rr, zz, c=emiss_lr[i], s=3)
    plt.axis('equal')
    plt.xlim(R0, R1), plt.ylim(Z0, Z1)
    plt.colorbar()
    plt.clim(mind, maxd)
    plt.title(f'Subgrid {i}')
    plt.grid(False)
    # sxr
    plt.subplot(133)
    plt.plot(ΘVs-π/2, sxrv[i], 'bs', label='V')
    plt.plot(ΘHs, sxrh[i], 'rs', label='H')
    # plt.plot(ΘVsHR-π/2, sxrv_hr[i], 'b:', label='V HR')
    # plt.plot(ΘHsHR, sxrh_hr[i], 'r:', label='H HR')
    plt.xticks([-π/4, -π/8, 0, π/8, π/4], ['-π/4', '-π/8', '0', 'π/8', 'π/4'])
    plt.grid(True)
    plt.legend()
    plt.title(f'SXR {i}')
    plt.show()
    plt.close()