# *Data-driven collective variables for enhanced sampling* - **Training**

This notebook contains the code used in the paper "Data-driven collective variables for enhanced sampling" by Bonati, Rizzi and Parrinello (2019).

# Setup and methods

### Modules

In [0]:
!pip3 install torch==1.1.0 torchvision==0.3.0 -f https://download.pytorch.org/whl/torch_stable.html

In [0]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

In [0]:
import os
import numpy as np
import scipy 
import matplotlib.pyplot as plt
import pandas as pd
import itertools
import progressbar

## Datasets

In [0]:
class ColvarDataset(Dataset):
    """COLVAR dataset"""

    def __init__(self, colvar_list):
        self.nstates = len( colvar_list )
        self.colvar = colvar_list
        
    def __len__(self):
        return len(self.colvar[0])

    def __getitem__(self, idx):
        x = ()
        for i in range(self.nstates):
            x += (self.colvar[i][idx],)
        return x
    
#useful for cycling over the test dataset even if it is smaller than the training set
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

## NN architecture

In [0]:
##################################
# Define Networks
##################################

class NN_DeepLDA(nn.Module):
    
    def __init__(self, l ):
        super(NN_DeepLDA, self).__init__()
        
        #Encoder architecture
        modules=[]
        for i in range( len(l)-1 ):
            print(l[i],' --> ', l[i+1], end=' ')
            if( i<len(l)-2 ):
                modules.append(nn.Linear(l[i], l[i+1]) )
                modules.append( nn.ReLU(True) )
                print("(relu)")
            else:
                modules.append(nn.Linear(l[i], l[i+1]) )
                print("")
                
        self.nn = nn.Sequential(*modules)
        
        #norm option
        self.normIn = False
        
    def set_norm(self, Mean: torch.Tensor, Range: torch.Tensor):
        self.normIn = True
        self.Mean = Mean
        self.Range = Range
        
    def normalize(self, x: Variable):
        batch_size = x.size(0)
        x_size = x.size(1)
        
        Mean = self.Mean.unsqueeze(0).expand(batch_size, x_size)
        Range = self.Range.unsqueeze(0).expand(batch_size, x_size)
        
        return x.sub(Mean).div(Range)
    
    def get_hidden(self, x: Variable, svd=False, svd_vectors=False, svd_eigen=False, training=False) -> (Variable):
        if(self.normIn):
            x = self.normalize(x)   
        z = self.nn(x)
        return z

    def set_lda(self, x: torch.Tensor):
        self.lda = nn.Parameter(x.unsqueeze(0), requires_grad=False) 

    def get_lda(self) -> (torch.Tensor):
        return self.lda

    def apply_lda(self, x: Variable) -> (Variable):
        z = torch.nn.functional.linear(x,self.lda)
        return z
        
    def forward(self, x: Variable) -> (Variable):
        z = self.get_hidden(x,svd=False)
        z = self.apply_lda(z)
        return z

    def get_cv(self, x: Variable) -> (Variable):
        return self.forward(x)
    
    

## Loss function

