# This notebokk generates the MF Allen-Cahn data

In [None]:
# Load libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.gaussian_process.kernels import Matern
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from timeit import default_timer

import matplotlib.animation as animation

In [None]:
# Set parameters

np.random.seed(123)
nx = 65
ny = 65
num_samples = 4000
batch_size = 20
dx = 3/nx
dy = 3/ny
x = np.arange(0,3,dx)
y = np.arange(0,3,dy)
dt = 0.04
T = 10
nstep = int(T/dt)
epsilon = 0.001
step = (dx,dy)
t = np.arange(0,T+dt,dt)
device = 'cpu'

In [None]:
# Load initial conditions

in_data = sio.loadmat("data/initac2d.mat")
u_init = torch.tensor(in_data['initial']).permute(2,0,1)

In [None]:
u_init.shape

In [None]:
fig1 = plt.figure(figsize=(12,4), dpi=100)
plt.subplots_adjust(wspace=0.25)
for i in range(8):
    plt.subplot(2,4,i+1)
    plt.imshow(u_init[i].cpu(),interpolation='gaussian',cmap='jet')
    plt.colorbar()

In [None]:
u_init_batch = torch.tensor_split(u_init, num_samples//batch_size, dim=0)


In [None]:
fig2 = plt.figure(figsize=(12,4), dpi=100)
plt.subplots_adjust(wspace=0.25)
for i in range(8):
    plt.subplot(2,4,i+1)
    plt.imshow(u_init_batch[0][i].cpu(),interpolation='gaussian',cmap='jet')
    plt.colorbar()

In [None]:
# Define the laplacian in Fourier domain

kx = 2*torch.pi*torch.fft.fftfreq(x.shape[0],d=dx).to(device)
ky = 2*torch.pi*torch.fft.rfftfreq(y.shape[0],d=dy).to(device)
kxx,kyy = torch.meshgrid(kx,ky)
kxx = kxx.reshape(1,nx,-1)
kyy = kyy.reshape(1,ny,-1)
lapl = -epsilon*(kxx**2+kyy**2)

In [None]:
# Define the Allen-Cahn solver

def ac2d_solver(u,laplace,dt):
  uhat = torch.fft.rfft2(u)
  laplacian = laplace*uhat
  u = u + dt*(torch.fft.irfft2(laplacian,s=(u.size(-2), u.size(-1))) + u - u**3)
  return u

In [None]:
num_samples//batch_size

In [None]:
u_store = []
for batch in range(num_samples//batch_size):
    t1 = default_timer()
    
    u = u_init_batch[batch].to(device)
    u_batch = torch.zeros(batch_size,nstep+1,nx,ny)
    for i in range(nstep):
      u = ac2d_solver(u,lapl,dt)
      u_batch[:,i+1,:,:] = u.cpu()
        
    u_store.append( u_batch )
    t2 = default_timer()

    if batch % 25 == 0:
        print('Batch-{}, Time-{:0.4f}'.format(batch, t2-t1))

u_store = torch.cat(u_store)

In [None]:
# Subsample the time steps

subsample_factor = 5
u_largedt = u_store[:,::subsample_factor,:,:]
t_largedt = t[::subsample_factor]
dt_large = t_largedt[1] - t_largedt[0]

In [None]:
fig3 = plt.figure(figsize=(12,1), dpi=100)
plt.subplots_adjust(wspace=0.35)
sample = 0
index = 0
for i in range(u_largedt.shape[1]):
    if i % 10 == 0:
        plt.subplot(1,6,index+1)
        plt.imshow(u_largedt[sample, i, :, :].cpu(), interpolation='gaussian',cmap='jet')
        plt.colorbar()
        index += 1

In [None]:
u_largedt.shape

In [None]:
# Subsample the spatial dimension
r = 2

u_low = u_largedt[:,:,::r,::r]
x_low = torch.tensor(x[::r])
y_low = torch.tensor(y[::r])

In [None]:
u_low.shape

In [None]:
# Define the Laplacian in Fourier domain for low-fidelity u(x,y)
device = 'cpu'

kxl = 2*torch.pi*torch.fft.fftfreq(x_low.shape[0], d=x_low[1]-x_low[0]).to(device)
kyl = 2*torch.pi*torch.fft.rfftfreq(y_low.shape[0], d=y_low[1]-y_low[0]).to(device)
kxxl,kyyl = torch.meshgrid(kxl,kyl)
kxxl = kxxl.reshape(1,x_low.shape[0],-1)
kyyl = kyyl.reshape(1,y_low.shape[0],-1)
lapl_low = -epsilon*(kxxl**2 + kyyl**2)

In [None]:
u_low[:,:-1,:,:].reshape(-1,x_low.shape[0],y_low.shape[0]).shape

In [None]:
u_low_in = u_low[:,:-1,:,:].reshape(-1,x_low.shape[0],y_low.shape[0]).to(device)
t1 = default_timer()
u_low_out = ac2d_solver(u_low_in,lapl_low,dt_large)
t2 = default_timer()
u_low_out = u_low_out.reshape(-1,t_largedt.shape[0]-1,x_low.shape[0],y_low.shape[0]).cpu()
u_low_upscale = F.interpolate(u_low_out,size=(65,65),mode='bicubic',align_corners=True)

print('Total_time-{}'.format(t2-t1))

In [None]:
u_low_upscale.shape

In [None]:
fig4 = plt.figure(figsize=(6,2), dpi=100)
plt.subplots_adjust(wspace=0.25)
sample = 100

plt.subplot(1,2,1)
plt.imshow(u_low_upscale[sample,-1,:,:], interpolation='gaussian', cmap='jet', vmin=-1, vmax=1)
plt.colorbar()

plt.subplot(1,2,2)
plt.imshow(u_largedt[sample,-1,:,:], interpolation='gaussian', cmap='jet', vmin=-1, vmax=1)
plt.colorbar()

In [None]:
print('Error-{}'.format(torch.mean((u_largedt[:,1:,:,:]-u_low_upscale)**2)) )

In [None]:
datadict = {'uhr':u_largedt.cpu().numpy(),
           'ulr_nextstep':u_low_upscale.cpu().numpy(),
           'time':t_largedt,
           'x':x,
           'y':y,
           'dtlarge':dt_large,
           'epsilon':epsilon}

In [None]:
sio.savemat('data/ac2dlowhighres_1.mat',datadict)