In [1]:
import argparse 
import sys
sys.path.append('./grover/')
import grover.util.parsing
import grover.model.models
import task.train

## Use pretrain config
parser = argparse.ArgumentParser()
subparser = parser.add_subparsers(title="subcommands",
                                    dest="parser_name",
                                    help="Subcommands for finetune, prediction, and fingerprint.")
parser_finetune = subparser.add_parser('finetune', help="Fine tune the pre-trained model.")
grover.util.parsing.add_finetune_args(parser_finetune)
parser_eval = subparser.add_parser('eval', help="Evaluate the results of the pre-trained model.")
grover.util.parsing.add_finetune_args(parser_eval)
parser_predict = subparser.add_parser('predict', help="Predict results from fine tuned model.")
grover.util.parsing.add_predict_args(parser_predict)
parser_fp = subparser.add_parser('fingerprint', help="Get the fingerprints of SMILES.")
grover.util.parsing.add_fingerprint_args(parser_fp)
parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
grover.util.parsing.add_pretrain_args(parser_pretrain)

grover_args = parser.parse_args("finetune --data_path grover/exampledata/finetune/bbbp.csv \
                        --features_path grover/exampledata/finetune/bbbp.npz \
                        --save_dir grover/model/finetune/bbbp/ \
                        --checkpoint_path grover/model/tryout/model.ep3 \
                        --dataset_type classification \
                        --split_type scaffold_balanced \
                        --ensemble_size 1 \
                        --num_folds 3 \
                        --no_features_scaling \
                        --ffn_hidden_size 200 \
                        --batch_size 32 \
                        --epochs 10 \
                        --init_lr 0.00015 \
                        --no_cuda".split())

grover.util.parsing.modify_train_args(grover_args)
# train_args = grover.util.parsing.get_newest_train_args()
features_scaler, scaler, shared_dict, test_data, train_data, val_data = task.train.load_data(grover_args, print, None)
grover_model = grover.util.utils.load_checkpoint("./grover/grover_large.pt", current_args=grover_args, logger=None)



  from .autonotebook import tqdm as notebook_tqdm


Loading data




Number of tasks = 1
Splitting data with seed 0


100%|##########| 2039/2039 [00:00<00:00, 3968.77it/s]


Class sizes
p_np 0: 23.49%, 1: 76.51%
Total size = 2,039 | train size = 1,631 | val size = 203 | test size = 205
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_q.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_q.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_k.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_k.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_v.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_v.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_q.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_q.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_k.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_k.W_h.w

In [2]:
# import grover.data
# import torch

# grover_linear_layer = torch.nn.Linear(1200, 768)
# grover_layer = grover.model.layers.GTransEncoder(args=grover_args,
#                                           hidden_size=768,
#                                           edge_fdim=768,
#                                           node_fdim=768,
#                                           dropout=grover_args.dropout,
#                                           activation=grover_args.activation,
#                                           num_mt_block=1,
#                                           num_attn_head=grover_args.num_attn_head,
#                                           atom_emb_output="atom",
#                                           bias=grover_args.bias,
#                                           cuda=grover_args.cuda)

# mol_collator = grover.data.MolCollator(shared_dict=shared_dict, args=grover_args)

# num_workers = 4
# mol_loader = torch.utils.data.DataLoader(train_data, batch_size=grover_args.batch_size, shuffle=True,
#                     num_workers=num_workers, collate_fn=mol_collator)

# grover_model.train()
# for item in mol_loader:
#     _, batch, features_batch, mask, targets = item
#     f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch
#     # if next(model.parameters()).is_cuda:
#     #     mask, targets = mask.cuda(), targets.cuda()

#     # Run model
#     grover_model.zero_grad()
    
#     output = grover_model.grover((f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a))

#     f_atoms, f_bonds = output['atom_from_atom'], output['bond_from_atom']
#     f_atoms = grover_linear_layer(f_atoms)
#     f_bonds = grover_linear_layer(f_bonds)
#     print(f_atoms.shape, f_bonds.shape)
    
#     output = grover_layer((f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a))
#     f_atoms = output[0]
#     print(f_atoms.shape, f_bonds.shape)

#     output = grover_layer((f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a))
#     f_atoms = output[0]
#     print(f_atoms.shape, f_bonds.shape)

#     break

In [3]:
# import transformers
# import torch

# bert_path = "dmis-lab/biobert-base-cased-v1.2"
# tokenizer = transformers.AutoTokenizer.from_pretrained(bert_path)
# bert_model = transformers.AutoModel.from_pretrained(bert_path)

