In [1]:
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import pandas as pd
import torch
from torch.nn import functional as F
from normalize_text import normalize
import re

# Data example

In [2]:
# utils

alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov)"
digits = "([0-9])"

def split_into_sentences(text):
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    if "..." in text: text = text.replace("...","<prd><prd><prd>")
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = sentences[:-1]
    sentences = [s.strip() for s in sentences]
    return sentences

def tokenize_sentences(sentences, tokenizer, max_length):
    norm_sents = [normalize(s) for s in sentences]
    tokenized_sents = tokenizer(norm_sents, padding='max_length', truncation=True, max_length=max_length)
    tokenized_sents = {k: torch.Tensor(v).long() for k, v in tokenized_sents.items()}
    return tokenized_sents

def get_sum_hidden_emb(tokenized_sents, model):
    with torch.no_grad():
        hidden_states = model(**tokenized_sents, output_hidden_states=True).hidden_states
    stack_embs = torch.stack(hidden_states[-4:], dim=0) # use last 4 layers
    embs = torch.sum(stack_embs, dim=0)

    return torch.flatten(embs.mean(dim=0))

In [24]:
abs_list = [
    r"The shuttling of soluble sodium polysulfides (Na2Sn) and sluggish conversion kinetics are major roadblocks toward the practical realization of sodium–sulfur (Na–S) batteries. To undertake the challenges, we use first-principles calculations to design bifunctional electrocatalysts to achieve engineered interfaces with sulfur-based cathode materials. We illustrate the detailed behavior of Na2Sn adsorption, sulfur reduction reactions (SRRs), and catalytic decomposition on transition-metal (TM)-based single-atom catalysts (SACs) embedded on MoS2 substrates (SACs@MoS2). We observe that SACs doped on sulfur substitution and molybdenum top sites result in adequate binding energies to immobilize higher-order Na2Sn species. We found the d-band center as an important “descriptor” in dictating polysulfide adsorption energies and catalytic activities on SACs@MoS2. We elucidate that the larger upward shift of the d-band center toward the Fermi level and the involved higher number of vacant antibonding states are directly correlated to the adsorption strength of the Na2Sn. The V and Ni SACs are found to exhibit higher and lower binding energies, respectively, consistent with the d-band theory. Furthermore, the SACs that are electron-deficient sites demonstrate bifunctional electrocatalytic activity through reduced free energy for SRR and lower the barrier for Na2S decomposition in favor of accelerated electrode kinetics during discharge and charge processes, respectively. The electronic structure calculations reveal a significantly reduced band gap of the pristine and Na2Sn-adsorbed SACs@MoS2 due to mid-gap states, majorly stemming from TM-d orbitals, thus expected to improve the electronic conductivity of the substrates. The insight developed on the role of SACs in tailoring the polysulfides’ chemistry at the interfaces in relation to their d-band center is an important step toward the rational design of cathode materials for high-performance Na–S batteries."
]

# 1. load Bert and MLP

In [4]:
# BERT
tokenizer = AutoTokenizer.from_pretrained('m3rg-iitd/matscibert')
model = AutoModel.from_pretrained('m3rg-iitd/matscibert')

Some weights of the model checkpoint at m3rg-iitd/matscibert were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at m3rg-iitd/matscibert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
Y

In [5]:
# BERT setting
max_length = 128

