In [13]:
import matplotlib.pyplot as plt
import numpy as np
import os
import csv
from sklearn.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Define the directory containing the saved results
results_dir = r"D:\Class Project\209\results"
output_dir = r"D:\Class Project\209\results\ReIm"
output_dir = os.path.join(output_dir, "4layer_nonorm_output_0.75")
os.makedirs(output_dir, exist_ok=True)  # Ensure the output directory exists

# CSV file to record statistics
stats_file = os.path.join(output_dir, "test_statistics.csv")
fieldnames = ["Sample Index", "MSE Before", "MSE After", "SSIM Before", "SSIM After", "PSNR Before", "PSNR After"]

# Initialize lists to store statistics
all_mse_before = []
all_mse_after = []
all_ssim_before = []
all_ssim_after = []
all_psnr_before = []
all_psnr_after = []

# Open the CSV file for writing
with open(stats_file, mode="w", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()

    # Get a list of indices by finding unique files in the directory
    num_samples = len([f for f in os.listdir(results_dir) if f.startswith("undersampled_")])

    # Iterate through all the saved samples
    for index in range(num_samples):
        # Load results for the current sample
        undersampled = np.load(os.path.join(results_dir, f"undersampled_{index}.npy"))
        output = np.load(os.path.join(results_dir, f"output_{index}.npy"))
        ground_truth = np.load(os.path.join(results_dir, f"ground_truth_{index}.npy"))

        # Normalize the arrays
        output = np.abs(output)
        undersampled = (undersampled - np.min(undersampled)) / (np.max(undersampled) - np.min(undersampled))
        output = (output - np.min(output)) / (np.max(output) - np.min(output))

        # Calculate statistics
        mse_before = mean_squared_error(ground_truth, undersampled)
        mse_after = mean_squared_error(ground_truth, output)
        ssim_before = ssim(ground_truth, undersampled, data_range=ground_truth.max() - ground_truth.min())
        ssim_after = ssim(ground_truth, output, data_range=ground_truth.max() - ground_truth.min())
        psnr_before = psnr(ground_truth, undersampled, data_range=ground_truth.max() - ground_truth.min())
        psnr_after = psnr(ground_truth, output, data_range=ground_truth.max() - ground_truth.min())

        # Append to overall statistics
        all_mse_before.append(mse_before)
        all_mse_after.append(mse_after)
        all_ssim_before.append(ssim_before)
        all_ssim_after.append(ssim_after)
        all_psnr_before.append(psnr_before)
        all_psnr_after.append(psnr_after)

        # Write to CSV
        writer.writerow({
            "Sample Index": index,
            "MSE Before": mse_before,
            "MSE After": mse_after,
            "SSIM Before": ssim_before,
            "SSIM After": ssim_after,
            "PSNR Before": psnr_before,
            "PSNR After": psnr_after
        })

        # Plot the results
        plt.figure(figsize=(16, 4))
        plt.subplot(1, 6, 1)
        plt.title("Undersampled")
        plt.imshow(undersampled, cmap='gray')
        plt.subplot(1, 6, 2)
        plt.title("Reconstructed")
        plt.imshow(output, cmap='gray')
        plt.subplot(1, 6, 3)
        plt.title("Ground Truth")
        plt.imshow(ground_truth, cmap='gray')
        plt.subplot(1, 6, 4)
        plt.title("Residual initial")
        plt.imshow(undersampled - ground_truth, cmap='gray')
        plt.subplot(1, 6, 5)
        plt.title("Residual after")
        plt.imshow(output - ground_truth, cmap='gray')
        plt.subplot(1, 6, 6)
        plt.title("Residual difference")
        plt.imshow(output - undersampled, cmap='gray')

        # Save the plot as a TIFF file
        save_path = os.path.join(output_dir, f"sample_{index}.tiff")
        plt.savefig(save_path, format='tiff')
        plt.close()  # Close the plot to free memory
        print(f"Saved plot for sample {index} as {save_path}")

# Calculate and display mean statistics
mean_mse_before = np.mean(all_mse_before)
mean_mse_after = np.mean(all_mse_after)
mean_ssim_before = np.mean(all_ssim_before)
mean_ssim_after = np.mean(all_ssim_after)
mean_psnr_before = np.mean(all_psnr_before)
mean_psnr_after = np.mean(all_psnr_after)

print("Mean Statistics:")
print(f"MSE Before: {mean_mse_before}")
print(f"MSE After: {mean_mse_after}")
print(f"SSIM Before: {mean_ssim_before}")
print(f"SSIM After: {mean_ssim_after}")
print(f"PSNR Before: {mean_psnr_before}")
print(f"PSNR After: {mean_psnr_after}")

# Append mean statistics to the CSV
with open(stats_file, mode="a", newline="") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writerow({
        "Sample Index": "Mean",
        "MSE Before": mean_mse_before,
        "MSE After": mean_mse_after,
        "SSIM Before": mean_ssim_before,
        "SSIM After": mean_ssim_after,
        "PSNR Before": mean_psnr_before,
        "PSNR After": mean_psnr_after
    })

print(f"Saved test statistics to {stats_file}")


Saved plot for sample 0 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_0.tiff
Saved plot for sample 1 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_1.tiff
Saved plot for sample 2 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_2.tiff
Saved plot for sample 3 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_3.tiff
Saved plot for sample 4 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_4.tiff
Saved plot for sample 5 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_5.tiff
Saved plot for sample 6 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_6.tiff
Saved plot for sample 7 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_7.tiff
Saved plot for sample 8 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\sample_8.tiff
Saved plot for sample 9 as D:\Class Project\209\results\ReIm\4layer_nonorm_output_0.75\samp