# Analysis and Sampling of Molecular Simulations by adversarial Autoencoders
---
1. [Packages import](#1.-Packages-import)
2. [Adversarial autoencoder](#2.-Adversarial-autoencoder)
3. [Intcoord](#3.-Intcoord)
4. [Force field](#4.-Force-field)
5. [Execution & visualization](#5.-Execution-&-visualization)

## 1. Packages import

In [None]:
# Import packages

from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras import backend as kb
import keras as krs

from scipy.stats import multivariate_normal, gaussian_kde
from tensorflow.keras.optimizers import Adam
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import mdtraj as md
import numpy as np
import logging

In [None]:
# Define input files

# input conformation
conf = "md1_2box.pdb"

# input trajectory
traj = "md1_2fitskip20.xtc"

# input topology
topol = "topol.top"

## 2. Adversarial autoencoder

In [None]:
# Define generative adversarial network (is it?)

class GAN():
    def __init__(self):
        self.mol_shape = (383,)
        self.latent_dim = 2
        optimizer = Adam(0.0002, 0.5)
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        mol_inp = Input(shape=self.mol_shape)
        low = self.encoder(mol_inp)
        mol_out = self.decoder(low)
        self.autoencoder = Model(mol_inp, mol_out)
        self.autoencoder.compile(loss='mean_squared_error', optimizer=optimizer)
        validity = self.discriminator(low)
        self.combined = Model(mol_inp, validity)
        self.combined.compile(loss='mean_squared_error', optimizer=optimizer)
        

        
    def build_encoder(self):
        model = Sequential()
        model.add(Dense(256, input_dim=np.prod(self.mol_shape), activation="sigmoid"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.latent_dim, activation='linear'))
        model.summary(print_fn=logging.info)
        mol = Input(shape=self.mol_shape)
        lowdim = model(mol)
        return Model(mol, lowdim)

    
    def build_decoder(self):
        model = Sequential()
        model.add(Dense(1024, input_dim=self.latent_dim, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.mol_shape), activation='sigmoid'))
        model.add(Reshape(self.mol_shape))
        model.summary(print_fn=logging.info)
        lowdim = Input(shape=(self.latent_dim,))
        mol = model(lowdim)
        return Model(lowdim, mol)

    
    def build_discriminator(self):
        model = Sequential()
        model.add(Flatten(input_shape=(self.latent_dim,)))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary(print_fn=logging.info)
        mol = Input(shape=(self.latent_dim,))
        validity = model(mol)
        return Model(mol, validity)

        
    def sample_state(self, epoch):
        x = np.random.normal(0, 1, (1, self.latent_dim))
        mol = self.decoder.predict(x)
        phi = np.arccos(2.0*mol[0,72]-1.0)
        if (2.0*mol[0,73]-1.0) < 0.0:
            phi = -phi
        psi = np.arccos(2.0*mol[0,102]-1.0)
        if (2.0*mol[0,103]-1.0) < 0.0:
            psi = -psi
        pot = forcefield(mol[0])
        logging.info(pot)
        for i in range(10000):
            newx = x + np.random.normal(0, 0.5, (1, self.latent_dim))
            newmol = self.decoder.predict(newx)
            newphi = np.arccos(2.0*newmol[0,72]-1.0)
            if (2.0*newmol[0,73]-1.0) < 0.0:
                newphi = -newphi
            newpsi = np.arccos(2.0*newmol[0,102]-1.0)
            if (2.0*newmol[0,103]-1.0) < 0.0:
                newpsi = -newpsi
            newpot = forcefield(newmol[0])
            metro = np.exp((pot-newpot)/8.314/0.3)
            if newpot < pot:
                x = newx
                pot = newpot
                phi = newphi
                psi = newpsi
            elif np.random.rand(1) < metro:
                x = newx
                pot = newpot
                phi = newphi
                psi = newpsi
            logging.info(f"{i+1}, {pot}, {metro}, {phi}, {psi}")

            
    def train(self, epochs, batch_size=128):
        refpdb = md.load_pdb(conf)
        X = md.load(traj, top = topol)
        X.superpose(refpdb)
        Xt = []
        for i in range(X.n_frames):
            Xtl = []
            for j in range(X.n_atoms):
                Xtl.append(X.xyz[i,j,:])
            Xt.append(intcoord(Xtl))
        nfeatures = len(intcoord(Xtl))
        Xt = np.array(Xt)
        X_train = Xt.reshape(X.n_frames, nfeatures)
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            mols = X_train[idx]
            gen_lows = self.encoder.predict(mols)
            gen_mols = self.decoder.predict(gen_lows)
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            mols = X_train[idx]
            ae_loss = self.autoencoder.train_on_batch(mols, mols)
            c_loss = self.combined.train_on_batch(mols, valid)
            d_loss_real = self.discriminator.train_on_batch(noise, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_lows, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            if epoch % 100 == 0:
                output = f"{epoch} [D loss: {d_loss[0]},acc.: {100*d_loss[1]}]" + \
                      f"[AE loss: {ae_loss}] [C loss: {c_loss}]"
                logging.info(output)
        newlows = self.encoder(X_train)
        np.savetxt("lows.txt", newlows)
#         self.sample_state(epoch)


In [None]:
# Define logger & helping functions

logging.root.handlers = []
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.INFO,
                    handlers=[
                        logging.FileHandler("gan.log", mode="w"),
                        logging.StreamHandler()
                    ])

def dist(a1, a2):
    dd = np.linalg.norm(a1-a2)
    return(dd)


def angle(a1, a2, a3):
    d1 = np.linalg.norm(a1-a2)
    d2 = np.linalg.norm(a3-a2)
    sp = np.dot(a1-a2,a3-a2)
    aa = np.arccos(sp/d1/d2)
    return(aa)


def dihedral(a1, a2, a3, a4):
    a12 = a2-a1
    a23 = a3-a2
    a34 = a4-a3
    vp1 = np.cross(a12,a23)
    vp1 = vp1/np.linalg.norm(vp1)
    vp2 = np.cross(a23,a34)
    vp2 = vp2/np.linalg.norm(vp2)
    vp3 = np.cross(vp1,a23/np.linalg.norm(a23))
    sp1 = np.dot(vp1,vp2)
    sp2 = np.dot(vp3,vp2)
    aa = np.arctan2(sp1,sp2)-np.pi/2
    return(aa)


def parse_topology_file(topology_file):
    mode = "n"
    atoms = []
    bonds = []
    pairs = []
    angles = []
    dihedrals = []
    
    with open(topology_file, "r") as topol_file:
        topol = topol_file.readlines()
        for line in topol:
            if line[:9] == "[ atoms ]":
                mode = "a"
            if line[:9] == "[ bonds ]":
                mode = "b"
            if line[:9] == "[ pairs ]":
                mode = "p"
            if line[:10] == "[ angles ]":
                mode = "v"
            if line[:13] == "[ dihedrals ]":
                mode = "d"
            if line[:6] == "#ifdef":
                mode = "n"
            uncomm = str.split(line, ";")[0]
            sline = str.split(uncomm)
            if mode == "a":
                if len(sline) == 8:
                    atoms.append([sline[1], float(sline[6])])
            if mode == "b":
                if len(sline) == 3 and sline[0] != "[":
                    bonds.append([int(sline[0]) - 1,
                                  int(sline[1]) - 1])
            if mode == "p":
                if len(sline) == 3 and sline[0] != "[":
                    pairs.append([int(sline[0]) - 1,
                                  int(sline[1]) - 1])
            if mode == "v":
                if len(sline) == 4:
                    angles.append([int(sline[0]) - 1,
                                   int(sline[1]) - 1,
                                   int(sline[2]) - 1])
            if mode == "d":
                if len(sline) == 5:
                    dihedrals.append(
                        [
                            int(sline[0]) - 1,
                            int(sline[1]) - 1,
                            int(sline[2]) - 1,
                            int(sline[3]) - 1,
                            int(sline[4]),
                        ]
                    )
    return atoms, bonds, pairs, angles, dihedrals


## 3. Intcoord

In [None]:
def intcoord(coords):
    bondtypes = [
        ["C", "CT", 0.15220, 265265.6],
        ["C", "O", 0.12290, 476976.0],
        ["CT", "CT", 0.15260, 259408.0],
        ["CT", "HC", 0.10900, 284512.0],
        ["CT", "H1", 0.10900, 284512.0],
        ["C", "N", 0.13350, 410032.0],
        ["CT", "N", 0.14490, 282001.6],
        ["H", "N", 0.10100, 363171.2],
    ]

    angletypes = [
        ["CT", "C", "O", 120.400, 669.440],
        ["CT", "C", "N", 116.600, 585.760],
        ["N", "C", "O", 122.900, 669.440],
        ["CT", "CT", "HC", 109.500, 418.400],
        ["CT", "CT", "H1", 109.500, 418.400],
        ["CT", "CT", "N", 109.700, 669.440],
        ["HC", "CT", "HC", 109.500, 292.880],
        ["H1", "CT", "N", 109.500, 418.400],
        ["H1", "CT", "H1", 109.500, 292.880],
        ["C", "CT", "HC", 109.500, 418.400],
        ["C", "CT", "N", 110.100, 527.184],
        ["C", "CT", "CT", 111.100, 527.184],
        ["C", "CT", "H1", 109.500, 418.400],
        ["C", "N", "CT", 121.900, 418.400],
        ["C", "N", "H", 120.000, 418.400],
        ["CT", "N", "H", 118.040, 418.400],
    ]

    dihetypesp = [
        ["C", "N", "CT", "C", 9, 0.0, 1.12968, 2.0],
        ["C", "N", "CT", "C", 9, 0.0, 1.75728, 3.0],
        ["N", "CT", "C", "N", 9, 180.0, 1.88280, 1.0],
        ["N", "CT", "C", "N", 9, 180.0, 6.61072, 2.0],
        ["N", "CT", "C", "N", 9, 180.0, 2.30120, 3.0],
        ["CT", "CT", "N", "C", 9, 0.0, 8.36800, 1.0],
        ["CT", "CT", "N", "C", 9, 0.0, 8.36800, 2.0],
        ["CT", "CT", "N", "C", 9, 0.0, 1.67360, 3.0],
        ["CT", "CT", "C", "N", 9, 0.0, 0.83680, 1.0],
        ["CT", "CT", "C", "N", 9, 0.0, 0.83680, 2.0],
        ["CT", "CT", "C", "N", 9, 0.0, 1.67360, 3.0],
        ["H", "N", "C", "O", 9, 180.0, 10.46000, 2.0],
        ["H", "N", "C", "O", 9, 0.0, 8.36800, 1.0],
        ["H1", "CT", "C", "O", 9, 0.0, 3.34720, 1.0],
        ["H1", "CT", "C", "O", 9, 180.0, 0.33472, 3.0],
        ["HC", "CT", "C", "O", 9, 0.0, 3.34720, 1.0],
        ["HC", "CT", "C", "O", 9, 180.0, 0.33472, 3.0],
        ["HC", "CT", "CT", "HC", 9, 0.0, 0.62760, 3.0],
        ["N", "C", "CT", "HC", 9, 0.0, 0.00000, 0.0],
        ["CT", "C", "N", "H", 9, 180.0, 10.46000, 2.0],
        ["CT", "C", "N", "CT", 9, 180.0, 10.46000, 2.0],
        ["O", "C", "N", "CT", 9, 180.0, 10.46000, 2.0],
        ["H1", "CT", "N", "C", 9, 0.0, 0.00000, 0.0],
        ["H1", "CT", "N", "H", 9, 0.0, 0.00000, 0.0],
        ["CT", "CT", "N", "H", 9, 0.0, 0.00000, 0.0],
        ["C", "CT", "N", "H", 9, 0.0, 0.00000, 0.0],
        ["N", "CT", "CT", "HC", 9, 0.0, 0.65084, 3.0],
        ["H1", "CT", "CT", "HC", 9, 0.0, 0.65084, 3.0],
        ["C", "CT", "CT", "HC", 9, 0.0, 0.65084, 3.0],
        ["O", "C", "CT", "N", 9, 0.0, 0.00000, 0.0],
        ["N", "C", "CT", "H1", 9, 0.0, 0.00000, 0.0],
        ["O", "C", "CT", "CT", 9, 0.0, 0.00000, 0.0],
    ]
    
    atoms, bonds, pairs, angles, dihedrals = parse_topology_file(topol)
    natoms = len(atoms)

    output = []

    item = 0

    for triple in angles:
        type1 = atoms[triple[0]][0]
        type2 = atoms[triple[1]][0]
        type3 = atoms[triple[2]][0]
        a0 = 0.0
        for onetype in angletypes:
            if onetype[0] == type1 and onetype[1] == type2 and onetype[2] == type3:
                a0 = np.pi * onetype[3] / 180.0
            if onetype[0] == type3 and onetype[1] == type2 and onetype[2] == type1:
                a0 = np.pi * onetype[3] / 180.0
        aa = angle(coords[triple[0]], coords[triple[1]], coords[triple[2]])
        scaled = (aa - 0.75 * a0) / 0.5 / a0
        output.append(scaled)
        item = item + 1

    for quad in dihedrals:
        type1 = atoms[quad[0]][0]
        type2 = atoms[quad[1]][0]
        type3 = atoms[quad[2]][0]
        type4 = atoms[quad[3]][0]
        type5 = quad[4]
        if type5 == 4:
            dihe = dihedral(
                coords[quad[0]], coords[quad[1]], coords[quad[2]], coords[quad[3]]
            )
            output.append((np.cos(dihe) + 1.0) / 2.0)
            output.append((np.sin(dihe) + 1.0) / 2.0)
            item = item + 2
        if type5 == 9:
            dihe = dihedral(
                coords[quad[0]], coords[quad[1]], coords[quad[2]], coords[quad[3]]
            )
            
            for onetype in dihetypesp:
                doit = 0
                if (
                    onetype[0] == type1
                    and onetype[1] == type2
                    and onetype[2] == type3
                    and onetype[3] == type4
                ):
                    doit = 1
                if (
                    onetype[0] == type4
                    and onetype[1] == type3
                    and onetype[2] == type2
                    and onetype[3] == type1
                ):
                    doit = 1
                if doit == 1:
                    output.append((np.cos(dihe) + 1.0) / 2.0)
                    output.append((np.sin(dihe) + 1.0) / 2.0)
                    item = item + 2

    maxdist = 1.0
    for i in range(natoms):
        for j in range(i):
            dd = dist(coords[i], coords[j])
            output.append(dd / maxdist)

    return output


## 4. Force field

In [None]:
def forcefield(intcoords):
    types = [
        ["C", 0.339967, 0.359824],
        ["CT", 0.339967, 0.457730],
        ["H", 0.106908, 0.0656888],
        ["HC", 0.264953, 0.0656888],
        ["H1", 0.247135, 0.0656888],
        ["N", 0.325000, 0.711280],
        ["O", 0.295992, 0.878640],
    ]

    bondtypes = [
        ["C", "CT", 0.15220, 265265.6],
        ["C", "O", 0.12290, 476976.0],
        ["CT", "CT", 0.15260, 259408.0],
        ["CT", "HC", 0.10900, 284512.0],
        ["CT", "H1", 0.10900, 284512.0],
        ["C", "N", 0.13350, 410032.0],
        ["CT", "N", 0.14490, 282001.6],
        ["H", "N", 0.10100, 363171.2],
    ]

    angletypes = [
        ["CT", "C", "O", 120.400, 669.440],
        ["CT", "C", "N", 116.600, 585.760],
        ["N", "C", "O", 122.900, 669.440],
        ["CT", "CT", "HC", 109.500, 418.400],
        ["CT", "CT", "H1", 109.500, 418.400],
        ["CT", "CT", "N", 109.700, 669.440],
        ["HC", "CT", "HC", 109.500, 292.880],
        ["H1", "CT", "N", 109.500, 418.400],
        ["H1", "CT", "H1", 109.500, 292.880],
        ["C", "CT", "HC", 109.500, 418.400],
        ["C", "CT", "N", 110.100, 527.184],
        ["C", "CT", "CT", 111.100, 527.184],
        ["C", "CT", "H1", 109.500, 418.400],
        ["C", "N", "CT", 121.900, 418.400],
        ["C", "N", "H", 120.000, 418.400],
        ["CT", "N", "H", 118.040, 418.400],
    ]

    dihetypesi = [
        ["C", "CT", "N", "H", 4, 180.0, 4.60240, 2.0],
        ["C", "CT", "N", "O", 4, 180.0, 4.60240, 2.0],
        ["CT", "N", "C", "O", 4, 180.0, 43.93200, 2.0],
    ]

    dihetypesp = [
        ["C", "N", "CT", "C", 9, 0.0, 1.12968, 2.0],
        ["C", "N", "CT", "C", 9, 0.0, 1.75728, 3.0],
        ["N", "CT", "C", "N", 9, 180.0, 1.88280, 1.0],
        ["N", "CT", "C", "N", 9, 180.0, 6.61072, 2.0],
        ["N", "CT", "C", "N", 9, 180.0, 2.30120, 3.0],
        ["CT", "CT", "N", "C", 9, 0.0, 8.36800, 1.0],
        ["CT", "CT", "N", "C", 9, 0.0, 8.36800, 2.0],
        ["CT", "CT", "N", "C", 9, 0.0, 1.67360, 3.0],
        ["CT", "CT", "C", "N", 9, 0.0, 0.83680, 1.0],
        ["CT", "CT", "C", "N", 9, 0.0, 0.83680, 2.0],
        ["CT", "CT", "C", "N", 9, 0.0, 1.67360, 3.0],
        ["H", "N", "C", "O", 9, 180.0, 10.46000, 2.0],
        ["H", "N", "C", "O", 9, 0.0, 8.36800, 1.0],
        ["H1", "CT", "C", "O", 9, 0.0, 3.34720, 1.0],
        ["H1", "CT", "C", "O", 9, 180.0, 0.33472, 3.0],
        ["HC", "CT", "C", "O", 9, 0.0, 3.34720, 1.0],
        ["HC", "CT", "C", "O", 9, 180.0, 0.33472, 3.0],
        ["HC", "CT", "CT", "HC", 9, 0.0, 0.62760, 3.0],
        ["N", "C", "CT", "HC", 9, 0.0, 0.00000, 0.0],
        ["CT", "C", "N", "H", 9, 180.0, 10.46000, 2.0],
        ["CT", "C", "N", "CT", 9, 180.0, 10.46000, 2.0],
        ["O", "C", "N", "CT", 9, 180.0, 10.46000, 2.0],
        ["H1", "CT", "N", "C", 9, 0.0, 0.00000, 0.0],
        ["H1", "CT", "N", "H", 9, 0.0, 0.00000, 0.0],
        ["CT", "CT", "N", "H", 9, 0.0, 0.00000, 0.0],
        ["C", "CT", "N", "H", 9, 0.0, 0.00000, 0.0],
        ["N", "CT", "CT", "HC", 9, 0.0, 0.65084, 3.0],
        ["H1", "CT", "CT", "HC", 9, 0.0, 0.65084, 3.0],
        ["C", "CT", "CT", "HC", 9, 0.0, 0.65084, 3.0],
        ["O", "C", "CT", "N", 9, 0.0, 0.00000, 0.0],
        ["N", "C", "CT", "H1", 9, 0.0, 0.00000, 0.0],
        ["O", "C", "CT", "CT", 9, 0.0, 0.00000, 0.0],
    ]

    atoms, bonds, pairs, angles, dihedrals = parse_topology_file(topol)
    natoms = len(atoms)
    
    ii = 0

    vangles = 0.0
    for triple in angles:
        type1 = atoms[triple[0]][0]
        type2 = atoms[triple[1]][0]
        type3 = atoms[triple[2]][0]
        a0 = 0.0
        kk = 0.0
        for onetype in angletypes:
            if onetype[0] == type1 and onetype[1] == type2 and onetype[2] == type3:
                a0 = onetype[3]
                kk = onetype[4]
            if onetype[0] == type3 and onetype[1] == type2 and onetype[2] == type1:
                a0 = onetype[3]
                kk = onetype[4]
        aa = intcoords[ii] * 0.5 * a0 + 0.75 * a0
        ii = ii + 1
        vangles = vangles + 0.5 * kk * np.pi ** 2 * (aa - a0) ** 2 / 180.0 / 180.0

    vtorsionsp = 0.0
    vtorsionsi = 0.0
    for quad in dihedrals:
        type1 = atoms[quad[0]][0]
        type2 = atoms[quad[1]][0]
        type3 = atoms[quad[2]][0]
        type4 = atoms[quad[3]][0]
        type5 = quad[4]
        if type5 == 4:
            shift = 0.0
            kk = 0.0
            per = 0.0
            for onetype in dihetypesi:
                if (
                    onetype[0] == type1
                    and onetype[1] == type2
                    and onetype[2] == type3
                    and onetype[3] == type4
                ):
                    shift = onetype[5]
                    kk = onetype[6]
                    per = onetype[7]
                if (
                    onetype[0] == type4
                    and onetype[1] == type3
                    and onetype[2] == type2
                    and onetype[3] == type1
                ):
                    shift = onetype[5]
                    kk = onetype[6]
                    per = onetype[7]
            dihe = np.arccos(2.0 * intcoords[ii] - 1.0)
            ii = ii + 2
            vtorsionsi = vtorsionsi + kk * (
                1.0 + np.cos(per * dihe + np.pi * shift / 180.0)
            )
        if type5 == 9:
            shift = 0.0
            kk = 0.0
            per = 0.0
            for onetype in dihetypesp:
                doit = 0
                if (
                    onetype[0] == type1
                    and onetype[1] == type2
                    and onetype[2] == type3
                    and onetype[3] == type4
                ):
                    shift = onetype[5]
                    kk = onetype[6]
                    per = onetype[7]
                    doit = 1
                if (
                    onetype[0] == type4
                    and onetype[1] == type3
                    and onetype[2] == type2
                    and onetype[3] == type1
                ):
                    shift = onetype[5]
                    kk = onetype[6]
                    per = onetype[7]
                    doit = 1
                if doit == 1:
                    dihe = np.arccos(2.0 * intcoords[ii] - 1.0)
                    ii = ii + 2
                    vtorsionsp = vtorsionsp + kk * (
                        1.0 + np.cos(per * dihe + np.pi * shift / 180.0)
                    )

    maxdist = 1.0
    coulsr = 0.0
    coul14 = 0.0
    vdwsr = 0.0
    vdw14 = 0.0
    for i in range(natoms):
        for j in range(i):
            intersr = 1.0
            inter14 = 0.0
            interljsr = 1.0
            interlj14 = 0.0
            for onetype in types:
                if onetype[0] == atoms[i][0]:
                    sigmai = onetype[1]
                    epsiloni = onetype[2]
                if onetype[0] == atoms[j][0]:
                    sigmaj = onetype[1]
                    epsilonj = onetype[2]
            sigma = (sigmai + sigmaj) / 2.0
            epsilon = np.sqrt(epsiloni * epsilonj)
            for pair in bonds:
                if pair[0] == i and pair[1] == j:
                    intersr = 0.0
                    inter14 = 0.0
                    interljsr = 0.0
                    interlj14 = 0.0
                if pair[0] == j and pair[1] == i:
                    intersr = 0.0
                    inter14 = 0.0
                    interljsr = 0.0
                    interlj14 = 0.0
            for pair in pairs:
                if pair[0] == i and pair[1] == j:
                    intersr = 0.0
                    inter14 = 0.8333
                    interljsr = 0.0
                    interlj14 = 0.5
                if pair[0] == j and pair[1] == i:
                    intersr = 0.0
                    inter14 = 0.8333
                    interljsr = 0.0
                    interlj14 = 0.5
            for pair in angles:
                if pair[0] == i and pair[2] == j:
                    intersr = 0.0
                    inter14 = 0.0
                    interljsr = 0.0
                    interlj14 = 0.0
                if pair[0] == j and pair[2] == i:
                    intersr = 0.0
                    inter14 = 0.0
                    interljsr = 0.0
                    interlj14 = 0.0
            dd = intcoords[ii] * maxdist
            ii = ii + 1
            coulsr = coulsr + 138.9354859 * intersr * atoms[i][1] * atoms[j][1] / dd
            coul14 = coul14 + 138.9354859 * inter14 * atoms[i][1] * atoms[j][1] / dd
            vdwsr = vdwsr + 4.0 * interljsr * epsilon * (
                sigma ** 12 / dd ** 12 - sigma ** 6 / dd ** 6
            )
            vdw14 = vdw14 + 4.0 * interlj14 * epsilon * (
                sigma ** 12 / dd ** 12 - sigma ** 6 / dd ** 6
            )
    vpot = coulsr + coul14 + vdwsr + vdw14 + vangles + vtorsionsp + vtorsionsi
    return vpot

## 5. Execution & visualization

In [None]:
# Execute

if __name__ == '__main__':
    gan = GAN()
    
    test = gan.train(epochs=10000, batch_size=132) 


In [None]:
# Visualize final result of correction

x = []
y = []

with open("lows.txt") as ifile:
    for line in ifile.readlines():
        split_values = line.split()
        x.append(float(split_values[0]))
        y.append(float(split_values[1]))

plt.figure(figsize=(100, 50), dpi=100)
plt.scatter(x, y, c='r', marker='x', label='lows', s=3000)
plt.legend(loc='upper left', prop={'size': 80})

plt.show()

In [None]:
# Visualization of low dimensional space

# define input files
lows = np.loadtxt('lows.txt')
rama_ala = np.loadtxt('rama_ala_reduced.txt', usecols=(0,1))
angever1 = np.loadtxt('angever1.txt')
angever2 = np.loadtxt('angever2.txt')
angever3 = np.loadtxt('angever3.txt')


cvs = (lows[:, 0], lows[:, 1])
analysis_files = {
    'rama0' : rama_ala[:, 0],
    'rama1' : rama_ala[:, 1],
    'ang1' : angever1[:, 1],
    'ang2' : angever2[:, 1],
    'ang3' : angever3[:, 1]
}

# set multiplot parameters
xmin = ymin = -0.0
xmax = ymax = 1.0

plt.style.use("seaborn-white")
fig = plt.figure(figsize=(20, 12))
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X; pos[:, :, 1] = Y
values = np.vstack([cvs[0], cvs[1]])
kernel = gaussian_kde(values)
dens = np.reshape(kernel(positions).T, X.shape)


# plot first graph
ax1 = plt.subplot(2, 3, 1)
plt.imshow(np.rot90(dens), cmap="hsv", extent=[xmin, xmax, ymin, ymax])
ax1.set_ylabel('CV2')
ax1.set_xlabel('CV1')


# plot every other graph
i = 2
for name, data in analysis_files.items():
    ax = plt.subplot(2, 3, i)
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])
    ax.set_title(name)
    ax.set_ylabel('CV2')
    ax.set_xlabel('CV1')
    plt.scatter(cvs[0], cvs[1], s=5, c=data, cmap="hsv")
    i += 1
    
plt.title("Low Dimentional Space - Analysis")
plt.savefig('analysis.png')