In [1]:
%env CUDA_VISIBLE_DEVICES = 1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
%load_ext autoreload
%autoreload 2
import os
import sys; 
sys.path.extend(['/home/meet/FlowMatchingTests/conditional-flow-matching/'])
sys.path.extend(['..'])

import matplotlib.pyplot as plt
import numpy as np
import torch

from tqdm import tqdm
from torchcfm.conditional_flow_matching import *
from physics_flow_matching.unet.unet import GuidedUNetModelWrapper as UNetModel
from physics_flow_matching.inference_scripts.utils import grad_cost_func_parallel, cost_func_parallel, sample_noise
from physics_flow_matching.inference_scripts.cond import infer_parallel
from physics_flow_matching.inference_scripts.uncond import infer
from resdiual import calculate_kuramoto_sivashinsky_residual, two_point_corr
from physics_flow_matching.inference_scripts.cond import d_flow, d_flow_ssag, infer_grad, infer_gradfree, flow_daps, infer_grad_dpmc
from physics_flow_matching.inference_scripts.utils import cost_func, cost_func_exp, ssag_get_norm_params, ssag_sample, sample_noise, grad_cost_func
from physics_flow_matching.multi_fidelity.synthetic.dists.base import get_distribution

In [3]:
fid = "high" 
data = np.load(f"/home/meet/FlowMatchingTests/conditional-flow-matching/physics_flow_matching/multi_fidelity/synthetic/ks/{fid}_fid.npy")
test_data = np.load(f"/home/meet/FlowMatchingTests/conditional-flow-matching/physics_flow_matching/multi_fidelity/synthetic/ks/{fid}_fid_test.npy")
m, std = data.mean(axis=(0,2,3), keepdims=True), data.std(axis=(0,2,3), keepdims=True)
X = (test_data - m)/std

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def meas_func(x, **kwargs):
    return x[..., :128, :] # temporal inpainting
    # return x[..., :128] # spatial inpainting

In [6]:
meas = torch.from_numpy(meas_func(X)).to(device)

In [7]:
exp = "lf_hf"
iteration = 9
print(f"Loading model for experiment {exp}, iteration {iteration}")
ot_cfm_model = UNetModel(dim=[1, 256, 256],
                        channel_mult=None,
                        num_channels=128,
                        num_res_blocks=2,
                        num_head_channels=64,
                        attention_resolutions="40",
                        dropout=0.0,
                        use_new_attention_order=True,
                        use_scale_shift_norm=True,
                        guide_func=None
                        )
state = torch.load(f"/home/meet/FlowMatchingTests/conditional-flow-matching/physics_flow_matching/multi_fidelity/exps/{exp}/exp_gaussian_ot/saved_state/checkpoint_{iteration}.pth")
ot_cfm_model.load_state_dict(state["model_state_dict"])
ot_cfm_model.to(device)
ot_cfm_model.eval();

Loading model for experiment lf_hf, iteration 9


### Grad

In [8]:
total_samples = 1000
samples_per_batch = 1
sample_shape = (1, 256, 256)
# initial_points  =  get_distribution('gaussian').sample(total_samples, *sample_shape).to(device)
initial_points = torch.randn(total_samples, *sample_shape).to(device)
ground_truth_for_cond = torch.from_numpy(X[:total_samples]).float().to(device)
# measurement_points = meas_func(ground_truth_for_cond)
# measurement_points += 0.10 * torch.randn_like(measurement_points)

In [9]:
cond = meas[1:2].repeat(total_samples,1,1,1).to(device)
# cond += 0.10 * torch.randn_like(cond)

In [None]:
samples_cond_grad = infer_parallel(cfm_model=ot_cfm_model,
                    swag=False, samples_per_batch=samples_per_batch, total_samples=total_samples,
                    dims_of_img=(1,256,256), num_of_steps=300, grad_cost_func=grad_cost_func_parallel, meas_func= meas_func,
                    conditioning=cond, conditioning_scale=1., device=device, refine=1, sample_noise=sample_noise,
                    use_heavy_noise=False, rf_start=False,
                    solver='euler', is_grad_free=False, nu=None)

In [20]:
samples_daps = flow_daps(fm=ExactOptimalTransportConditionalFlowMatcher(sigma=1e-3), cfm_model=ot_cfm_model,
                                samples_per_batch=1, total_samples=1000,
                                dims_of_img=(1,256,256), num_of_steps=200, grad_cost_func=grad_cost_func, meas_func=meas_func,
                                conditioning=cond, device=device,
                                beta=1e-2, eta=torch.linspace(1e-6, 1e-8, 200), r=torch.linspace(1e-2, 1e-2, 200), langevin_mc_steps=200,
                                sample_noise=sample_noise, use_heavy_noise=False, start_provided=True, start_point=initial_points)

