In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
from electrolyte.model.electrolyte_d import ElectrolyteModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
import numpy as np 
sys.path.append('/vepfs/fs_users/ycjin/Delta-ML-Framework/Unimol_2_NMR_fix/descriptior/unimol_tools')
from unimol_tools import UniMolRepr, UniMolRepr_F
from unimol_tools.data import Coords2Unimol
from ase.io import read
from rdkit import Chem
import torch

class UniRepr_Generator(object):
    def __init__(self,base_type = 'mol', finetune = False):
        if finetune:
            self.generator = UniMolRepr_F(data_type='molecule', remove_hs=False, base_type = base_type, no_optimize = True)
        else:
            self.generator = UniMolRepr_F(data_type='molecule', remove_hs=False, base_type = base_type, no_optimize = True)
        
        # self.coords2unimol = Coords2Unimol()

    def UniRepr_atom(self, mol, atom_id, molecule_repr = True):
        if len(mol) > 1:
            uni_desc = self.generator.get_repr(mol)
            uni_descs = []
            if molecule_repr:
                for i in range(len(uni_desc['cls_repr'])):
                    uni_descs.append(uni_desc['cls_repr'][i] + uni_desc['atomic_reprs'][i][atom_id[i]-1].tolist())
            else:
                for i in range(len(uni_desc['cls_repr'])):
                    uni_descs.append(uni_desc['atomic_reprs'][i][atom_id[i]-1].tolist())
            return uni_descs
        else: 
            uni_desc = self.generator.get_repr([mol])
            return uni_desc['cls_repr'][0] + uni_desc['atomic_reprs'][0][atom_id[0]-1].tolist()

    def UniRepr_molecule(self, mol, atom_repr = False):
        # if len(mol) > 1:
        #     uni_desc = self.generator.get_repr(mol)
        # else:
        uni_desc = self.generator.get_repr(mol)

        if atom_repr == False:
            return uni_desc['cls_repr']
        else:
            return uni_desc['cls_repr'], uni_desc['atomic_reprs']
        
    def get2train(self, input, desc_level, only_atom_repr = True):
        max_num = 500
        atom_id = input['atom']
        uni_descs = []
        if input['unimol']['src_tokens'].shape[0] < max_num:
            uni_desc = self.generator.get_reprs(input['unimol'])
            for i in range(len(uni_desc['cls_repr'])):
                if desc_level == 'atom':
                    if only_atom_repr == True:
                        uni_descs.append(uni_desc['cls_repr'][i].tolist() + uni_desc['atomic_reprs'][i][atom_id[i]-1].tolist())
                        # return torch.cat([uni_desc['cls_repr'],uni_desc['atomic_reprs'][atom_id-1]],axis = 0)
                    else:
                        uni_descs.append(uni_desc['atomic_reprs'][i][atom_id[i]-1].tolist())
                        # return uni_desc['atomic_reprs'][0][atom_id-1]
                elif desc_level == 'molecule':
                    uni_descs.append(uni_desc['cls_repr'][i].tolist())
                    # return uni_desc['cls_repr'][0]
                else:
                    raise ValueError('UnKnown Desc Level, u should use atom or molecule')
        else:
            for i in range(input['unimol']['src_tokens'].shape[0] // max_num + 1):
                batch = {}
                for k in input['unimol'].keys():
                    batch[k] = input['unimol'][k][i*max_num:(i+1)*max_num]
                uni_desc = self.generator.get_reprs(batch)

                for i in range(len(uni_desc['cls_repr'])):
                    if desc_level == 'atom':
                        if only_atom_repr == True:
                            uni_descs.append(uni_desc['cls_repr'][i].tolist() + uni_desc['atomic_reprs'][i][atom_id[i]-1].tolist())
                            # return torch.cat([uni_desc['cls_repr'],uni_desc['atomic_reprs'][atom_id-1]],axis = 0)
                        else:
                            uni_descs.append(uni_desc['atomic_reprs'][i][atom_id[i]-1].tolist())
                            # return uni_desc['atomic_reprs'][0][atom_id-1]
                    elif desc_level == 'molecule':
                        uni_descs.append(uni_desc['cls_repr'][i].tolist())
                        # return uni_desc['cls_repr'][0]
                    else:
                        raise ValueError('UnKnown Desc Level, u should use atom or molecule')

        return uni_descs
        
    def get2fintune(self, input, desc_level, atom_repr = True):
        uni_desc = self.generator.get_repr(input)
        if desc_level == 'atom':
            atom_id = input['atom']
            if atom_repr == True:
                return torch.cat([uni_desc['cls_repr'][0],uni_desc['atomic_reprs'][0][atom_id-1]],axis = 0)
            else:
                return uni_desc['atomic_reprs'][0][atom_id-1]
        elif desc_level == 'molecule':
            return uni_desc['cls_repr'][0]
        else:
            raise ValueError('UnKnown Desc Level, u should use atom or molecule')
        
    def get_models(self):
        return 'unimol', self.generator.model

    def mols2src(self):
        pass


In [3]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset
import torch.optim as optim
import pickle
from sklearn.preprocessing import StandardScaler
import torch
import random

class ElectrolyteDataset(Dataset):
    def __init__(self, data_path):

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.finetune_unimol_model = False
        
        csv_data = pd.read_csv(data_path)
        self.scaler = StandardScaler()
        self.search_space = csv_data.columns
        self.data = csv_data.to_numpy()
        self.property = self.scaler.fit_transform(self.data[:,-4:])
        self.cation_conc = self.data[:,:1]
        self.anion_conc = self.data[:,1:2]
        self.solvent_conc = self.data[:,2:-4]

        self.cation_types = self.search_space[:1]
        self.anion_types = self.search_space[1:2]
        self.solvent_types = self.search_space[2:-4]

        f_read = open('/vepfs/fs_users/ycjin/electrolyte/data/structure/solvent_input_dict.pkl', 'rb')
        self.solvent_structure = pickle.load(f_read)
        f_read.close()

        f_read = open('/vepfs/fs_users/ycjin/electrolyte/data/structure/anion_input_dict.pkl', 'rb')
        self.anion_structure = pickle.load(f_read)
        f_read.close()

        self.unirepr_Gen = UniRepr_Generator(finetune = True)

        if not self.finetune_unimol_model:
            self.solvent_descs = self.trans_Solvent2Desc().detach()
            self.anion_descs = self.trans_Anion2Desc().detach()
            self.cation_descs = self.trans_Cation2Desc()


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        property = torch.tensor(self.property[index]).float()
        
        solvent_conc = torch.tensor(self.solvent_conc[index])
        anion_conc = torch.tensor(self.anion_conc[index])
        cation_conc = torch.tensor(self.cation_conc[index])

        if self.finetune_unimol_model:
            solvent_descs = self.trans_Solvent2Desc()
            anion_descs = self.trans_Anion2Desc()
            cation_descs = self.trans_Cation2Desc()
        else:
            solvent_descs = self.solvent_descs.float()
            anion_descs = self.anion_descs.float()
            cation_descs = self.cation_descs.float()

        solvent_remain_idx = self.gen_remain_index(solvent_conc)
        solvent_conc = torch.index_select(solvent_conc, 0, solvent_remain_idx).float()
        solvent_descs = torch.index_select(solvent_descs, 0, solvent_remain_idx.to(self.device)).float()

        solvent_shuffle_idx = self.shuffle_idx(solvent_conc.shape[0])
        solvent_conc = torch.index_select(solvent_conc, 0, solvent_shuffle_idx).float()
        solvent_descs = torch.index_select(solvent_descs, 0, solvent_shuffle_idx.to(self.device)).float()

        anion_shuffle_idx = self.shuffle_idx(anion_conc.shape[0])
        anion_conc = torch.index_select(anion_conc, 0, anion_shuffle_idx).float()
        anion_descs = torch.index_select(anion_descs, 0, anion_shuffle_idx.to(self.device)).float()

        cation_shuffle_idx = self.shuffle_idx(cation_conc.shape[0])
        cation_conc = torch.index_select(cation_conc, 0, cation_shuffle_idx).float()
        cation_descs = torch.index_select(cation_descs, 0, cation_shuffle_idx).float()

        conc = torch.concat([solvent_conc,anion_conc,cation_conc]).float()

        conc_cilp = torch.tensor([solvent_conc.shape[0], anion_conc.shape[0], cation_conc.shape[0]]).int()

        return property.to(self.device), solvent_descs, anion_descs, cation_descs.to(self.device), conc.to(self.device), conc_cilp
        # return property.to(self.device), torch.rand_like(solvent_descs), torch.rand_like(anion_descs), torch.rand_like(cation_descs).to(self.device), conc.to(self.device), conc_cilp  
    
    def trans_Solvent2Desc(self):

        desc = []
        for k in self.solvent_structure.keys():
            input = self.solvent_structure[k]
            desc.append(self.unirepr_Gen.get2fintune(input,'molecule',False))

        return torch.stack(desc)

    def trans_Anion2Desc(self):

        desc = []
        for k in self.anion_structure.keys():
            input = self.anion_structure[k]
            desc.append(self.unirepr_Gen.get2fintune(input,'molecule',False))

        return torch.stack(desc)

    def trans_Cation2Desc(self):
        desc = []
        for k in self.cation_types:
            if k == 'Li':desc.append(torch.tensor([2,0,0,0,0,0,0]))
        return torch.stack(desc)
    
    def gen_remain_index(self,conc,remain = 4):
        zero_inedx = torch.where(conc == 0)[0]
        no_zero_inedx = torch.where(conc!= 0)[0]
        zero_inedx_remain = torch.Tensor(random.sample(zero_inedx.tolist(), remain - no_zero_inedx.shape[0] )).int()
        return torch.cat([no_zero_inedx, zero_inedx_remain])
    
    def shuffle_idx(self,lens):
        return torch.randperm(lens)
    
dataset = ElectrolyteDataset('/vepfs/fs_users/ycjin/electrolyte/data/raw_data/data_percent.csv')
property, solvent_descs, anion_descs, cation_descs, conc, conc_clip = dataset[0]

2024-01-03 06:02:15 | unimol_tools/models/unimol.py | 114 | INFO | Uni-Mol(QSAR) | Loading pretrained weights from /vepfs/fs_users/ycjin/Delta-ML-Framework/Unimol_2_NMR_fix/descriptior/unimol_tools/unimol_tools/weights/mol_pre_all_h_220816.pt


In [4]:
init_dict = {
    'property_input_dim': property.shape[0],
    'solute_anion_input_dim': anion_descs.shape[1],
    'solute_cation_input_dim': cation_descs.shape[1],
    'solvent_input_dim': solvent_descs.shape[1],
    'hidden_dim': 256
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
property_input = property.to(device)
solvent = solvent_descs
solute_anion = anion_descs
solute_cation = cation_descs.to(device)
model = ElectrolyteModel(**init_dict).to(device)
predict = model(property_input, solvent, solute_anion, solute_cation)

In [5]:
dataloader = DataLoader(dataset, batch_size=16, collate_fn=model.batch_collate_fn, shuffle=True)

In [9]:
MSE_LOSS_FUNC = nn.MSELoss()
def my_loss(pred,tgt,clip):
    ion_loss = []
    conc_loss = []
    sum_loss = []

    for i in range(len(tgt)):
        conc_loss.append(MSE_LOSS_FUNC(tgt[i],pred[i]))

        # ion_loss.append(torch.abs(torch.sum(pred[i][clip[i][0]:clip[i][0]+clip[i][1]]) 
        #                 - torch.sum(pred[i][clip[i][0]+clip[i][1]:clip[i][0]+clip[i][1]+clip[i][2]])))
        
        # sum_loss.append(torch.abs((torch.sum(pred[i][clip[i][0]:clip[i][0]+clip[i][1]]) 
        #                  + torch.sum(pred[i][clip[i][0]+clip[i][1]:clip[i][0]+clip[i][1]+clip[i][2]]))/2 
        #                  + torch.sum(pred[i][:clip[i][0]])-1))
        
    # total_loss = torch.mean(torch.stack(ion_loss)) + torch.mean(torch.stack(sum_loss)) + torch.mean(torch.stack(conc_loss))

    # return total_loss, torch.mean(torch.stack(conc_loss))
    return torch.mean(torch.stack(conc_loss)), torch.mean(torch.stack(conc_loss))

In [10]:
criterion = my_loss
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [11]:
model.train()

for epoch in range(2000):
    total_loss = []
    mse_loss = []
    for features, labels in dataloader:
        pred = []
        tgt = []
        clip = []

        for i,feature in enumerate(features):
            output = model(feature[0],feature[1],feature[2],feature[3])
            pred.append(output)
            tgt.append(labels[i][0])
            clip.append(labels[i][1])
        
        loss,mse = my_loss(pred, tgt, clip)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss.append(loss.item())
        mse_loss.append(mse.item())

    print(f"Epoch {epoch} loss: {sum(total_loss)/len(total_loss)}, mse: {sum(mse_loss)/len(mse_loss)}")

Epoch 0 loss: 0.07860474288463593, mse: 0.07860474288463593
Epoch 1 loss: 0.07851559370756149, mse: 0.07851559370756149
Epoch 2 loss: 0.07917868793010711, mse: 0.07917868793010711
Epoch 3 loss: 0.07736984640359879, mse: 0.07736984640359879
Epoch 4 loss: 0.07553595453500747, mse: 0.07553595453500747
Epoch 5 loss: 0.07572001218795776, mse: 0.07572001218795776
Epoch 6 loss: 0.07561949044466018, mse: 0.07561949044466018
Epoch 7 loss: 0.07511886805295945, mse: 0.07511886805295945
Epoch 8 loss: 0.07602327764034271, mse: 0.07602327764034271
Epoch 9 loss: 0.07279193848371505, mse: 0.07279193848371505
Epoch 10 loss: 0.07698667049407959, mse: 0.07698667049407959
Epoch 11 loss: 0.07385232895612717, mse: 0.07385232895612717
Epoch 12 loss: 0.07451834753155709, mse: 0.07451834753155709
Epoch 13 loss: 0.07340994328260422, mse: 0.07340994328260422
Epoch 14 loss: 0.0760337084531784, mse: 0.0760337084531784
Epoch 15 loss: 0.07534253150224686, mse: 0.07534253150224686
Epoch 16 loss: 0.07080193012952804, 

KeyboardInterrupt: 

In [12]:
pred

[tensor([0.0062, 0.0154, 0.8993, 0.0029, 0.0923, 0.1063], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([0.0009, 0.8553, 0.0144, 0.0053, 0.1603, 0.1729], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([0.0131, 0.8388, 0.0147, 0.0036, 0.1774, 0.1920], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([0.0430, 0.7623, 0.0084, 0.0175, 0.2031, 0.1836], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([9.0601e-03, 4.9292e-04, 8.2450e-01, 1.3358e-02, 1.4362e-01, 1.5000e-01],
        device='cuda:0', grad_fn=<AbsBackward0>),
 tensor([0.7801, 0.0142, 0.0323, 0.0066, 0.2125, 0.2170], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([0.7727, 0.0146, 0.0154, 0.0010, 0.2224, 0.2310], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([0.0131, 0.0016, 0.7535, 0.0009, 0.2437, 0.2594], device='cuda:0',
        grad_fn=<AbsBackward0>),
 tensor([0.7302, 0.0172, 0.0592, 0.0124, 0.1568, 0.1502], device='cuda:0',
        grad_fn=<AbsBackward0>),
 ten

In [13]:
tgt

[tensor([0.0000, 0.0000, 0.8701, 0.0000, 0.1299, 0.1299], device='cuda:0'),
 tensor([0.0000, 0.8571, 0.0000, 0.0000, 0.1429, 0.1429], device='cuda:0'),
 tensor([0.0000, 0.8000, 0.0000, 0.0000, 0.2000, 0.2000], device='cuda:0'),
 tensor([0.7500, 0.0000, 0.0000, 0.0000, 0.2500, 0.2500], device='cuda:0'),
 tensor([0.0000, 0.0000, 0.7826, 0.0870, 0.1304, 0.1304], device='cuda:0'),
 tensor([0.8276, 0.0000, 0.0000, 0.0000, 0.1724, 0.1724], device='cuda:0'),
 tensor([0.7436, 0.0000, 0.0000, 0.0000, 0.2564, 0.2564], device='cuda:0'),
 tensor([0.0000, 0.0000, 0.7500, 0.0000, 0.2500, 0.2500], device='cuda:0'),
 tensor([0.8889, 0.0000, 0.0000, 0.0000, 0.1111, 0.1111], device='cuda:0'),
 tensor([0.0000, 0.0000, 0.0000, 0.8000, 0.2000, 0.2000], device='cuda:0'),
 tensor([0.6087, 0.0000, 0.2609, 0.0000, 0.1304, 0.1304], device='cuda:0'),
 tensor([0.0000, 0.0000, 0.8113, 0.0000, 0.1887, 0.1887], device='cuda:0'),
 tensor([0.8000, 0.0000, 0.0000, 0.0000, 0.2000, 0.2000], device='cuda:0'),
 tensor([0.0