In [0]:
# -- loss function --
def LDAloss_cholesky(H, label, test_routines=False):
    #sizes
    N, d = H.shape
    
    #H = H*1e2

    # Mean centered observations for entire population
    H_bar = H - torch.mean(H, 0, True)
    #Total scatter matrix (cov matrix over all observations)
    S_t = H_bar.t().matmul(H_bar) / (N - 1)
    #Define within scatter matrix and compute it
    S_w = torch.Tensor().new_zeros((d, d), device = device, dtype = dtype)    
    S_w_inv = torch.Tensor().new_zeros((d, d), device = device, dtype = dtype)
    buf = torch.Tensor().new_zeros((d, d), device = device, dtype = dtype)
    #Loop over classes to compute means and covs
    for i in range(categ):
        #check which elements belong to class i
        H_i = H[torch.nonzero(label == i).view(-1)]
        # compute mean centered obs of class i
        H_i_bar = H_i - torch.mean(H_i, 0, True)
        # count number of elements
        N_i = H_i.shape[0]
        if N_i == 0:
            continue
        
        #LDA
        S_w += H_i_bar.t().matmul(H_i_bar) / ((N_i - 1) * categ)
        
        ######HLDA
        #inv_i = H_i_bar.t().matmul(H_i_bar) / ((N_i - 1) * categ)
        #S_w_inv += inv_i.pinverse()       
        
    #S_w = S_w_inv.pinverse()
    #END HLDA#########        

    S_b = S_t - S_w

    S_w = S_w + lambdA * torch.diag(torch.Tensor().new_ones((d), device = device, dtype = dtype))

    ## Generalized eigenvalue problem: S_b * v_i = lambda_i * Sw * v_i 

    # (1) use cholesky decomposition for S_w
    L = torch.cholesky(S_w,upper=False)

    # (2) define new matrix using cholesky decomposition and 
    L_t = torch.t(L)
    L_ti = torch.inverse(L_t)
    L_i = torch.inverse(L)
    S_new = torch.matmul(torch.matmul(L_i,S_b),L_ti)

    # (3) solve  S_new * w_i = lambda_i * w_i
    eig_values, eig_vectors = torch.symeig(S_new,eigenvectors=True)
    eig_vectors = eig_vectors.t()
    # (4) sort eigenvalues and retrieve old eigenvector 
    #eig_values, ind = torch.sort(eig_values, 0, descending=True)
    max_eig_vector = eig_vectors[-1]   
    max_eig_vector = torch.matmul(L_ti,max_eig_vector)
    norm=max_eig_vector.pow(2).sum().sqrt()
    max_eig_vector.div_(norm)

    loss = - eig_values[-1]

    return loss, eig_values, max_eig_vector, S_b, S_w

In [0]:
def check_LDA_cholesky(loader, model):
    with torch.no_grad():
        for data in loader:
            X,y = data[0].float().to(device),data[1].long().to(device)
            H  = model.get_hidden(X)
            _, eig_values, eig_vector, _, _ = LDAloss_cholesky(H, y)
    return eig_values, eig_vector

## Encode (for analysis)

In [0]:
def encode_hidden(loader,model,batch,n_hidden,device):
    """Compute the compressed representation for an entire dataset (with two classes A and B)"""
    s=np.empty((len(loader),batch,n_hidden))
    l=np.empty((len(loader),batch))
    for i,data in enumerate(loader):
        x,lab = data[0].float(),data[1].long()
        x = Variable(x).to(device)
        cv = model.get_hidden(x,svd=False)
        #cv = model.apply_pca(cv)
        s[i] = cv.detach().cpu().numpy()
        l[i] = lab
        
    s=s.reshape(len(loader)*batch,n_hidden)
    s=s[0:len(loader)*batch]

    l=l.reshape(len(loader)*batch)
    l=l[0:len(loader)*batch]
    
    sA = s[l==0]
    sB = s[l==1]

    return sA,sB

def encode_cv(loader,model,batch,n_cv,device):
    """Compute the compressed representation for an entire dataset (with two classes A and B)"""
    s=np.empty((len(loader),batch,n_cv))
    l=np.empty((len(loader),batch))
    for i,data in enumerate(loader):
        x,lab = data[0].float(),data[1].long()
        x = Variable(x).to(device)
        cv = model(x)
        s[i] = cv.detach().cpu().numpy()
        l[i] = lab
        
    s=s.reshape(len(loader)*batch,n_cv)
    s=s[0:len(loader)*batch]

    l=l.reshape(len(loader)*batch)
    l=l[0:len(loader)*batch]
    
    sA = s[l==0]
    sB = s[l==1]

    return sA,sB