100%|██████████| 199/199 [00:37<00:00,  5.32it/s, distance=0.917]
100%|██████████| 199/199 [00:37<00:00,  5.32it/s, distance=0.976]
100%|██████████| 199/199 [00:36<00:00,  5.50it/s, distance=0.941]
100%|██████████| 199/199 [00:26<00:00,  7.47it/s, distance=0.919]
100%|██████████| 199/199 [00:26<00:00,  7.60it/s, distance=0.938]
100%|██████████| 199/199 [00:26<00:00,  7.63it/s, distance=0.956]
100%|██████████| 199/199 [00:26<00:00,  7.62it/s, distance=0.946]
100%|██████████| 199/199 [00:26<00:00,  7.64it/s, distance=0.961]
100%|██████████| 199/199 [00:25<00:00,  7.67it/s, distance=0.927]
100%|██████████| 199/199 [00:25<00:00,  7.68it/s, distance=0.912]
100%|██████████| 199/199 [00:25<00:00,  7.69it/s, distance=1.06]  
100%|██████████| 199/199 [00:26<00:00,  7.65it/s, distance=0.912]
100%|██████████| 199/199 [00:26<00:00,  7.60it/s, distance=0.915] 
100%|██████████| 199/199 [00:26<00:00,  7.44it/s, distance=0.933]
100%|██████████| 199/199 [00:25<00:00,  7.67it/s, distance=0.93] 
100%|███

KeyboardInterrupt: 

In [None]:
for i in range(2):
    fig, axes = plt.subplots(1, 3, figsize=(10, 5))
    axes[0].imshow(X[1, 0], vmax=3, vmin=-3)
    axes[0].set_title('GT')
    a = axes[1].imshow(samples_cond_grad[i, 0], vmax=3, vmin=-3)
    axes[1].set_title('Sampled Output')
    fig.colorbar(ax=axes[1], mappable=a, fraction = 0.05)
    b = axes[2].imshow(np.abs(X[1, 0] - samples_cond_grad[i, 0]), vmax=0.1, vmin=-0.1)
    axes[2].set_title('Absolute Difference')
    fig.colorbar(ax=axes[2], mappable=b, fraction = 0.05)
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(50, 5))
for i in range(2):
    axes[i].imshow(samples_cond_grad[i, 0])
    axes[i].set_title(f'Sample {i+1}')
    axes[i].axis('off')

In [None]:
corr_grad = []
for s in (samples_cond_grad):
    corr_grad.append(two_point_corr(s[0], np.arange(256), 0, 1)[-1])

In [None]:
corr_grad = np.stack(corr_grad, axis=0)

In [None]:
# for i in range(20):
#     plt.plot(np.arange(128), corr_mcmc[i], label='MCMC')
plt.plot(np.arange(128), corr_grad.mean(axis=0), label='MCMC Mean', color='red', linewidth=3)
plt.plot(np.arange(128), two_point_corr(X[1,0], np.arange(256), 0, 1)[-1], label='GT', linewidth=3, color='black')
plt.show()

In [None]:
calculate_kuramoto_sivashinsky_residual((samples_cond_grad * std + m)[:, 0], 0.2,0.245)

In [None]:
np.square((meas[1:2].cpu().numpy() - samples_cond_grad[:1,..., :128, :])).mean()

In [None]:
for i in range(2):
    fig, axes = plt.subplots(1, 3, figsize=(10, 5))
    axes[0].imshow(X[1, 0], vmax=3, vmin=-3)
    axes[0].set_title('GT')
    a = axes[1].imshow(samples_daps[i, 0], vmax=3, vmin=-3)
    axes[1].set_title('Sampled Output')
    fig.colorbar(ax=axes[1], mappable=a, fraction = 0.05)
    b = axes[2].imshow(np.abs(X[1, 0] - samples_daps[i, 0]), vmax=0.1, vmin=-0.1)
    axes[2].set_title('Absolute Difference')
    fig.colorbar(ax=axes[2], mappable=b, fraction = 0.05)
    plt.tight_layout()
    plt.show()

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(50, 5))
for i in range(2):
    axes[i].imshow(samples_daps[i, 0])
    axes[i].set_title(f'Sample {i+1}')
    axes[i].axis('off')
    
corr_grad = []
for s in (samples_daps):
    corr_grad.append(two_point_corr(s[0], np.arange(256), 0, 1)[-1])
    



In [None]:
corr_grad = np.stack(corr_grad, axis=0)
# for i in range(20):
#     plt.plot(np.arange(128), corr_mcmc[i], label='MCMC')
plt.plot(np.arange(128), corr_grad.mean(axis=0), label='MCMC Mean', color='red', linewidth=3)
plt.plot(np.arange(128), two_point_corr(X[1,0], np.arange(256), 0, 1)[-1], label='GT', linewidth=3, color='black')
plt.show()

In [None]:
calculate_kuramoto_sivashinsky_residual((samples_daps * std + m)[:, 0], 0.2,0.245)

In [None]:
np.square((meas[1:2].cpu().numpy() - samples_daps[:1,..., :128, :])).mean()