In [None]:
from pytorch_lightning.loggers import NeptuneLogger

run = NeptuneLogger(
    project="andreasm/ufed",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI4NTVlOWQ1ZS0xMjQwLTRlZTktYjllMC0wZmI0ODBhYjA2MGMifQ==",
     log_model_checkpoints=False
)  # your credentials

In [None]:
from network import FE_network
from model import Estimator
from dataset import Dataset
from torch.utils.data import DataLoader, random_split
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [None]:
import pandas as pd
pd.read_csv('./output_with_forces.csv').head()

In [None]:
dataset = Dataset('./output_with_forces.csv', ['omega', 'phi', 'psi'], ['F_omega', 'F_phi', 'F_psi'])


# use_mean = False
weight_decay=0.01
spectral_norm = False  
learning_rate = 0.001

# updates = 100000
batch_size = 256
input_dim = dataset.cvs.shape[1]
width = 50
depths = 5
acti = 'Softplus' # 'Silu', 'Softplus'
valid_ratio = 0.001
spectral_norm = False
patience=20
max_epochs = 100


# mean = None
# std = None
# mean_force=None
# std_force = None
net  = FE_network(input_dim, width, depths, acti, spectral_norm) # mean, std, mean_force, std_force)

accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
model = Estimator(net, run, accelerator=accelerator, weight_decay=weight_decay, lr=learning_rate)
# if use_mean:
#     forces_pred = model.net.forward(torch.Tensor(dataset.cvs)).detach().cpu().numpy()
#     mean_force = torch.Tensor(forces_pred.mean(axis=0, keepdims=True))
#     std_force = torch.Tensor(forces_pred.std(axis=0, keepdims=True))
#     net.whitening_layer_model.set_both(mean_force, std_force)
#     net.whitening_layer_output.set_both(torch.Tensor(dataset.force_mean), torch.Tensor(dataset.force_std))

n_val = int(len(dataset)*valid_ratio)

train_data, val_data = random_split(dataset, [len(dataset) - n_val, n_val])

loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True, 
                            # num_workers=0, 
                            # prefetch_factor=2
                            )
loader_val = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
# max_epochs = updates // len(loader_train) + 1
trainer = pl.Trainer(accelerator=accelerator, logger=model.neptune_logger, max_epochs=max_epochs)
# trainer.callbacks += [EarlyStopping(monitor="val_loss", mode="min", patience=patience)] 
trainer.fit(model, loader_train, loader_val)

In [None]:
loader_train = DataLoader(train_data, batch_size=len(train_data), shuffle=True, 
                            # num_workers=0, 
                            # prefetch_factor=2
                            )
loader_val = DataLoader(val_data, batch_size=len(val_data), shuffle=False)
updates = 20
max_epochs = updates // len(loader_train) + 1
# model.set_learning_rate(0.01)
trainer = pl.Trainer(accelerator=accelerator, logger=model.neptune_logger, max_epochs=max_epochs)
# trainer.callbacks += [EarlyStopping(monitor="val_loss", mode="min", patience=patience)] 
trainer.fit(model, loader_train, loader_val)

In [None]:
# plot free energy
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter

bins = 200

# create grid on dihedral angles
omega = np.linspace(-np.pi, np.pi, bins)
phi = np.linspace(-np.pi, np.pi, bins)
psi = np.linspace(-np.pi, np.pi, bins)
omega, phi, psi = np.meshgrid(omega, phi, psi)
omega = omega.flatten()
phi = phi.flatten()
psi = psi.flatten()
grid = np.vstack((omega, phi, psi)).T
print(grid.shape)

In [None]:
# compute free energy
with torch.no_grad():
    grid = torch.Tensor(grid)
    grid = grid.to(model.device)
    F = model.net.predict(grid)
    F = F.cpu().numpy()

In [None]:
F = F.reshape(bins, bins, bins)
F = F - F.min()
print(F.max()- F.min())