def encode_cv_all(loader,model,batch,n_cv,device):
    """Compute the compressed representation for an entire dataset (with two classes A and B)"""
    s=np.empty((len(loader),batch,n_cv))
    l=np.empty((len(loader),batch))
    for i,data in enumerate(loader):
        x,lab = data[0].float(),data[1].long()
        x = Variable(x).to(device)
        cv = model.get_cv(x)
        s[i] = cv.detach().cpu().numpy()
        l[i] = lab
        
    s=s.reshape(len(loader)*batch,n_cv)
    s=s[0:len(loader)*batch]

    l=l.reshape(len(loader)*batch)
    l=l[0:len(loader)*batch]

    return s,l

## Plot functions with save opt

In [0]:
def plot_results(save=False,testing=False,accuracy=False,chem_space=False):
    ngrid=3
    if accuracy:
        ngrid=4
        if chem_space:
            ngrid=5
    for i in range(ngrid):
        with grid.output_to(0,i):
            grid.clear_cell()
    with grid.output_to(0,0):
        plot_training(save)
    with grid.output_to(0,1):
        plot_H(save,testing)
    with grid.output_to(0,2):
        plot_CV(save,testing)
    if accuracy:
        with grid.output_to(0,3):
            plot_accuracy(save,testing)
        if chem_space:
            with grid.output_to(0,4):
                plot_chem_space(save)

def plot_training(save=False):
    pylab.figure(figsize=(5, 5))
    pylab.title("Deep-LDA optimization")
    pylab.plot(np.asarray(ep),np.asarray(eig),'.-', c='tab:green', label='batch')
    pylab.plot(np.asarray(ep),np.asarray(eig_t),'.-', c='tab:grey', label='population')
    pylab.xlabel("Epoch")
    pylab.ylabel("1st Eigenvalue")
    pylab.legend()
    if save:
        pylab.savefig("{}/{}.png".format(tr_folder, "training"),dpi=150)

def plot_accuracy(save=False,testing=False):
    pylab.figure(figsize=(5, 5))
    pylab.title("Classification accuracy")
    pylab.plot(np.asarray(ep),100*np.asarray(acc),'.-', c='tab:cyan', label='training')
    if testing:
        pylab.plot(np.asarray(ep),100*np.asarray(acc_t),'.-', c='tab:orange', label='testing')
    pylab.xlabel("Epoch")
    pylab.ylabel("Accuracy (%)")
    pylab.legend()
    if save:
        pylab.savefig("{}/{}.png".format(tr_folder, "accuracy"),dpi=150)

def plot_H(save=False,testing=False):
    pylab.figure(figsize=(5, 5))
    pylab.title("LDA on Hidden-space H")
    # -- Testing and Validation histograms --
    trA,trB = encode_hidden(valid_loader_labels,model,batch_val,n_hidden,device)
    eigen=max_eig_vector.detach().numpy()

    pylab.scatter(trA[:,0],trA[:,1], c='tab:red', label='trA',alpha=0.3)
    pylab.scatter(trB[:,0],trB[:,1], c='tab:blue', label='trB',alpha=0.3)

    if testing:
        ttA,ttB = encode_hidden(test_meta_labels,model,batch_test,n_hidden,device)
        pylab.scatter(ttA[:,0],ttA[:,1], c='tab:orange', label='testA',s=0.2, alpha=0.5)
        pylab.scatter(ttB[:,0],ttB[:,1], c='tab:cyan', label='testB',s=0.2, alpha=0.5)
        mIN=np.min([np.min(trA[:,0]),np.min(trB[:,0]),np.min(ttA[:,0]),np.min(ttB[:,0])])
        mAX=np.max([np.max(trA[:,0]),np.max(trB[:,0]),np.max(ttA[:,0]),np.max(ttB[:,0])])
    else:
        mIN=np.min([np.min(trA[:,0]),np.min(trB[:,0])])
        mAX=np.max([np.max(trA[:,0]),np.max(trB[:,0])])

    x=np.linspace(mIN,mAX,100)
    y=-eigen[0]/eigen[1]*x+0
    #pylab.plot(x,y, linewidth=2, label='DeepLDA')
    pylab.legend() 
    if save:
        pylab.savefig("{}/{}.png".format(tr_folder, "hidden"),dpi=150)

