####  Data Generation/Collection

In [None]:
import numpy as np
import scipy.io as sio
import os
import matplotlib.pyplot as plt
import glob

# Load high-resolution data generated from MATLAB code
hr_files = sorted(glob.glob('/home/diya/Projects/super_resolution/flow_super_resolution/dataset/train_data_ml_spatio-temporal_fukami_paper/data_HR/matlab_data/data*'))

high_res_data = []
for file in hr_files:
    mat_data = sio.loadmat(file)
    # Extract vorticity field 'omg' as mentioned in the paper
    high_res_data.append(mat_data['omg'])
    
high_res_data = np.array(high_res_data)  # Shape: [n_samples, 128, 128]
print(high_res_data.shape)

(256, 128, 128)


#### Create Low-Resolution Data using Average Pooling

In [2]:
from scipy.ndimage import zoom

def average_downsample(data, target_size=(16, 16)):
    """Perform average downsampling on the input data"""
    n_samples = data.shape[0]
    low_res_data = np.zeros((n_samples, target_size[0], target_size[1]))
    
    for i in range(n_samples):
        # Reshape to perform average pooling
        h, w = data[i].shape
        pool_size = (h // target_size[0], w // target_size[1])
        reshaped = data[i].reshape(target_size[0], pool_size[0], 
                                  target_size[1], pool_size[1])
        low_res_data[i] = reshaped.mean(axis=(1, 3))
    
    return low_res_data

# Generate low-resolution data
low_res_data_8x8 = average_downsample(high_res_data, target_size=(8, 8))
# low_res_data_16x16 = average_downsample(high_res_data, target_size=(16,16))

#### Verify the shapes of the data

In [3]:
# print(low_res_data_16x16.shape)
print(low_res_data_8x8.shape)

(256, 8, 8)


#### Split data into training and validation sets

In [6]:
from sklearn.model_selection import train_test_split

# Split data (80% train, 20% validation as commonly used)
X_train, X_val, y_train, y_val = train_test_split(
    low_res_data_8x8, high_res_data, test_size=0.2, random_state=42
)

curr_dir = '/home/diya/Projects/super_resolution/flow_super_resolution/dataset/train_data_ml_spatio-temporal_fukami_paper/data_HR/'

# Save the processed data

np.save(curr_dir + 'high_res/train.npy', y_train)
np.save(curr_dir + 'high_res/val.npy', y_val)
np.save(curr_dir + 'low_res_8x8/train.npy', X_train)
np.save(curr_dir + 'low_res_8x8/val.npy', X_val)

In [7]:
print(X_train.shape)

(204, 8, 8)


In [8]:
print(y_train.shape)

(204, 128, 128)


In [9]:
print(X_val.shape)

(52, 8, 8)


In [10]:
print(y_val.shape)

(52, 128, 128)


#### Check the shapes of the flowfields

In [12]:
# Load the saved data
X_train = np.load(curr_dir + 'low_res_8x8/train.npy')
X_val = np.load(curr_dir + 'low_res_8x8/val.npy')
y_train = np.load(curr_dir + 'high_res/train.npy')
y_val = np.load(curr_dir + 'high_res/val.npy')

# Print shapes to confirm dimensions
print(f"Training data: {len(X_train)} low-res samples, shape: {X_train.shape}")
print(f"Validation data: {len(X_val)} low-res samples, shape: {X_val.shape}")
print(f"Training targets: {len(y_train)} high-res samples, shape: {y_train.shape}")
print(f"Validation targets: {len(y_val)} high-res samples, shape: {y_val.shape}")

Training data: 204 low-res samples, shape: (204, 8, 8)
Validation data: 52 low-res samples, shape: (52, 8, 8)
Training targets: 204 high-res samples, shape: (204, 128, 128)
Validation targets: 52 high-res samples, shape: (52, 128, 128)


#### Visualize the flowfields

In [None]:
def visualize_and_save_samples(low_res_samples, high_res_samples, num_samples=5, dataset_type="Training", save_dir="visualizations"):
    """
    Visualizes and saves low-res vs high-res samples.

    Parameters:
        - low_res_samples: List/array of low-resolution images.
        - high_res_samples: List/array of high-resolution images.
        - num_samples: Number of samples to visualize.
        - dataset_type: "Training" or "Validation".
        - save_dir: Directory to save the visualization images.
    """
    os.makedirs(save_dir, exist_ok=True)  # Create directory if not exists

    indices = np.random.choice(len(low_res_samples), num_samples, replace=False)
    
    for i, idx in enumerate(indices):
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        fig.suptitle(f"{dataset_type} Sample {i+1}: Low-res vs High-res", fontsize=14)

        # Low-res Image
        low_res_img = low_res_samples[idx]
        axes[0].imshow(low_res_img, cmap='inferno')
        axes[0].set_title(f"Low-Res {low_res_img.shape}")
        axes[0].set_xlabel("X-axis")
        axes[0].set_ylabel("Y-axis")
        axes[0].axis('off')

        # High-res Image
        high_res_img = high_res_samples[idx]
        axes[1].imshow(high_res_img, cmap='inferno')
        axes[1].set_title(f"High-Res {high_res_img.shape}")
        axes[1].set_xlabel("X-axis")
        axes[1].axis('off')

        # Save the figure
        save_path = os.path.join(save_dir, f"{dataset_type.lower()}_sample_{i+1}.png")
        plt.savefig(save_path, bbox_inches='tight')
        plt.close(fig)  

    print(f"Saved {num_samples} samples in '{save_dir}' directory.")

curr_dir = curr_dir
save_dir = curr_dir + 'data_visualizations/vis_8x8/'
# check if save_dir exists 
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
print(save_dir)

# Visualize training samples
train_fig = visualize_and_save_samples(X_train, y_train, num_samples=len(X_train), dataset_type="Training", save_dir = save_dir)
# plt.savefig('train_samples_visualization.png', dpi=150, bbox_inches='tight')

# Visualize validation samples
val_fig = visualize_and_save_samples(X_val, y_val, num_samples=len(X_val), dataset_type="Validation", save_dir = save_dir)
# plt.savefig('validation_samples_visualization.png', dpi=150, bbox_inches='tight')

/home/diya/Projects/super_resolution/flow_super_resolution/dataset/train_data_ml_spatio-temporal_fukami_paper/data_HR/data_visualizations/vis_8x8_/
Saved 204 samples in '/home/diya/Projects/super_resolution/flow_super_resolution/dataset/train_data_ml_spatio-temporal_fukami_paper/data_HR/data_visualizations/vis_8x8_/' directory.
Saved 52 samples in '/home/diya/Projects/super_resolution/flow_super_resolution/dataset/train_data_ml_spatio-temporal_fukami_paper/data_HR/data_visualizations/vis_8x8_/' directory.