# additional_layer = transformers.models.bert.BertLayer(config=bert_model.config)

# text = ["I love drug summary", "I love criteria"]

# encoded_input = tokenizer(text=text, return_tensors='pt', truncation=True, max_length=512, padding=True)

# model_output  = bert_model(
#     input_ids=encoded_input['input_ids'], 
#     token_type_ids=encoded_input['token_type_ids'], 
#     attention_mask=encoded_input['attention_mask'])

# print(model_output.keys())
# last_hidden_state = model_output['last_hidden_state']
# extended_attention_mask = bert_model.get_extended_attention_mask(
#     attention_mask=encoded_input['attention_mask'], 
#     input_shape=encoded_input['input_ids'].shape)

# output = additional_layer(last_hidden_state, extended_attention_mask)
# print(len(output), output[0].shape)

# output = additional_layer(output[0], extended_attention_mask)
# print(len(output), output[0].shape)

# # model_output

In [4]:
import torch
import transformers

class CrossEncoderLayer(torch.nn.Module):
    def __init__(self, grover_args, bert_config, hidden_dim=768, nhead=8) -> None:
        super().__init__()

        self.transformer_layer = torch.nn.TransformerEncoder(
            encoder_layer=torch.nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, batch_first=True),
            num_layers=1)

        self.grover_layer = grover.model.layers.GTransEncoder(
            args=grover_args,
            hidden_size=hidden_dim,
            edge_fdim=hidden_dim,
            node_fdim=hidden_dim,
            dropout=grover_args.dropout,
            activation=grover_args.activation,
            num_mt_block=1,
            num_attn_head=grover_args.num_attn_head,
            atom_emb_output="atom",
            bias=grover_args.bias,
            cuda=grover_args.cuda)
            
        self.bert_layer = transformers.models.bert.BertLayer(config=bert_config)

    def forward(self, new_text_batch, new_grover_batch):
        last_hidden_state, extended_attention_mask = new_text_batch
        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = new_grover_batch

        len_split = last_hidden_state.shape[1]
        ## last_hidden_state = (1, seq_len, 768)
        ## f_atoms = (n_atoms, 768)
        trans_out = self.transformer_layer(torch.cat([last_hidden_state, f_atoms.unsqueeze(dim=0)], dim=1))
        last_hidden_state, f_atoms = trans_out[:,:len_split], trans_out[:,len_split:].squeeze()

        last_hidden_state = self.bert_layer(last_hidden_state, extended_attention_mask)[0]
        grover_out = self.grover_layer((f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a))
        f_atoms = grover_out[0]

        return (last_hidden_state, extended_attention_mask), (f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a)

class CrossMol(torch.nn.Module):
    def __init__(self, grover_args, grover_path='./grover/grover_large.pt', bert_path='dmis-lab/biobert-base-cased-v1.2', 
        num_ca_layers=2):
        super().__init__()
        self.num_ca_layers = num_ca_layers
        self.grover_args = grover_args
        
        self.grover_model = grover.util.utils.load_checkpoint(path=grover_path, current_args=grover_args, logger=None)
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path=bert_path)
        self.bert_model = transformers.AutoModel.from_pretrained(pretrained_model_name_or_path=bert_path)
        self.grover_linear_layer = torch.nn.Linear(1200, 768)

        self.cross_encoder_layers = torch.nn.ModuleList()
        for i in range(self.num_ca_layers):
            self.cross_encoder_layers.append(CrossEncoderLayer(grover_args=grover_args, bert_config=self.bert_model.config))

    def forward(self, text_batch, grover_batch):
        # ========== Initial pass through grover molecule encoder ==========
        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = grover_batch
        grover_output = self.grover_model.grover((f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a))
        f_atoms = self.grover_linear_layer(grover_output['atom_from_atom'])
        f_bonds = self.grover_linear_layer(grover_output['bond_from_atom'])
        new_grover_batch = (f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a)

        # ========== Initial pass through bert text encoder ==========
        encoded_input = self.tokenizer(text=text_batch, return_tensors='pt', truncation=True, max_length=512, padding=True)
        bert_output  = self.bert_model(
            input_ids=encoded_input['input_ids'], 
            token_type_ids=encoded_input['token_type_ids'], 
            attention_mask=encoded_input['attention_mask'])
        last_hidden_state = bert_output['last_hidden_state']
        extended_attention_mask = self.bert_model.get_extended_attention_mask(
            attention_mask=encoded_input['attention_mask'], 
            input_shape=encoded_input['input_ids'].shape)
        new_text_batch = (last_hidden_state, extended_attention_mask)

        # ========== Cross Attention Encoder Layers ==========
        for layer in self.cross_encoder_layers:
            new_text_batch, new_grover_batch = layer(new_text_batch, new_grover_batch)

        return new_text_batch, new_grover_batch


