In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cmocean
import h5py
import numpy as np
import os
import re

os.environ["CUDA_VISIBLE_DEVICES"] = "4"
device = torch.device("cuda:0")
print("Using GPU:", torch.cuda.get_device_name(0))

In [None]:

def process_and_compare_images(origninal_image, my_up_sample, plot=False, **kwargs):
    '''
    Process and compare the original image with the processed image.
    
    original_image: torch.Tensor, shape=(1, 2048, 2048)
    my_up_sample: function, the function that upsample the image
    plot: bool, whether to plot the images
    **kwargs: additional arguments to pass to the upsample function
    
    return: processed_image, mean_difference
    processed_image: torch.Tensor, shape=(1, 2048, 2048)
    mean_difference: float, the mean difference between the original and processed image
    '''
    
    # def downsample(image, size=(512, 512)):
    #     return F.interpolate(image.unsqueeze(0), size=size, mode='bilinear', align_corners=False).squeeze(0)
    def downsample(image, size=(512, 512)):
        scale_factor = (image.shape[-2] // size[0], image.shape[-1] // size[1])
        downsampled_image = F.avg_pool2d(image.unsqueeze(0), kernel_size=scale_factor).squeeze(0)
        
        return downsampled_image

    def split_image(image, piece_size=(64, 64)):
        pieces = []
        for i in range(0, image.size(1), piece_size[0]):
            for j in range(0, image.size(2), piece_size[1]):
                pieces.append(image[:, i:i+piece_size[0], j:j+piece_size[1]])
        return pieces

    def stitch_image(pieces, image_size=(2048, 2048), piece_size=(256, 256)):
        stitched_image = torch.zeros((pieces[0].size(0), *image_size))
        idx = 0
        for i in range(0, image_size[0], piece_size[0]):
            for j in range(0, image_size[1], piece_size[1]):
                stitched_image[:, i:i+piece_size[0], j:j+piece_size[1]] = pieces[idx]
                idx += 1
        return stitched_image

    def compare_images(image1, image2):
        difference = torch.abs(image1 - image2)
        return difference

    # Process the image
    downsampled_image = downsample(origninal_image) # 1x2048x2048 -> 1x512x512
    pieces = split_image(downsampled_image) # 1x512x512 -> 64x64x64
    upsampled_pieces = [my_up_sample(piece) for piece in pieces] # 64 x (1x64x64 -> 1x256x256)
    result_image = stitch_image(upsampled_pieces) # 64x256x256 -> 1x2048x2048
    
    # Compare the images
    difference_image = compare_images(origninal_image, result_image)
    mean_difference = difference_image.mean().item()

    if plot:
        plot_images(origninal_image, result_image, difference_image, mean_difference=mean_difference, **kwargs)

    return result_image, mean_difference

def plot_images(original, processed, difference, normalize=True, **kwargs):
    '''
    Plot the original image, processed image, difference image, and 1D power spectrum comparison.
    
    original, processed, difference: torch.Tensor, shape=(1, 2048, 2048)
    normalize: bool, whether to normalize the images before plotting
    **kwargs: additional arguments to pass to the plot function
        save: bool, whether to save the figure
        title: str, the title of the figure, if save is True then title should be h5 path
    '''
    
    def radial_profile(data, center):
        y, x = np.indices((data.shape))
        r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
        r = r.astype(int)

        tbin = np.bincount(r.ravel(), data.ravel())
        nr = np.bincount(r.ravel())
        radialprofile = tbin / nr
        return radialprofile

    def compute_1d_power_spectrum(frame):
        # Compute the 2D Fourier transform
        fft_frame = np.fft.fftshift(np.fft.fft2(frame))
        
        # Compute the power spectrum (magnitude squared)
        power_spectrum = np.abs(fft_frame)**2
        
        # Compute the radial profile (1D power spectrum)
        center = (power_spectrum.shape[1]//2, power_spectrum.shape[0]//2)
        radial_prof = radial_profile(power_spectrum, center)
        
        return radial_prof
    
    if normalize:
        def normalize(tensor):
            return (tensor - tensor.min()) / (tensor.max() - tensor.min())
        
        original = normalize(original).permute(1, 2, 0).cpu().numpy()
        processed = normalize(processed).permute(1, 2, 0).cpu().numpy()
        difference = normalize(difference).permute(1, 2, 0).cpu().numpy()
        
    # Compute the 1D power spectrum for both images
    img1 = np.squeeze(original)
    img2 = np.squeeze(processed)
    power_spectrum_1d_1 = compute_1d_power_spectrum(img1)
    power_spectrum_1d_2 = compute_1d_power_spectrum(img2)

    # Create the k values corresponding to the radial profile
    k = np.arange(1, len(power_spectrum_1d_1) + 1)
    comparison_line = 100000 * k**(-3.0) # 1/k^3 line for comparison


    plt.figure(figsize=(20, 6))

    plt.subplot(1, 4, 1)
    plt.title('Original Image')
    plt.imshow(original, cmap=cmocean.cm.balance)
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.title('Processed Image')
    plt.imshow(processed, cmap=cmocean.cm.balance)
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.title('Difference Image')
    plt.imshow(difference, cmap=cmocean.cm.balance)
    plt.axis('off')
    
    plt.subplot(1, 4, 4)
    plt.loglog(k, power_spectrum_1d_1, linewidth=3, label='Original Power Spectrum')
    plt.loglog(k, power_spectrum_1d_2, label='Processed Power Spectrum')
    plt.loglog(k, comparison_line, 'g--', lw=2, label=r'$k^{-3}$')
    plt.title('1D Power Spectrum (Log-Log)')
    plt.xlabel('Log(Radial Distance)')
    plt.ylabel('Log(Power)')
    plt.legend()
    
    title = (kwargs['title'] if 'title' in kwargs else 'Image Comparison') + (' (Mean Difference: {:.4f})'.format(kwargs['mean_difference']))
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    
    if kwargs.get('save', False):
        filename = kwargs['title'].split('/')[-1].split('.')[0].lower() + '_comparison.png'
        plt.savefig(filename, dpi=300)
        print(f"Figure saved as {filename}")
    
    plt.show()



In [None]:
image = torch.randn(1, 2048, 2048)  

###! Put upsample diffusion forward here
def my_up_sample(image):
    result = F.interpolate(image.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)
    return result

result_image, difference = process_and_compare_images(image, my_up_sample, plot=True, normalize=True)

print(f"Diff: {difference}")

In [None]:
file_path = '/data/rdl/NSTK/16000_2048_2048_seed_3407.h5'
with h5py.File(file_path, 'r') as file:
    image = torch.tensor(file['w'][0])
    image = image.unsqueeze(0)
    print(image.shape)
    
###! Put upsample diffusion forward here
def my_up_sample(image):
    return F.interpolate(image.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0)

result_image, difference = process_and_compare_images(image, my_up_sample, plot=True, normalize=True, save=True, title=file_path)

print(f"Diff: {difference}")

In [None]:
def plot_energy_spectrum(img1, img2, label):

    def radial_profile(data, center):
        y, x = np.indices((data.shape))
        r = np.sqrt((x - center[0])**2 + (y - center[1])**2)
        r = r.astype(int)

        tbin = np.bincount(r.ravel(), data.ravel())
        nr = np.bincount(r.ravel())
        radialprofile = tbin / nr
        return radialprofile

    def compute_1d_power_spectrum(frame):
        # Compute the 2D Fourier transform
        fft_frame = np.fft.fftshift(np.fft.fft2(frame))
        
        # Compute the power spectrum (magnitude squared)
        power_spectrum = np.abs(fft_frame)**2
        
        # Compute the radial profile (1D power spectrum)
        center = (power_spectrum.shape[1]//2, power_spectrum.shape[0]//2)
        radial_prof = radial_profile(power_spectrum, center)
        
        return radial_prof

    # Create a figure with a 1x3 grid layout for the subplots (two images and one spectrum plot)
    fig, axes = plt.subplots(ncols=3, figsize=(15, 5))

    # Compute the 1D power spectrum for both images
    img1 = img1.squeeze().numpy()
    img2 = img2.squeeze().numpy()
    power_spectrum_1d_1 = compute_1d_power_spectrum(img1)
    power_spectrum_1d_2 = compute_1d_power_spectrum(img2)

    # Create the k values corresponding to the radial profile
    k = np.arange(1, len(power_spectrum_1d_1) + 1)

    # Define the comparison line: k^-3
    comparison_line = 100000 * k**(-3.0) 

    # Plot the first image
    ax1 = axes[0]
    ax1.imshow(img1, cmap=cmocean.cm.balance)
    ax1.set_title('Original Image')
    ax1.axis('off')

    # Plot the second image
    ax2 = axes[1]
    ax2.imshow(img2, cmap=cmocean.cm.balance)
    ax2.set_title('Processed Image')
    ax2.axis('off')

    # Plot both 1D power spectra on the same plot in log-log scale
    ax3 = axes[2]
    ax3.loglog(k, power_spectrum_1d_1, label='Original Power Spectrum')
    ax3.loglog(k, power_spectrum_1d_2, label='Processed Power Spectrum')
    ax3.loglog(k, comparison_line, 'g--', lw=2, label=r'$k^{-3}$')
    ax3.set_title('1D Power Spectrum (Log-Log)')
    ax3.set_xlabel('Log(Radial Distance)')
    ax3.set_ylabel('Log(Power)')
    ax3.legend()

    # Add a figure title with the label provided
    fig.suptitle(f'Energy Spectrum Analysis: {label}', fontsize=14)

    # Adjust layout to prevent overlapping
    plt.tight_layout()
    plt.show()


In [None]:
plot_energy_spectrum(image, result_image, 'Original vs. Processed Image')

In [None]:
import os, sys
import torch


import numpy as np

from unet import UNet
from diffusion_model import GaussianDiffusionModel

from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch_ema import ExponentialMovingAverage
from tqdm import tqdm
from plotting import plot_samples
import h5py
import scipy.stats
import matplotlib.pyplot as plt
import cmocean

from PIL import Image
from PIL import ImageDraw,ImageFont

    


In [None]:
print(f"os.environ['CUDA_VISIBLE_DEVICES']: {os.environ['CUDA_VISIBLE_DEVICES']}")

unet_model,lowres_head,future_head = UNet(image_size=256, in_channels=1, out_channels=1, 
                                            base_width=64,
                                            num_pred_steps=3,
                                            Reynolds_number=True)

model = GaussianDiffusionModel(base_model=unet_model.cuda(),
                                lowres_model = lowres_head.cuda(),
                                forecast_model = future_head.cuda(),
                                betas=(1e-4, 0.02),
                                n_T=10, 
                                prediction_type = 'v', 
                                sampler = 'ddim')


checkpoint_path = '/data/rdl/NSTK/checkpoint_ddim_v_multinode_64.pt'

checkpoint = torch.load(checkpoint_path)
#optimizer.load_state_dict(checkpoint["optimizer"])
model.base_model.load_state_dict(checkpoint["basemodel"])
model.lowres_model.load_state_dict(checkpoint["lowres_model"])
model.forecast_model.load_state_dict(checkpoint["forecast_model"])

ema = ExponentialMovingAverage(model.parameters(),decay=0.999)
ema.load_state_dict(checkpoint["ema"])

In [None]:
model = model
ema = ema
Reynolds_number = torch.tensor([16000.0]).to('cuda')
image = torch.randn(1, 64, 64)  

def generate_single_sample(image):
    global model
    global ema
    global Reynolds_number
    model.eval()
    image = image.unsqueeze(0)
    with torch.no_grad():
        with ema.average_parameters():
            image = image.to('cuda')
            Reynolds_number = Reynolds_number.to('cuda')
            
            cond = image.to('cuda')
            predictions = model.sample(cond.shape[0], 
                                        (1, 256, 256),
                                        cond, None, Reynolds_number,'cuda',superres=True)
            prediction = predictions[0, 0].cpu().numpy()        
            prediction = torch.from_numpy(prediction).unsqueeze(0)                
            
            return prediction

print(image.shape)
pred = generate_single_sample(image)
print(pred.shape)

In [None]:
def main(file_path, **kwargs):
    match = re.search(r'/(\d+)_', file_path)
    re_number = int(match.group(1)) / 40000.0
    Reynolds_number = torch.tensor([re_number]).to('cuda')

    with h5py.File(file_path, 'r') as file:
        image = torch.tensor(file['w'][0])
        image = image.unsqueeze(0)

    print(f"Processing file: {file_path}, Re = {Reynolds_number.item():.3f}")
    result_image, difference = process_and_compare_images(image, generate_single_sample, plot=True, normalize=True, **kwargs)

    print(f"Diff: {difference}")
    return re_number, difference
    
file_path = file_path = '/data/rdl/NSTK/16000_2048_2048_seed_3407.h5'
main(file_path, title='Test Run')

In [None]:
file_paths = ['1000_2048_2048_seed_3407.h5', '32000_2048_2048_seed_987.h5', '600_2048_2048_seed_3407.h5', 
              '16000_2048_2048_seed_3407.h5', '8000_2048_2048_seed_3407.h5', '16000_2048_2048_seed_2150.h5', 
              '32000_2048_2048_seed_2150.h5', '8000_2048_2048_seed_2150.h5', '8000_2048_2048_seed_1000.h5', 
              '1000_2048_2048_seed_2150.h5', '12000_2048_2048_seed_1000.h5', '12000_2048_2048_seed_3407.h5', 
              '12000_2048_2048_seed_2150.h5', '4000_2048_2048_seed_2150.h5', '2000_2048_2048_seed_2150.h5', 
              '4000_2048_2048_seed_3407.h5', '2000_2048_2048_seed_3407.h5', '6000_2048_2048_seed_2150.h5', 
              '6000_2048_2048_seed_3407.h5', '24000_2048_2048_seed_2150.h5', '24000_2048_2048_seed_3407.h5']
file_paths.sort()

results = []
for file_path in file_paths:
    file_path = '/data/rdl/NSTK/' + file_path
    re_number, difference = main(file_path, title=file_path, save=True)
    results.append((re_number*40000, difference))

print(results)
    

In [None]:
results.sort(key=lambda x: x[0])
for pair in results:
    print(pair)