In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
%matplotlib inline
%matplotlib widget
import pickle
#from plot_utils import comp_dyn
from FourierGridMethod import Fourier_Grid,normalize_wfn,side_wp
from simulation import potential,alpha, random_perturbation, chirped_pulse, complete_simulation
import os
cwd = os.getcwd()

In [2]:
with open(cwd+'/result/MCMC-result.pickle', 'rb') as handle:
    data_dict = pickle.load(handle)

x_min=data_dict.get('x_min')
x_max=data_dict.get('x_max')
x_grid=data_dict.get('x_grid')
D=data_dict.get('D')
out=data_dict.get('out')
V_tot_au=data_dict.get('V')

In [3]:
x_min, x_max = -90*np.pi/180, 90*np.pi/180
y_min, y_max = 0, 7.25
x_size, y_size = 128, 300

#Define x and y grids
x_grid = np.linspace(x_min, x_max, x_size)
y_grid = np.linspace(y_min, y_max, y_size)

#Define meshgrid terms for 3D plots and making perturbations
x_dim, t_dim = np.meshgrid(x_grid, y_grid)

nstates=20
#Scale for unit conversion to au
ev_scale=1/(27.211*1e3)
IBr=8911925
IF=1864705
Irel=(IBr*IF)/(IBr+IF)    

#Generate the double well potential
V_pot=potential(x_grid,y_grid)
#coherence and variance in amplitude for term1
phi_coh_x = 12.2
phi_coh_y = 0.27
phi_amp_var=11.9

#Generate the n_v term and plot it. This term is in mev
phi_perturb=random_perturbation(x_size,y_size,phi_coh_x,phi_coh_y,phi_amp_var)

V = ev_scale*(V_pot)
V_vib_au=ev_scale*(V_pot+phi_perturb)

In [4]:
from scipy.integrate import simps
from itertools import cycle
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation
import numpy as np

In [5]:
def truncate(n, decimals=0):
    if n == 0:
        x = '0.'
        for i in range(decimals):
            x += '0'
        return x
    multiplier = 10**decimals
    return str(int(n * multiplier) / multiplier)

In [27]:
def comp_dyn(xmin, xmax, x, D, D_true, V, frames= 300, play=True, dt=0.024189):
    # Create the figure
    fs = 5
    ts = 30
    xLimits = [xmin, xmax]
    yLimits = [0, 0.16]
    dLimits = [0, 120]
    fig, ax = plt.subplots(figsize=(10, 8), tight_layout=True)
    ax.set_xlim(xLimits)
    ax.set_ylim(yLimits)
    ax.set_xlabel('$\phi_d$ (degree)', fontsize = fs)
    ax.set_ylabel(r"$|\psi|^2$", fontsize = fs)
    ax.grid()
    
    ax2 = ax.twinx()
    ax.set_xlim(xLimits)
    ax2.set_ylim(dLimits)
    ax2.set_xlabel('$\phi_d$ (degree)', fontsize = fs)
    ax2.set_ylabel(r"$V$ (meV)", fontsize = fs)
    plt.rc('xtick', labelsize=ts)  
    plt.rc('ytick', labelsize=ts)  
    # Plot the initial state of the function
    line, = ax.plot(x, D[0], 'r-', label="FNO", lw=2)
    lineS2, = ax2.plot(x,V[0], 'b-', label="potential", lw=2)
    lineS3, = ax.plot(x,D_true[0], 'k-', label="Split-operator", lw=2)

    time_text = ax.text(0.1, 0.95, "",ha='left', va='top', transform=ax.transAxes,
                         fontsize=25, color='red')
    
    #ax.get_xaxis().set_visible(False)
    #fig.legend() 
    # Define the animation function
    def animate(i):
        f=i
        line.set_ydata(D[f])
        lineS2.set_ydata(V[f])
        lineS3.set_ydata(D_true[f])
        s = 't = '+ ("%.3f"%(dt*f))+' ps'
        time_text.set_text(s)
        return line, lineS2

    # Create the animation
    ani = FuncAnimation(fig, animate, frames=frames, interval=100, blit=False)
    
    if play==True:
        # Show the animation
        plt.show()
    else:
        plt.close()
    
    return ani

In [19]:
import scipy.integrate as integrate
def normalize(psi, x, uniform = True):
    int_psi_square = integrate.simps(psi, x)
    return psi/int_psi_square

In [23]:
normal_D = D
normal_out = out
for i in range(len(D)):
    normal_D[i] = normalize(D[i], x_grid*(180/np.pi))
    
for i in range(len(out)):
    normal_out[i] = normalize(out[i], x_grid*(180/np.pi))

In [30]:
t=comp_dyn(x_min*(180/np.pi),x_max*(180/np.pi),x_grid*(180/np.pi),out,D,V/ev_scale, frames = 150,play=False,dt=0.0241667*2)
writervideo = animation.FFMpegWriter(fps=20)

In [29]:
t.save(cwd+'/result/MCMC-result.mp4', writer = writervideo)

In [None]:
model = torch.load(cwd+'/model/data-gaussian-pulse-10000.pt')
sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
from FNO2D import H1Loss

In [None]:
myloss = H1Loss(d=2)

In [None]:
myloss(D, out)

In [None]:
myloss(out,D)

In [None]:
torch.nn.functional.mse_loss(out, D, reduction='mean')