def plot_CV(save=False,testing=False):
    sA,sB = encode_cv(valid_loader_labels,model,batch_val,n_cv,device)
    sA,sB = sA[:,0], sB[:,0]
    if testing:
        stA,stB = encode_cv(test_meta_labels,model,batch_test,n_cv,device)
        stA,stB = stA[:,0], stB[:,0]

        min_s=np.min([np.min(sA),np.min(sB),np.min(stA),np.min(stB)])
        max_s=np.max([np.max(sA),np.max(sB),np.max(stA),np.max(stB)])
    else:
        min_s=np.min([np.min(sA),np.min(sB)])
        max_s=np.max([np.max(sA),np.max(sB)])
    
    b=np.linspace(min_s,max_s,100)

    pylab.figure(figsize=(5, 5))
    pylab.title("Deep-LDA CV Histogram")
    pylab.hist(sA, bins=b, ls='dashed', alpha = 0.7, lw=2, color='tab:red', label='trA',density=True)
    pylab.hist(sB, bins=b, ls='dashed', alpha = 0.7, lw=2, color='tab:blue', label='trB',density=True)

    if testing:
        pylab.hist(stA, bins=b, ls='dashed', alpha = 0.5, lw=2, color='tab:orange', label='testA',density=True)
        pylab.hist(stB, bins=b, ls='dashed', alpha = 0.5, lw=2, color='tab:cyan', label='testB',density=True)

    pylab.legend()
    if save:
        pylab.savefig("{}/{}.png".format(tr_folder, "histogram"), dpi=150)

def plot_chem_space(save):
    s,l = encode_cv_all(test_meta_ord_labels,model,batch_test,n_cv,device)
    s = s[:,0]

    x = de_cc
    y = de_oh1

    x = x[:len(s)]
    y = y[:len(s)]

    pylab.figure(figsize=(5,5))
    scat = pylab.scatter(x,y,c=s,cmap=cm_fessa,s=1.)
    pylab.colorbar(scat)
    if save:
        pylab.savefig("{}/{}.png".format(tr_folder, "chem_space"), dpi=150)


# Ala2

### Load data

In [0]:
folder=main_folder+"ala2/unbiased/"

n_dist=45
n_input=n_dist

distA=np.loadtxt(folderA+"INPUTS.A",usecols=range(1,n_dist+1))
distB=np.loadtxt(folderB+"INPUTS.B",usecols=range(1,n_dist+1))   

print(distA.shape)

if normalize:
    # normalize inputs
    Max=np.amax(np.concatenate([distA,distB],axis=0),axis=0)
    Min=np.amin(np.concatenate([distA,distB],axis=0),axis=0)

    Mean=(Max+Min)/2.
    Range=(Max-Min)/2.
    Range[Range<1e-6]=1.

    if all_input:
        #do not normalize angles
        Mean[n_dist:]=0.
        Range[n_dist:]=1.

# create labels
lA=np.zeros_like(distA[:,0])
lB=np.ones_like(distB[:,0])

dist=np.concatenate([distA,distB],axis=0)
dist_label=np.concatenate([lA,lB],axis=0)

p = np.random.permutation(len(dist))
dist, dist_label = dist[p], dist_label[p]

#assign equal weights for testing
w=np.ones_like(dist_label)

train_data=20000
batch_tr=2000
train_labels=ColvarDataset([dist[:train_data],dist_label[:train_data],w[:train_data]])
train_loader_labels=DataLoader(train_labels, batch_size=batch_tr,shuffle=True)

valid_data=20000
batch_val=20000
valid_labels=ColvarDataset([dist[:valid_data],dist_label[:valid_data],w[:valid_data]])
#valid_labels=ColvarDataset([dist[train_data:train_data+test_data],dist_label[train_data:train_data+test_data]])
valid_loader_labels=DataLoader(valid_labels, batch_size=batch_val)

### Training

In [0]:
#type
dtype = torch.float32

#parameters
categ = 2
eig_num = 1

# wheter to use CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

n_input=n_dist
n_hidden=5
n_cv=1
nodes=[n_input,30,15,n_hidden]
normalize = True

