In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
# import pdb
from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm import tqdm
import matplotlib.pyplot as plt

device = "cuda"

## Create DART Model

In [None]:
class DartModel(nn.Module):
    def __init__(self, lh, lo, li):
        super(DartModel, self).__init__()
        self.fi1 = nn.Linear(3, lh)
        self.fi2 = nn.Linear(lh, lo)
        
        self.fj1 = nn.Linear(3, lh)
        self.fj2 = nn.Linear(lh, lo)

        self.fk1 = nn.Linear(3, lh)
        self.fk2 = nn.Linear(lh, lo)
        
        self.fl1 = nn.Linear(3, lh)
        self.fl2 = nn.Linear(lh, lo)
        
        self.inter1 = nn.Linear(li, 256)
        self.inter2 = nn.Linear(256, 128)
        self.inter3 = nn.Linear(128, 32)
        self.inter4 = nn.Linear(32, 1)
        
    def forward(self, ai, aj, ak, al):
        ai_sum = ai.sum(axis=2)
        same_shape = ai_sum.shape
        ones = torch.ones(same_shape, device=device)
        zeros = torch.zeros(same_shape, device=device)
        make_zero = torch.where(ai_sum==0, zeros, ones)
        ai_mask = make_zero.unsqueeze(dim=2)

        aj_sum = aj.sum(axis=3)
        same_shape = aj_sum.shape
        ones = torch.ones(same_shape, device=device)
        zeros = torch.zeros(same_shape, device=device)
        make_zero = torch.where(aj_sum==0, zeros, ones)
        aj_mask = make_zero.unsqueeze(dim=3)

        ak_sum = ak.sum(axis=3)
        same_shape = ak_sum.shape
        ones = torch.ones(same_shape, device=device)
        zeros = torch.zeros(same_shape, device=device)
        make_zero = torch.where(ak_sum==0, zeros, ones)
        ak_mask = make_zero.unsqueeze(dim=3)

        al_sum = al.sum(axis=3)
        same_shape = al_sum.shape
        ones = torch.ones(same_shape, device=device)
        zeros = torch.zeros(same_shape, device=device)
        make_zero = torch.where(al_sum==0, zeros, ones)
        al_mask = make_zero.unsqueeze(dim=3)

        ######### atom_i ############
        ai = F.celu(self.fi1(ai), 0.1)
        ai = F.celu(self.fi2(ai), 0.1)
        ai = ai * ai_mask
        ######### atom_j ############
        aj = F.celu(self.fj1(aj), 0.1)
        aj = F.celu(self.fj2(aj), 0.1)
        aj = aj * aj_mask
        ######### atom_k ############
        ak = F.celu(self.fk1(ak), 0.1)
        ak = F.celu(self.fk2(ak), 0.1)
        ak = ak * ak_mask
        ######## atom_l ############
        al = F.celu(self.fl1(al), 0.1)
        al = F.celu(self.fl2(al), 0.1)
        al = al * al_mask

        ########### interactions of i, j, k and l atoms ############
        atm = ai + aj.sum(axis=2) + al.sum(axis=2) + al.sum(axis=2) # sum all interaction
        atm = F.celu(self.inter1(atm), 0.1)
        atm = F.celu(self.inter2(atm), 0.1)
        atm = F.celu(self.inter3(atm), 0.1)
        atm = self.inter4(atm)
        atm = atm * ai_mask
        return atm

Dart_model = DartModel(128, 128, 128).to(device)

def init_params(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight, a=1.0)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

Dart_model.apply(init_params)

## Pre-processing of data

