In [7]:
import numpy as np
import plotly.graph_objs as go
from astropy.io import fits
from astropy.wcs import WCS
import matplotlib.pyplot as plt
from scipy.constants import pi

fits_file_path = "C:/Users/ASUS/Downloads/MC-100_new.fits"
hdulist = fits.open(fits_file_path)
fits_data = hdulist[0].data
wcs = WCS(hdulist[0].header)

fits_data = np.nan_to_num(fits_data, nan=0.0, posinf=np.nanmax(fits_data), neginf=np.nanmin(fits_data))
cmap = plt.get_cmap('viridis')

l = 336.911
lu = (360. - l) * pi / 180.
Ro = 8.5
Vo = 220
tp = Ro * np.cos(lu)
Vc = Vo

distance_choice = input("Choose 'near' or 'far' for distance calculation: ").lower()
brightness_levels = []
distances = []

for i in range(fits_data.shape[0]):
    image = fits_data[i, :, :]
    max_value = np.nanmax(image)
    if max_value > 0:
        normalized_image = image / max_value
    else:
        continue
    brightness = np.nanmean(normalized_image)
    clamped_brightness = np.clip(brightness, 0, 1)
    colormap_image = cmap(normalized_image)[:, :, :3]
    
    vel = np.nanmean(fits_data[i, :, :])  # Use the mean velocity of the slice
    R = Vc / (np.abs(vel) / (Ro * np.sin(lu)) + (Vo / Ro))
    dR = np.sqrt(R**2 - (Ro * np.sin(lu))**2)
    neardist = tp - dR
    fardist = tp + dR
    distance = neardist if distance_choice == 'near' else fardist
    distances.append(distance)

    brightness_levels.append((clamped_brightness, colormap_image, distance))

# Sorting by distance
sorted_brightness_levels = sorted(brightness_levels, key=lambda x: x[2], reverse=(distance_choice == 'far'))

mesh_data = []
pixel_corners = np.array([
    [0, 0, 0],
    [fits_data.shape[2] - 1, 0, 0],
    [0, fits_data.shape[1] - 1, 0],
    [fits_data.shape[2] - 1, fits_data.shape[1] - 1, 0]
])
wcs_corners = wcs.all_pix2world(pixel_corners, 0)
ra_corners = [wcs_corner[0] for wcs_corner in wcs_corners]
dec_corners = [wcs_corner[1] for wcs_corner in wcs_corners]

for i, (brightness, image, distance) in enumerate(sorted_brightness_levels):
    opacity = brightness
    x, y = np.meshgrid(np.linspace(min(ra_corners), max(ra_corners), image.shape[1]), np.linspace(min(dec_corners), max(dec_corners), image.shape[0]))
    z = np.full(x.shape, distance)
    image_flattened = image.reshape(-1, 3)
    slice_data = go.Mesh3d(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        vertexcolor=image_flattened,
        opacity=opacity
    )
    mesh_data.append(slice_data)

fig = go.Figure(data=mesh_data)
distance_label = "Near distance (kpc)" if distance_choice == 'near' else "Far distance (kpc)"
fig.update_layout(
    scene=dict(
        aspectmode='cube',
        xaxis=dict(
            title="Right Ascension (RA)",
            range=[min(ra_corners), max(ra_corners)]
        ),
        yaxis=dict(
            title="Declination (Dec)",
            range=[min(dec_corners), max(dec_corners)]
        ),
        zaxis=dict(
            title=distance_label,
            range=[min(distances), max(distances)]
        )
    )
)

fig.show()
