In [1]:
# import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.colors import LinearSegmentedColormap


In [None]:

# Load the datasets
dvs_df = pd.read_csv('mars_dvs.csv', header=None)
dates_df = pd.read_csv('mars_dates.csv')

# Convert date columns to datetime objects
dates_df['Launch Date'] = pd.to_datetime(dates_df['Launch Date'])
dates_df['Arrival Date'] = pd.to_datetime(dates_df['Arrival Date'])

# Extract unique launch and arrival dates
unique_launch_dates = sorted(dates_df['Launch Date'].unique())
unique_arrival_dates = sorted(dates_df['Arrival Date'].unique())

# Re-checking dimensions and addressing the mismatch
num_launch_dates_dvs = dvs_df.shape[0]
num_arrival_dates_dvs = dvs_df.shape[1]

# Ensure launch date count matches dvs_df rows
if len(unique_launch_dates) != num_launch_dates_dvs:
    raise ValueError(f"Mismatch in launch date counts. dvs_df has {num_launch_dates_dvs} rows, but dates_df has {len(unique_launch_dates)} unique launch dates.")

# Select only the relevant arrival dates that match the columns of dvs_df
relevant_arrival_dates = unique_arrival_dates[:num_arrival_dates_dvs]

# Convert datetime objects to numerical format using mdates.date2num
launch_days = mdates.date2num(unique_launch_dates)
arrival_days = mdates.date2num(relevant_arrival_dates)

# Convert dvs_df to a numpy array for plotting
dvs_matrix = dvs_df.values

# --- Axis Shift: Launch Date on X (bottom), Arrival Date on Y (left) ---
# Create meshgrid: X for Launch (columns), Y for Arrival (rows)
X, Y = np.meshgrid(launch_days, arrival_days)
# Transpose Z (dvs_matrix) because it was originally (Launch_rows, Arrival_cols)
# but now X-axis is Launch and Y-axis is Arrival, so Z needs to be (Arrival_rows, Launch_cols)
Z_transposed = dvs_matrix.T

# Ensure the dimensions of Z_transposed match X and Y from meshgrid
# Z_transposed shape: (num_arrival_dates_dvs, num_launch_dates_dvs)
# X, Y from meshgrid will have shape (len(arrival_days), len(launch_days))
# This should match: (num_arrival_dates_dvs, num_launch_dates_dvs)

# --- Normalize Scale and Colormap (lowest energies red, above ~20000 black) ---
min_val_Z = np.min(dvs_matrix) # Use original dvs_matrix for min/max
max_val_Z = np.max(dvs_matrix)

# Define the threshold for black color
threshold_black = 20000

# Normalize the threshold within the [min_val_Z, max_val_Z] range for colormap definition
norm_threshold = (threshold_black - min_val_Z) / (max_val_Z - min_val_Z)
norm_threshold = max(0.0, min(1.0, norm_threshold)) # Clamp between 0 and 1

# Define custom colormap segments
# The list of (position, color) tuples. Positions are normalized from 0 to 1.
# - Start with red at min_val_Z (position 0)
# - Transition to yellow/orange for intermediate low values
# - Transition to black exactly at threshold_black (norm_threshold)
# - Remain black for all values above threshold_black
cmap_colors_list = []

# Define colors for the range below the threshold
# Using a base colormap like 'autumn_r' or a custom gradient for red-to-yellow/green
# Let's create a red-yellow-green gradient for values below the threshold
# and then transition to black.
# Positions are relative to the overall range [min_val_Z, max_val_Z]

# Segment 1: Red for min_val_Z
cmap_colors_list.append((0.0, 'red'))

# Segment 2: Transition to yellow/orange/green up to the black threshold.
# We'll put a yellow color stop slightly before the black threshold
if norm_threshold > 0.0: # Only add intermediate if threshold is not at min
    intermediate_norm = norm_threshold * 0.7 # 70% of the way to the threshold
    cmap_colors_list.append((intermediate_norm, 'yellow')) # Transition to yellow

# Segment 3: Exactly at the threshold, become black.
cmap_colors_list.append((norm_threshold, 'black'))

# Segment 4: Remain black for all values above the threshold, up to max_val_Z (position 1.0)
if norm_threshold < 1.0: # Only if threshold is not at max
    cmap_colors_list.append((1.0, 'black'))

custom_cmap = LinearSegmentedColormap.from_list("porkchop_custom_cmap", cmap_colors_list)

# Define levels for contourf (filled contours) and contour (lines)
# For fill: many levels to show gradient below 20k, then just one level for black
levels_fill_below_20k = np.linspace(min_val_Z, threshold_black, 50) # 50 levels up to 20k
levels_fill_above_20k = np.array([threshold_black + 1, max_val_Z]) # A range for black above 20k
levels_fill = np.concatenate((levels_fill_below_20k, levels_fill_above_20k))

# For lines: typically show contours for specific, round delta-V values below the threshold.
# Values are typically given in m/s
levels_lines = np.arange(round(min_val_Z / 1000) * 1000, threshold_black, 2500) # Every 2500 m/s below 20k
if threshold_black not in levels_lines:
    levels_lines = np.append(levels_lines, threshold_black)
levels_lines = np.sort(levels_lines)


# Plotting the porkchop plot
plt.figure(figsize=(14, 10))

contourf_plot = plt.contourf(X, Y, Z_transposed, levels=levels_fill, cmap=custom_cmap)
contour_plot = plt.contour(X, Y, Z_transposed, levels=levels_lines, colors='white', linewidths=0.7) # Use white lines for better contrast on dark background

# Label the contour lines.
# Only label a subset of the contour lines if they are too dense.
plt.clabel(contour_plot, inline=True, fontsize=9, fmt='%1.0f')

# Add color bar
cbar = plt.colorbar(contourf_plot, label='Delta-V (m/s)')
# Set colorbar ticks to match the contour line values for clarity
cbar.set_ticks(levels_lines)


# Set labels and title
plt.xlabel('Launch Date') # Swapped
plt.ylabel('Arrival Date') # Swapped
plt.title('Mars Porkchop Plot ($\Delta V$ Contours)')

# Format date ticks
ax = plt.gca()

# Correctly calculate date spans using pd.Timedelta to get .days attribute
date_span_launch = pd.Timedelta(unique_launch_dates[-1] - unique_launch_dates[0]).days
date_span_arrival = pd.Timedelta(relevant_arrival_dates[-1] - relevant_arrival_dates[0]).days

# Choose interval dynamically based on span
def set_date_locator(span_days):
    if span_days > 365 * 2: # More than 2 years, show yearly
        return mdates.YearLocator()
    elif span_days > 365: # More than 1 year, show every 3 months
        return mdates.MonthLocator(interval=3)
    else: # Less than a year, show monthly
        return mdates.MonthLocator()

ax.xaxis.set_major_locator(set_date_locator(date_span_launch))
ax.yaxis.set_major_locator(set_date_locator(date_span_arrival))

ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.yaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))

# Set explicit limits to match the extent of the data exactly
ax.set_xlim(launch_days.min(), launch_days.max()) # Swapped
ax.set_ylim(arrival_days.min(), arrival_days.max()) # Swapped

plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)

plt.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig('mars_porkchop_plot_custom.png')
print("Porkchop plot with custom axes and colormap saved as 'mars_porkchop_plot_custom.png'")

# Also output the code as text