In [6]:
class MLP(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, dropout=0):
        super().__init__()
        self.linears = nn.ModuleList()
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        for _ in range(num_layers-2):
            self.linears.append(nn.Linear(hidden_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.dropout = nn.Dropout(dropout)
        # self.batch_norm = nn.BatchNorm1d((hidden_dim))

    def forward(self, x):
        h = x
        for layer in self.linears[:-1]:
            h = F.relu(layer(self.dropout(h)))
        # output layer
        out = self.linears[-1](h)
        return out
    
    def retrieve_emb(self, x):
        h = x
        for layer in self.linears[:-1]:
            h = F.relu(layer(h))
        return h

In [115]:
# load MLP
n_layer = 4
in_dim = 98304
out_dim = 6
hidden_dim = 128

mlp = MLP(n_layer, in_dim, hidden_dim, out_dim, dropout=0)
mlp.load_state_dict(torch.load("./mlp_models/080valacc.pt"))

<All keys matched successfully>

# 2. Get BERT Embeddings

In [32]:
split_abs_list = [split_into_sentences(abs) for abs in abs_list]
tokenized_list = [tokenize_sentences(split_abs, tokenizer, max_length) for split_abs in split_abs_list]
emb_list = [get_sum_hidden_emb(tokenized_abs, model) for tokenized_abs in tokenized_list]

# 3. Get MLP Augmented Embeddings

In [116]:
mlp.eval()
with torch.no_grad():
    aug_embs = mlp.retrieve_emb(torch.stack(emb_list, dim=0))

# 4. Top-K

In [34]:
def get_similarities(embs_1, embs_2):
    similarities = torch.mm(embs_1, embs_2.T)
    norm_1 = torch.norm(embs_1, dim=1)
    norm_2 = torch.norm(embs_2, dim=1)
    norms = torch.mm(norm_1.view(-1, 1), norm_2.view(1, -1))
    similarities = similarities / norms
    similarities_dict = {}
    for i in range(similarities.shape[0]):
        similarities_dict[i] = similarities[i].tolist()

    return similarities_dict

def get_pair_ranks(sim_list):
    sorted_id = sorted(range(len(sim_list)), key=lambda k: sim_list[k], reverse=True)

    return sorted_id

def create_similarities_df_with_doi(compared_df, emb_list_aim, emb_list_compared, K=None):
    sim = get_similarities(emb_list_aim, emb_list_compared)
    topk = {}

    print("Calculating topk...")
    
    if K is not None:
        for key in sim.keys():
            topk[key] = get_pair_ranks(sim[key])[:K]
    else:
        for key in sim.keys():
            topk[key] = get_pair_ranks(sim[key])

    print("Constructing df...")
    df_list = []
    for key in topk.keys():
        idx_length = len(emb_list_compared) if K is None else K
        idx_list = [key for _ in range(idx_length)]
        temp_df = compared_df[['class', 'title', 'doi']].iloc[topk[key]]
        temp_df = temp_df.reset_index(drop= True)

        sorted_sim = []
        for id in topk[key]:
            sorted_sim.append(sim[key][id])
        sim_df = pd.DataFrame({'sim': sorted_sim})
        id_df = pd.DataFrame({'id':idx_list})
        df = pd.concat([id_df, temp_df, sim_df], axis=1)
        df_list.append(df)
    return pd.concat(df_list)

In [185]:
# load abstracts for comparison
aug_nas = torch.load('./embs/mlp_nas.pt')
aug_sa = torch.load('./embs/mlp_sa.pt')
aug_lis = torch.load('./embs/mlp_lis.pt')
aug_lissa = torch.load('./embs/mlp_lissa.pt')
aug_ir1 = torch.load('./embs/mlp_ir1.pt')
aug_ir2 = torch.load('./embs/mlp_ir2.pt')
aug_sup = torch.load("./embs/mlp_sup.pt")

aug_nassa = torch.load('./embs/mlp_nassa.pt')

aug_compare_embs = torch.cat([aug_nas, aug_sa, aug_lis, aug_lissa, aug_ir1, aug_ir2, aug_sup])

In [181]:
all_df_with_class = pd.read_csv('./abs_data/all_abs_wo_nasa_v1.csv')

In [186]:
# retrieve top-k
k = 30
topk_df = create_similarities_df_with_doi(all_df_with_class, aug_embs, aug_compare_embs, K=k)

Calculating topk...
Constructing df...


In [187]:
topk_df

Unnamed: 0,id,class,title,doi,sim
0,0,supplement,"Insights on forming N,O-coordinated Cu single-...",10.1038/s41467-020-20769-x,0.999997
1,0,supplement,Fluorine-tuned single-atom catalysts with dens...,10.1016/j.apcatb.2020.119591,0.999997
2,0,supplement,"Fe Isolated Single Atoms on S, N Codoped Carbo...",10.1002/adma.201800588,0.999996
3,0,supplement,Coordination-tuned Fe single-atom catalyst for...,10.1016/j.cej.2021.134270,0.999996
4,0,supplement,Axial coordination regulation of MOF-based sin...,10.1007/s12274-022-4467-3,0.999996
5,0,,Platinum atoms and nanoparticles embedded poro...,10.3390/ma13071513,0.999948
6,0,,Fabricating Single-Atom Catalysts from Chelati...,10.1002/adma.201808193,0.999946
7,0,,Theoretical Screening of Transition Metal-N4-D...,10.1021/acscatal.2c00307,0.999924
8,0,,ZnO monolayer supported single atom catalysts ...,10.1016/j.apsusc.2021.149682,0.999907
9,0,,Dual single-cobalt atom-based carbon electroca...,10.1016/j.apcatb.2021.120092,0.999896