crossmol = CrossMol(grover_args=grover_args)

Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_q.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_q.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_k.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_k.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_v.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.0.mpn_v.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_q.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_q.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_k.act_func.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_k.W_h.weight".
Loading pretrained parameter "grover.encoders.edge_blocks.0.heads.1.mpn_v.act_func.weight".
Loading pretr

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.2 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
### ========== retrieval metrics ==========
import numpy as np
import sklearn.metrics

def get_ranks(embedding1, embedding2):
    # assumes that the true value is diagonal
    ranks_tmp = []
    emb = sklearn.metrics.pairwise.cosine_similarity(embedding1, embedding2) # shape: (len(embedding1), len(embedding2))
    for k in range(emb.shape[0]):
        cid_locs = np.argsort(emb[k])[::-1] #sort high-to-low each column
        ranks = np.argsort(cid_locs) # get rank (original array order, but with rank instead of value)
        ranks_tmp.append(ranks[k] + 1)
    return np.array(ranks_tmp)

def print_ranks(ranks):
    print("Mean rank: {:.2}", np.mean(ranks))
    print("Hits at 1:", np.mean(ranks <= 1))
    print("Hits at 10:", np.mean(ranks <= 10))
    print("Hits at 20:", np.mean(ranks <= 20))
    print("Hits at 100:", np.mean(ranks <= 100))
    print("Hits at 500:", np.mean(ranks <= 500))
    print("Hits at 1000:", np.mean(ranks <= 1000))
    print("MRR:", np.mean(1/ranks))

# np.random.seed(seed=0)
# e2 = np.random.random(size=(1000,768))
# e3 = np.random.random(size=(1000,768))
# # for i in range(len(e2)):
# #     e2[i] += e3[i]*.1 # correlate text and mol embeddings

# test_output = np.load('test_output_no_ca.npy', allow_pickle=True)
test_output = np.load('test_output.npy', allow_pickle=True)
print(test_output.shape)
e1_mol = np.concatenate(test_output[:,0])
e1_text = np.concatenate(test_output[:,1])
e2_text = np.concatenate(test_output[:,2])
e3_mol = np.concatenate(test_output[:,3])

# retrieve mol from text
ranks_tmp = get_ranks(e2_text, e3_mol)
print_ranks(ranks_tmp)

# retrieve text from mol
ranks_tmp = get_ranks(e3_mol, e2_text)
print_ranks(ranks_tmp)



(413, 4)
Mean rank: {:.2} 61.95728567100878
Hits at 1: 0.07028173280823993
Hits at 10: 0.32747652226598
Hits at 20: 0.46834292638594366
Hits at 100: 0.831869130566495
Hits at 500: 0.9890942138745834
Hits at 1000: 0.9954559224477431
MRR: 0.15397264884356102
Mean rank: {:.2} 58.93547409875795
Hits at 1: 0.08239927294759164
Hits at 10: 0.36746440472584063
Hits at 20: 0.5004544077552256
Hits at 100: 0.8433807936988791
Hits at 500: 0.9900030293850348
Hits at 1000: 0.9963647379581945
MRR: 0.1735390775481024


In [42]:
import numpy as np

use_ca = False
u_list = np.load('link_pred_ca={}_u_list.npy'.format(str(use_ca)))
v_list = np.load('link_pred_ca={}_v_list.npy'.format(str(use_ca)))
labels = np.load('link_pred_ca=True_labels.npy')

u_list = u_list[labels==1]
v_list = v_list[labels==1]
u_list = u_list[:1000] 
v_list = v_list[:1000]
print(u_list.shape, v_list.shape, labels.shape)

ranks_tmp = get_ranks(u_list, v_list)
print_ranks(ranks_tmp)


(1000, 1200) (1000, 1200) (230332,)
Mean rank: {:.2} 473.425
Hits at 1: 0.001
Hits at 10: 0.012
Hits at 20: 0.03
Hits at 100: 0.143
Hits at 500: 0.535
Hits at 1000: 1.0
MRR: 0.008686032770797715