print("===== NN =====")
model = NN_DeepLDA(nodes)
if normalize:
    model.set_norm(torch.tensor(Mean,dtype=dtype,device=device),torch.tensor(Range,dtype=dtype,device=device))
print("======================")
model.to(device)
if torch.cuda.is_available():
    print("using CUDA acceleration")
    print("========================")

# -- Optimization --
lambdA=0.05
lrate = 0.0001
l2_reg = 1e-5
act_reg = 2./lambdA

num_epochs=50
print_ep=5

#define arrays and values
ep = []
eig = []
eig_t = []
acc = []
acc_t = []
init_epoch = 0
best_result = 0
best_value = 0
best_vectors = 0

#OPTIMIZERS
opt = torch.optim.Adam(model.parameters(), lr=lrate, weight_decay=l2_reg)

#format output
float_formatter = lambda x: "%.6f" % x
np.set_printoptions(formatter={'float_kind':float_formatter})

# grid settings --> ONLY FOR COLAB
#from matplotlib import pylab
#from google.colab import widgets

#grid = widgets.Grid(1,4)

print('[{:>3}/{:>3}] {:>10} {:>10} {:>10} {:>10} {:>10}'.format('ep','tot','eig_tr','eig_test','reg'))

# -- Training --
for epoch in range(num_epochs):
    for data in train_loader_labels:
        # =================get data===================
        X,y = data[0].float().to(device),data[1].long().to(device)
        # =================forward====================
        H = model.get_hidden(X)
        # =================reg loss===================
        reg_loss = s.pow(2).sum().div( s.size(0) )
        reg_loss_lor = - act_reg / (1+(reg_loss-1).pow(2))
        # =================backprop===================
        opt.zero_grad()
        lossg.backward(retain_graph=True)
        reg_loss_lor.backward()
        opt.step()

    #Compute LDA over entire dataset
    test_eig_values, test_eig_vector = check_LDA_cholesky(valid_loader_labels, model)
    model.set_lda(test_eig_vector)    

    #Compute accuracy
    accu_train = classify(valid_loader_labels,train_loader_labels,model)
    #accu_test = classify(valid_loader_labels,test_meta_labels,model)
    
    #save results
    ep.append(epoch+init_epoch+1)
    eig.append(eig_values[-1])
    eig_t.append(test_eig_values[-1])
    acc.append(accu_train)
    #acc_t.append(accu_test)
    print
    if (epoch+1)%1 == 0:
        print('[{:3d}/{:3d}] {:10.2f} {:10.2f} {:10.2f} {:10.2G} {:10.2G}'.format
          (init_epoch+epoch+1, init_epoch+num_epochs, eig_values.detach().numpy()[-1], test_eig_values.numpy()[-1], reg_loss, test_eig_vector.numpy())

    if (epoch+1)%print_ep == 0:
        plot_results(testing=True, accuracy=True)

    if test_eig_values[0] > best_result:
        best_result = test_eig_values[0]
        best_value = test_eig_values
        best_vectors = test_eig_vector
        #torch.save(model, "model_DeepLDA.pt")
        
print("--------------")
print("-- Eigenvalues [Last // Best] --")
print(test_eig_values,best_value)
print("-- LDA Eigenvector [Last // Best]--")
print(test_eig_vector,best_vectors)


#### Analyze hidden space

In [0]:
trA,trB = encode_hidden(valid_loader_labels,model,batch_val,n_hidden,device)

print(trA.shape)

for i in range(trA.shape[1]):
    pylab.figure(figsize=(5, 5))
    pylab.title("h_"+str(i))
    pylab.plot(trA[:,i], c='tab:red', label='trA',alpha=0.7)
    pylab.plot(trB[:,i], c='tab:blue', label='trB',alpha=0.7)
    pylab.legend()

In [0]:
for i in range(trA.shape[1]):
    for j in range(i+1,trA.shape[1]):
        pylab.figure(figsize=(5, 5))
        pylab.title("h_"+str(i)+" vs h_"+str(j))
        pylab.scatter(trA[:,i],trA[:,j], c='tab:red', label='trA',alpha=0.7)
        pylab.scatter(trB[:,i],trB[:,j], c='tab:blue', label='trB',alpha=0.7)
        pylab.legend()

#### Analyze CV

In [0]:
#trA,trB = encode_cv(valid_loader_labels,model,batch_val,n_cv,device)
trA,trB = encode_cv(valid_loader_labels,model,batch_val,1,'cpu')

a = trA[:,0]
b = trB[:,0]

print("==A==")
print(np.mean(a),np.std(a))
print(np.amin(a),np.amax(a))

print("==B==")
print(np.mean(b),np.std(b))
print(np.amin(b),np.amax(b))

### Export model

In [0]:
# == Set output folder
tr_folder=main_folder+"ala2/"
!mkdir -p "{tr_folder}"

# == Plot and save results == 
grid = widgets.Grid(1,3)
plot_results(save=True,testing=True)

# == Create fake dataloader ==
fake_loader = DataLoader(train_labels, batch_size=1,shuffle=False)
fake_input = next(iter(fake_loader ))[0].float()

# == Export model ==
mod = torch.jit.trace(model, fake_input)
mod.save(tr_folder+"model.pt")
print("@@ exported model in: ",tr_folder+"model.pt" )

# == SAVE LDA COEFFICIENTS ==
f = open(tr_folder+"lda.dat", "w")
f.write(str(model.get_lda().numpy()))
f.close()

# == EXPORT CHECKPOINT ==
torch.save({
            'epoch': num_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            }, tr_folder+"checkpoint")

print("@@ checkpoint: ",tr_folder+"checkpoint" )

# Aldol reaction

#### Load Data

In [0]:
folder=main_folder+"aldol/unbiased/"

normalize=False

n_input=40

# -- Loading and preprocessing --
distA=np.loadtxt(folder+"INPUTS.R",usecols=range(1,n_input+1))
distB=np.loadtxt(folder+"INPUTS.P",usecols=range(1,n_input+1))

print("A",distA.shape)
print("B",distB.shape)

if normalize:
    # normalize inputs
    Max=np.amax(np.concatenate([distA,distB],axis=0),axis=0)
    Min=np.amin(np.concatenate([distA,distB],axis=0),axis=0)

    Mean=(Max+Min)/2.
    Range=(Max-Min)/2.
    Range[Range<1e-6]=1.

# create labels
lA=np.zeros_like(distA[:,0])
lB=np.ones_like(distB[:,0])

dist=np.concatenate([distA,distB],axis=0)
dist_label=np.concatenate([lA,lB],axis=0)

p = np.random.permutation(len(dist))
dist, dist_label = dist[p], dist_label[p]

#assign equal weights for testing
w=np.ones_like(dist_label)

train_data=10000
batch_tr=2000
train_labels=ColvarDataset([dist[:train_data],dist_label[:train_data],w[:train_data]])
train_loader_labels=DataLoader(train_labels, batch_size=batch_tr,shuffle=True)

valid_data=train_data
batch_val=train_data
valid_labels=ColvarDataset([dist[:valid_data],dist_label[:valid_data],w[:valid_data]])
valid_loader_labels=DataLoader(valid_labels, batch_size=batch_val)


#### Training

In [0]:
#type
dtype = torch.float32

#parameters
categ = 2
eig_num = 1

# wheter to use CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

n_hidden=3
n_cv=1
nodes=[n_input,20,10,n_hidden]
normalize = False

print("===== NN =====")
model = NN_DeepLDA(nodes)
if normalize:
    model.set_norm(torch.tensor(Mean,dtype=dtype,device=device),torch.tensor(Range,dtype=dtype,device=device))
print("======================")
model.to(device)
if torch.cuda.is_available():
    print("using CUDA acceleration")
    print("========================")

# -- Optimization --
lambdA = 0.05
lrate = 0.0001
l2_reg = 1e-5
act_reg = 2./lambdA

num_epochs=50
print_ep=5

#define arrays and values
ep = []
eig = []
eig_t = []
acc = []
acc_t = []
init_epoch = 0
best_result = 0

#OPTIMIZERS
opt = torch.optim.Adam(model.parameters(), lr=lrate, weight_decay=l2_reg)

#format output
float_formatter = lambda x: "%.6f" % x
np.set_printoptions(formatter={'float_kind':float_formatter})

# grid settings --> ONLY FOR COLAB
#from matplotlib import pylab
#from google.colab import widgets

#grid = widgets.Grid(1,5)

print('[{:>3}/{:>3}] {:>10} {:>10} {:>10} {:>10} {:>10}'.format('ep','tot','eig_tr','eig_test','reg'))

# -- Training --
for epoch in range(num_epochs):
    for data in train_loader_labels:
        # =================get data===================
        X,y = data[0].float().to(device),data[1].long().to(device)
        # =================forward====================
        H = model.get_hidden(X)
        # =================lda loss===================
        lossg, eig_values, max_eig_vector, Sb, Sw = LDAloss_cholesky(H, y)
        # =================reg loss===================
        reg_loss = s.pow(2).sum().div( s.size(0) )
        reg_loss_lor = - act_reg / (1+(reg_loss-1).pow(2))
        # =================backprop===================
        opt.zero_grad()
        lossg.backward(retain_graph=True)
        reg_loss_lor.backward()
        opt.step()
        
    
    #Compute LDA over entire dataset
    test_eig_values, test_eig_vector = check_LDA_cholesky(valid_loader_labels, model)
    model.set_lda(test_eig_vector)    
    #Compute accuracy
    accu_train = classify(valid_loader_labels,train_loader_labels,model)
    #accu_test = classify(valid_loader_labels,test_meta_labels,model)
    #save results
    ep.append(epoch+init_epoch+1)
    eig.append(eig_values[-1])
    #eig_t.append(test_eig_values[-1])
    acc.append(accu_train)
    #acc_t.append(accu_test)
    print
    if (epoch+1)%1 == 0:
        print('[{:3d}/{:3d}] {:10.2f} {:10.2f} {:10.2f} {:10.2G} {:10.2G}'.format
          (init_epoch+epoch+1, init_epoch+num_epochs, eig_values.detach().numpy()[-1], test_eig_values.numpy()[-1], reg_loss), test_eig_vector.numpy() )

    if (epoch+1)%print_ep == 0:
        #grid = widgets.Grid(1,5)
        #plot_results(testing=True, accuracy=True,chem_space=True)
        #print(Sw)

    if test_eig_values[0] > best_result:
        best_result = test_eig_values[0]
        best_value = test_eig_values
        best_vectors = test_eig_vector
        #torch.save(model, "model_DeepLDA.pt")
        
print("--------------")
print("-- Eigenvalues [Last // Best] --")
print(test_eig_values,best_value)
print("-- LDA Eigenvector [Last // Best]--")
print(test_eig_vector,best_vectors)

#### Export model

In [0]:
# == Set output folder
tr_folder=main_folder+"aldol/"
!mkdir -p "{tr_folder}"

# == Plot and save results == 
grid = widgets.Grid(1,5)
plot_results(save=True,testing=True,accuracy=True,chem_space=True)

# == Create fake dataloader ==
fake_loader = DataLoader(train_labels, batch_size=1,shuffle=False)
fake_input = next(iter(fake_loader ))[0].float()

# == Export model ==
mod = torch.jit.trace(model, fake_input)
mod.save(tr_folder+"model.pt")
print("@@ exported model in: ",tr_folder+"model.pt" )

# == SAVE LDA COEFFICIENTS ==
f = open(tr_folder+"lda.dat", "w")
f.write(str(model.get_lda().numpy()))
f.close()

# == EXPORT CHECKPOINT ==
torch.save({
            'epoch': num_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            }, tr_folder+"checkpoint")

print("@@ checkpoint: ",tr_folder+"checkpoint" )