In [None]:
class sep_ijkl_dataset(Dataset):
    def __init__(self, file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
        """
        self.data = np.load(file, allow_pickle=True)
        self.ener = self.data["ener"]
        batch = self.data["desc"]
        self.ener = torch.tensor([j for i in self.ener for j in i], dtype=torch.float, device=device)
        batch_size = len(self.ener)
        max_atoms = []

        for batch_idx in range(batch_size):
            lol = batch[batch_idx]
            max_atoms.append(len(lol[0])) # find longest sequence
            max_atoms.append(max([len(i) for i in lol[1]])) # find longest sequence
            max_atoms.append(max([len(i) for i in lol[2]])) # find longest sequence
            max_atoms.append(max([len(i) for i in lol[3]])) # find longest sequence
        iic = max_atoms[0::4]
        jjc = max_atoms[1::4]
        kkc = max_atoms[2::4]
        llc = max_atoms[3::4]

        des_j = []
        des_k = []
        des_l = []
        for i in range(batch_size):
            const_atom_count_i = max(iic) - iic[i]
            const_atom_count_j = max(jjc) - jjc[i]
            const_atom_count_k = max(kkc) - kkc[i]
            const_atom_count_l = max(llc) - llc[i]
            a_j = torch.zeros(const_atom_count_i, const_atom_count_j, 3)
            a_k = torch.zeros(const_atom_count_i, const_atom_count_k, 3)
            a_l = torch.zeros(const_atom_count_i, const_atom_count_l, 3)
            des_j.append(pad_sequence([torch.tensor(i) for i in batch[i][1]] + [i for i in a_j]))
            des_k.append(pad_sequence([torch.tensor(i) for i in batch[i][2]] + [i for i in a_k]))
            des_l.append(pad_sequence([torch.tensor(i) for i in batch[i][3]] + [i for i in a_l]))
        
        self.des_i = pad_sequence([torch.tensor(batch[i][0]) for i in range(batch_size)], batch_first=True).squeeze().float().to(device)
        des_j = pad_sequence(des_j, batch_first=True)
        self.des_j = torch.transpose(des_j, 1, 2).float().to(device)

        des_k = pad_sequence(des_k, batch_first=True)
        self.des_k = torch.transpose(des_k, 1, 2).float().to(device)

        des_l = pad_sequence(des_l, batch_first=True)
        self.des_l = torch.transpose(des_l, 1, 2).float().to(device)
        
    def __len__(self):
        return len(self.ener)
    
    def __getitem__(self, idx):
        sample = {"atm_i": self.des_i[idx], "atm_j": self.des_j[idx], "atm_k": self.des_k[idx], "atm_l": self.des_l[idx], "energy": self.ener[idx]}
        return sample

## Create Train:Validation:Test split  and dataloaders

In [None]:
desc_data = sep_ijkl_dataset("../data/small_dataset.npz")

validation_split = .1
test_split = .1
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(desc_data)
indices = list(range(dataset_size))
splitv = int(np.floor(validation_split * dataset_size))
splitt = int(np.floor(test_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices, test_indices = indices[splitt+splitv:], indices[:splitv], indices[splitv:splitt+splitv]

# Creating data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
test_sampler = SubsetRandomSampler(test_indices)

trainloader = DataLoader(desc_data, batch_size=32, sampler=train_sampler)
validloader = DataLoader(desc_data, batch_size=32, sampler=valid_sampler)
testloader = DataLoader(desc_data, batch_size=32, sampler=test_sampler)

In [None]:
criterion = torch.nn.L1Loss()
optimizer = torch.optim.Adam(Dart_model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode="min", patience=25, verbose=True, eps=1e-09)

## Training loop

In [None]:
epochal_train_losses = []
epochal_val_losses  = []
num_epochs = 1500
epoch_freq = 1
       
def test(Dart_model, testloader):
    mae = torch.nn.L1Loss()
    rmse = torch.nn.MSELoss()
    pred_energy = torch.tensor([], device="cuda")
    real_energy = torch.tensor([], device="cuda")
    cluster_size = torch.tensor([], device="cuda")
    Dart_model.eval()
    with torch.no_grad():
        for batch in testloader:
            energy = Dart_model(batch["atm_i"], batch["atm_j"], batch["atm_k"], batch["atm_l"])
            energy = energy.sum(axis=1).squeeze(dim=1)
            pred_energy = torch.cat((pred_energy, energy))
            real_energy = torch.cat((real_energy, batch["energy"]))
            cluster_size = torch.cat((cluster_size, batch["atm_i"][:,0].sum(axis=1)))
        results = torch.stack((cluster_size, real_energy, pred_energy), axis=1)
        test_loss = mae(pred_energy, real_energy)
        rmse_loss = torch.sqrt(rmse(pred_energy, real_energy))
        print("Test MAE = ", test_loss.item(), "Test RMSE = ", rmse_loss.item())
        return results, test_loss, rmse_loss
    
def train(Dart_model, optimizer, epochal_train_losses, criterion):
    train_loss = 0.00
    n = 0
    Dart_model.train()
    for batch in trainloader:
        optimizer.zero_grad()
        energy = Dart_model(batch["atm_i"], batch["atm_j"], batch["atm_k"], batch["atm_l"])
        energy = energy.sum(axis=1)
        batch_loss = criterion(energy, batch["energy"].unsqueeze(dim=1))
        batch_loss.backward()
        optimizer.step()
        
        train_loss += batch_loss.detach().cpu()
        n += 1
    train_loss /= n
    epochal_train_losses.append(train_loss)

def train_and_evaluate(Dart_model, optimizer, scheduler, criterion, start_epoch=1, restart=None):
    if restart:
        restore_path = os.path.join(log_dir + "/last.pth.tar")
        checkpoint = load_checkpoint(restore_path, Dart_model, optimizer)
        start_epoch = checkpoint["epoch"]

    best_val = 100000.00
    early_stopping_learning_rate = 1.0E-8
    
    for epoch in range(1, num_epochs+1):
        learning_rate = optimizer.param_groups[0]['lr']
        if learning_rate < early_stopping_learning_rate:
            break

        ############ training #############
        train(Dart_model, optimizer, epochal_train_losses, criterion)
        
        ############ validation #############
        n=0
        val_loss = 0.0
        Dart_model.eval()
        for batch in validloader:
            energy = Dart_model(batch["atm_i"], batch["atm_j"], batch["atm_k"], batch["atm_l"])
            energy = energy.sum(axis=1)
            batch_loss = criterion(energy, batch["energy"].unsqueeze(dim=1))
            val_loss += batch_loss.detach().cpu()
            n += 1
        val_loss /= n
        epochal_val_losses.append(val_loss)
        scheduler.step(val_loss)
     
        is_best = val_loss <= best_val
        if epoch % epoch_freq == 0:
            print("Epoch: {: <5} Train: {: <20} Val: {: <20}".format(epoch, epochal_train_losses[-1], val_loss))

## Let's start training

In [None]:
train_and_evaluate(Dart_model, optimizer, scheduler, criterion)

## Testing the model on unseen data (test-set)

In [None]:
results, test_mae, test_rmse = test(Dart_model, testloader)

In [None]:
# plt.plot(np.arange(0, len(epochal_train_losses[10:]), 1), epochal_train_losses[10:], label='Training Loss')
# plt.plot(np.arange(0, len(epochal_train_losses[10:]), 1), epochal_val_losses[10:], label='validation loss')
# plt.legend(loc='best')
# plt.xlabel("Epochs")
# plt.ylabel("Loss")
# plt.show()

In [None]:
plt.plot(np.arange(0, len(epochal_train_losses), 1), epochal_train_losses, label='Training Loss')
plt.plot(np.arange(0, len(epochal_train_losses), 1), epochal_val_losses, label='validation loss')
plt.legend(loc='best')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

In [None]:
results = results[results[:,0].argsort()].cpu().numpy()

# plt.plot(np.arange(0, results.shape[0]), abs(results[:,1]-results[:,2]), label='Training Loss')
plt.hist(abs(results[:,1]-results[:,2]), density=True)
plt.xlabel("MAE")
plt.ylabel("Freq")
plt.show()

In [None]:
some_res = np.split(results[:,1:], np.unique(results[:, 0], return_index=True)[1][1:])
diff = [abs(i[:,0]-i[:,1]).mean() for i in some_res]
plt.title("Cluster size vs MAE")
plt.bar(np.arange(31, 41), diff, label='Training Loss')
plt.xlabel("Cluster size")
plt.ylabel("MAE")
plt.show()

In [None]:
trainset = desc_data[train_indices]["atm_i"]
lol = trainset[:,0].sum(axis=1)
sizzle = [i.item() for i in lol]
plt.title("Train set Cluster distribution, size = {}".format(len(sizzle)))
plt.xlabel("Cluster size")
plt.ylabel("Freq")
plt.hist(sizzle, bins=10)
plt.show()
# plt.savefig("cluster_distribution_trainset.png", bbox_inches='tight')

In [None]:
trainset = desc_data[test_indices]["atm_i"]
lol = trainset[:,0].sum(axis=1)
sizzle = [i.item() for i in lol]
plt.title("Test set Cluster distribution, size = {}".format(len(sizzle)))
plt.xlabel("Cluster size")
plt.ylabel("Freq")
plt.hist(sizzle, bins=39)
plt.show()
# plt.savefig("cluster_distribution_testset.png", bbox_inches='tight')

In [None]:
trainset = desc_data[val_indices]["atm_i"]
lol = trainset[:,0].sum(axis=1)
sizzle = [i.item() for i in lol]
plt.title("Validation set Cluster distribution, size = {}".format(len(sizzle)))
plt.hist(sizzle, bins=39)
plt.xlabel("Cluster size")
plt.ylabel("Freq")
plt.show()
# plt.savefig("cluster_distribution_validationset.png", bbox_inches='tight')