In [None]:
reference_frame = np.array([5, 5, 5])
plt.imshow(F[:,:,reference_frame[2]], extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
plt.show()
plt.imshow(F[:,reference_frame[1],:], extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
plt.show()
plt.imshow(F[reference_frame[0],:,:], extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
plt.show()

In [None]:
def proj(F, axis=(0)):
    # first estimate probs from FES
    T = 300
    beta = 1/(T*8.31446261815324e-3) # kJ/mol
    p = np.exp(-beta* F)
    p = p / p.sum()
    # now sum over defined axis
    p = p.sum(axis=axis)
    # now estiamte free energy
    F_proj = -np.log(p)/beta
    F_proj = F_proj - F_proj.min()
    return F_proj

In [None]:
F.shape

In [None]:
x_labels = [r'$\phi$', r'$\omega$', r'$\omega$']
y_labels = [r'$\psi$', r'$\psi$', r'$\phi$']
for i in range(3):
    plt.imshow(proj(F, axis=(i)), extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
    plt.xlabel(x_labels[i], fontsize=16)
    plt.ylabel(y_labels[i], fontsize=16)
    plt.show()

In [None]:
reestimate_fe = True
if reestimate_fe:
    import ufedmm
    import pandas as pd
    nbins_fit = 10
    factor = 8

    platform = 'CPU'
    ufed = ufedmm.deserialize('ufed_object.yml')
    df = pd.read_csv('output.csv')

    print(df[['T[atoms] (K)'] + [f'T[{v.id}] (K)' for v in ufed.variables]].mean())

    analyzer = ufedmm.FreeEnergyAnalyzer(ufed, df)

    centers, mean_forces = analyzer.centers_and_mean_forces(nbins_fit)

    delta = 2*np.pi/nbins_fit
    properties = {'Precision': 'mixed'} if platform == 'CUDA' else {}
    potential, mean_force = analyzer.mean_force_free_energy(
        centers, mean_forces, sigma=factor*delta, platform_name=platform, properties=properties,
    )

    ranges = [(cv.min_value, cv.max_value) for cv in ufed.variables]
    x = [np.linspace(*range, num=bins) for range in ranges]
    X = np.meshgrid(*x)
    Z = potential(*X)
    fe = Z-Z.min()

    # def in_degrees(angles):
    #     return [180*angle/np.pi for angle in angles]
    # plt.clf()
    # fig, ax = plt.subplots()
    # cmap = plt.get_cmap('jet')
    # extent = in_degrees([item for sublist in ranges for item in sublist])
    # ax.imshow(fe, extent=extent, cmap=cmap, interpolation='spline36', origin='lower', zorder=0)
    # ax.contour(*in_degrees(x), fe, 20, cmap=cmap, linewidths=0.5, zorder=10)
    # ax.quiver(*in_degrees(centers), *mean_forces, zorder=20)
    # plt.savefig('figure.png')
    np.save('fe_true.npy', fe)
    plt.show()
else:
    fe = np.load('fe_true.npy')

In [None]:
F.max(), F.min(), F.max()-F.min()

In [None]:
fe.max(), fe.min(), fe.max()-fe.min()

In [None]:
def diff_fe(fe1, fe2, cutoff=80, plot=False):
    fe1 = fe1 - fe1.min()
    fe2 = fe2 - fe2.min()
    mask1 = fe1 < cutoff 
    mask2 = fe2 < cutoff
    mask = mask1 & mask2
    mean_fe1 = fe1[mask].mean()
    mean_fe2 = fe2[mask].mean()
    fe1_mean_free = fe1 - mean_fe1
    fe2_mean_free = fe2 - mean_fe2
    if plot:
        plt.imshow(fe1_mean_free - fe2_mean_free, extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
        plt.colorbar()
        plt.show()
    diff = np.linalg.norm(fe1_mean_free[mask] - fe2_mean_free[mask]) / np.sqrt(mask.sum())
    return diff

In [None]:
print(diff_fe(fe, F, cutoff=80, plot=False))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('tkagg')
from matplotlib.widgets import Slider, Button
from matplotlib import cm


bins=100
F = np.random.randn(bins,bins,bins)

# The parametrized function to be plotted
def f(omega):
    return F[omega, :, :]

# Define initial parameters
init_omega = 0
# Create the figure and the line that we will manipulate
fig, ax = plt.subplots()
line = ax.imshow(f(init_omega), extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
# ax.set_xlabel('Time [s]')

# adjust the main plot to make room for the sliders
fig.subplots_adjust(bottom=0.25)

# Make a horizontal slider to control the frequency.

axfreq = fig.add_axes([0.25, 0.1, 0.65, 0.03])
omega_slider = Slider(
    ax=axfreq,
    label='Omega',
    valmin=0,
    valmax=bins-1,
    valinit=init_omega,
    valstep=np.arange(bins),
)


# The function to be called anytime a slider's value changes
def update(val):
    line.set_ydata(f(omega_slider.val))
    fig.canvas.draw_idle()


# register the update function with each slider
omega_slider.on_changed(update)

# Create a `matplotlib.widgets.Button` to reset the sliders to initial values.
resetax = fig.add_axes([0.8, 0.025, 0.1, 0.04])
button = Button(resetax, 'Reset', hovercolor='0.975')


def reset(event):
    omega_slider.reset()
button.on_clicked(reset)

plt.show()

In [None]:
plt.imshow(diff, extent=[-np.pi, np.pi, -np.pi, np.pi], cmap=cm.viridis, origin='lower')
plt.colorbar()
plt.show()

In [None]:
forces_pred = model.net.forward(torch.Tensor(dataset.cvs))

In [None]:
# forces_pred_test = net.whitening_layer_output(forces_pred)

In [None]:
forces_pred[1], forces_pred.mean(0), forces_pred.std(0), dataset.force_mean, dataset.force_std

In [None]:
net.whitening_layer_output.set_both(mean_force, std_force)

In [None]:
forces_pred[1], forces_pred_test[1]

In [None]:
forces = dataset.forces

In [None]:
forces_pred = forces_pred.detach().cpu().numpy()
plt.plot(forces_pred[:,0])

In [None]:
forces.mean(axis=0)

In [None]:
dataset.force_std

In [None]:
forces_mean_free = (forces - dataset.force_mean) / dataset.force_std

In [None]:
plt.plot(forces_mean_free[::100,1])

In [None]:
plt.plot(dataset.cvs[::100,1])

In [None]:
import pandas as pd
df = pd.read_csv('./output_with_forces.csv')

In [None]:
df