In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('lib')
from lib.read_data import *
import time
import torch
import torch.functional as F
from model import *
from sm_model import *
torch.set_default_dtype(torch.float64)

In [None]:
N = 128
DIM = 3
outdir = f"output/output_{DIM}D_{N}/checkpt_mixedBCs_M10_ritz1600_rhs800.tar"
# outdir = f"output/output_single_2D_{N}/checkpt_frame_1_rhs_400_2D.tar"
checkpt = torch.load(outdir)

In [None]:
# for name in checkpt['model_state_dict']:
#     print(name)

In [None]:
checkpt.keys()

In [None]:
checkpt['training_loss']

In [None]:
train_loss = checkpt['training_loss']
valid_loss = checkpt['validation_loss']
fig, axes = plt.subplots(2)
axes[0].plot(train_loss, label='Training')
axes[1].plot(valid_loss, label='Validation', c='orange')
axes[0].legend()
axes[1].legend()
plt.savefig("train_loss.png")
plt.show()

In [None]:
model = SmallSMModelDn3D(3, 3)
model.load_state_dict(checkpt['model_state_dict'])
model = model.cuda()

In [None]:
N = 256
frame = 1
scene = 'dambreak'
flags = read_flags(f"data/{scene}_N{N}_200_3D/flags_{frame}.bin")
fluids = np.argwhere(flags == 2)
air = np.argwhere(flags == 3)

sol = load_vector(f"data/{scene}_N{N}_200_3D/pressure_{frame}.bin")
rhs = load_vector(f"data/{scene}_N{N}_200_3D/div_v_star_{frame}.bin")

In [None]:
image = torch.load(f"data/{scene}_N{N}_200_3D/preprocessed/{frame}/flags_binary_3.pt", map_location='cuda').view((3,)+(N,)*DIM)

In [None]:
x = torch.rand((N,)*DIM).unsqueeze(0).unsqueeze(0).cuda()
y = model(image, x)

In [None]:
b = torch.load(f"data/dambreak_N{N}_200/preprocessed/{frame}/b_100_fourier.pt")
flags = torch.load(f"data/dambreak_N{N}_200/preprocessed/{frame}/flags.pt")

In [None]:
plt.imshow(b.numpy().reshape(N,N).T, origin='lower', cmap='jet')
b.max(), b.min()

In [None]:
with torch.no_grad():
    bb = torch.rand_like(b).reshape(1, 1, N, N).double()
    y = model(flags.reshape(1, N, N).double(), b.reshape(1, 1, N, N).double())
    y1 = model(flags.reshape(1, N, N).double(), bb.double())
    s = 1.4*y.sum() + 2.2*y1.sum()
    ss = model(flags.reshape(1, N, N).double(), (2.2*bb+1.4*b.reshape(1, 1, N, N)).double()).sum()
s, ss

In [None]:
plt.imshow(y[0, 4].reshape(N, N).T, origin='lower', cmap='jet')
plt.show()

In [None]:
fluid_cells = torch.argwhere(flags.reshape(N, N) == 2).tolist()
fluid_cells

In [None]:
pos_x, pos_y = 240, 230
print('cell in fluid:', [pos_x, pos_y] in fluid_cells)
print('upper neighbor:', [pos_x, pos_y+1] in fluid_cells)
print('lower neighbor:', [pos_x, pos_y-1] in fluid_cells)
print('left neighbor:', [pos_x-1, pos_y] in fluid_cells)
print('right neighbor:', [pos_x+1, pos_y] in fluid_cells)

In [None]:
y.squeeze().permute((1, 2, 0))[pos_x, pos_y].reshape(3, 3)

In [None]:
z = y[0].numpy()
nc = 0
zz = z[nc]
print(zz.min(), zz.max())
plt.imshow(zz.T, origin='lower', cmap='jet')
plt.colorbar()
plt.show()

In [None]:
i = 30
r = np.load(f"res_{i}.npy")
plt.clf()
plt.imshow(abs(r).reshape(N,N).T, origin='lower', cmap='jet')
plt.colorbar()
plt.show()


In [None]:
scale = 2
flags_up = flags.reshape(N,N).repeat(scale, axis=0).repeat(scale, axis=1)

plt.imshow(flags_up.reshape(N*scale, N*scale).T, cmap='jet', origin='lower')
plt.colorbar()
plt.show()

In [None]:
np.save(f"data/dambreak_single/{N}_{N*scale}/{frame}/flags_{frame}.npy", flags_up.ravel())

In [None]:
from lib.discrete_laplacian import lap_with_bc, box_bd

air = np.argwhere(flags_up.ravel() == 3).ravel()
A = lap_with_bc(N*scale+2, 2, air=air, bd=box_bd(N*scale+2, 2), bd_padding=False)
A

In [None]:
sparse.save_npz(f"data/dambreak_single/{N}_{N*scale}/{frame}/A_{frame}.npz", A)