In [105]:
import polars as pl
import numpy as np
import numpy.typing as npt
import xgboost as xgb
from rdkit.Chem import AllChem
from Bio import Align

import json
from typing import List
import heapq

In [182]:
class Model:
    def __init__(self, ids, models, config, seqs, n_jobs=-1):
        def load_model(p, n_jobs):
            r = xgb.XGBRegressor()
            r.load_model(p)
            r.set_params(n_jobs=n_jobs)
            return r
        
        # uniprot to model
        self.knownModels = {i: load_model(m, n_jobs) for i, m in zip(ids, models)}
        # uniprot to List<similar models>
        self.similarModels = dict()
        self.aligner = Align.PairwiseAligner(scoring='blastp')
        
        with open(seqs, 'r') as f:
            self.knownSeqs = dict(map(lambda l: l.strip().split('\t'), f))
                
        with open(config, 'r') as f:
            self.config = json.load(f)
        
    def predict(self, smiles, uniprot) -> float:
        fp = self.fingerprint(smiles).reshape(1, -1)
        
        if uniprot in self.knownModels:
            return self.knownModels[uniprot].predict(fp)
        else:
            seq = self.fetch_seq(uniprot)
            models = self.get_similar(seq)
            self.similarModels[uniprot] = models
            # get 3? nearest models and average prediction
            return np.mean([m.predict(fp) for m in models])
            
        
    def fingerprint(self, smiles: str) -> npt.NDArray:
        mol = AllChem.MolFromSmiles(smiles)
        c = self.config['fingerprints']
        fp = AllChem.GetMorganFingerprintAsBitVect(
            mol, 
            radius=c['radius'],
            nBits=c['bitSize'],
            useFeatures=c['useFeatures'],
            useChirality=c['useChirality'],
        )
        
        return np.array(fp, dtype=np.uint8)
    
    def fetch_seq(self, uid: str) -> str:
        url = 'https://rest.uniprot.org/uniprotkb/stream'
        params = {
            'query': f'(accession:{uid})',
            'fields': 'accession,sequence',
            'format':'tsv',
        }
        res = requests.get(url, params=params)
        # return dict(map(lambda s: s.split('\t'), res.text.splitlines()[1:]))
        return res.text.splitlines()[1].split('\t')[1]
        
    
    def get_similar(self, target: str, k=3) -> List[xgb.XGBRegressor]:
        zip_seq_model = ((self.knownSeqs[uid], model) for uid, model in self.knownModels.items())
        map_score = lambda query, model: (self.aligner.score(target, query), model)
        it = map(lambda t: map_score(*t), zip_seq_model)
        
        return [res[1] for res in heapq.nlargest(k, it, key=lambda x: x[0])]
        

In [25]:
r = xgb.XGBRegressor()
r.load_model('test/models/00004.ubj')
r.set_params(n_jobs=-1)
r.predict

<bound method XGBModel.predict of XGBRegressor(base_score=None, booster=None, callbacks=None,
             colsample_bylevel=None, colsample_bynode=None,
             colsample_bytree=None, early_stopping_rounds=None,
             enable_categorical=False, eval_metric=None, feature_types=None,
             gamma=None, gpu_id=None, grow_policy=None, importance_type=None,
             interaction_constraints=None, learning_rate=None, max_bin=None,
             max_cat_threshold=None, max_cat_to_onehot=None,
             max_delta_step=None, max_depth=2, max_leaves=14,
             min_child_weight=None, missing=nan, monotone_constraints=None,
             n_estimators=33, n_jobs=-1, num_parallel_tree=None, predictor=None,
             random_state=None, ...)>

In [86]:
import urllib
import requests

url = 'https://rest.uniprot.org/uniprotkb/stream'
params = {
    'query': '(accession:P13368) OR (accession:P20806)',
    'fields': 'accession,sequence',
    'format':'tsv',
#     'query':'P13368 P20806 Q9UM73 P97793 Q17192',
#     'columns': 'id,sequence'
}

res = requests.get(url, params=params)
# data = urllib.parse.urlencode(params)
# data = data.encode('ascii')
# request = urllib.request.Request(url, data)
# with urllib.request.urlopen(request) as res:
#     print(res)
#     res = res.read()
# data
res.text

'Entry\tSequence\nP13368\tMTMFWQQNVDHQSDEQDKQAKGAAPTKRLNISFNVKIAVNVNTKMTTTHINQQAPGTSSSSSNSQNASPSKIVVRQQSSSFDLRQQLARLGRQLASGQDGHGGISTILIINLLLLILLSICCDVCRSHNYTVHQSPEPVSKDQMRLLRPKLDSDVVEKVAIWHKHAAAAPPSIVEGIAISSRPQSTMAHHPDDRDRDRDPSEEQHGVDERMVLERVTRDCVQRCIVEEDLFLDEFGIQCEKADNGEKCYKTRCTKGCAQWYRALKELESCQEACLSLQFYPYDMPCIGACEMAQRDYWHLQRLAISHLVERTQPQLERAPRADGQSTPLTIRWAMHFPEHYLASRPFNIQYQFVDHHGEELDLEQEDQDASGETGSSAWFNLADYDCDEYYVCEILEALIPYTQYRFRFELPFGENRDEVLYSPATPAYQTPPEGAPISAPVIEHLMGLDDSHLAVHWHPGRFTNGPIEGYRLRLSSSEGNATSEQLVPAGRGSYIFSQLQAGTNYTLALSMINKQGEGPVAKGFVQTHSARNEKPAKDLTESVLLVGRRAVMWQSLEPAGENSMIYQSQEELADIAWSKREQQLWLLNVHGELRSLKFESGQMVSPAQQLKLDLGNISSGRWVPRRLSFDWLHHRLYFAMESPERNQSSFQIISTDLLGESAQKVGESFDLPVEQLEVDALNGWIFWRNEESLWRQDLHGRMIHRLLRIRQPGWFLVQPQHFIIHLMLPQEGKFLEISYDGGFKHPLPLPPPSNGAGNGPASSHWQSFALLGRSLLLPDSGQLILVEQQGQAASPSASWPLKNLPDCWAVILLVPESQPLTSAGGKPHSLKALLGAQAAKISWKEPERNPYQSADAARSWSYELEVLDVASQSAFSIRNIRGPIFGLQRLQPDNLYQLRVRAINVDGEPGEWTEPLAARTWPLGPHRLRWASRQGSVIHTNELGEGLEVQQEQLERLPGPMTMVNESVGYYVT

In [87]:
new_seqs = dict(map(lambda s: s.split('\t'), res.text.splitlines()[1:]))

In [95]:
aligner = Align.PairwiseAligner(scoring='blastp')

In [103]:
target, query = list(new_seqs.values())
aligner.score(target, query)

7767.0

In [94]:
target

'MTMFWQQNVDHQSDEQDKQAKGAAPTKRLNISFNVKIAVNVNTKMTTTHINQQAPGTSSSSSNSQNASPSKIVVRQQSSSFDLRQQLARLGRQLASGQDGHGGISTILIINLLLLILLSICCDVCRSHNYTVHQSPEPVSKDQMRLLRPKLDSDVVEKVAIWHKHAAAAPPSIVEGIAISSRPQSTMAHHPDDRDRDRDPSEEQHGVDERMVLERVTRDCVQRCIVEEDLFLDEFGIQCEKADNGEKCYKTRCTKGCAQWYRALKELESCQEACLSLQFYPYDMPCIGACEMAQRDYWHLQRLAISHLVERTQPQLERAPRADGQSTPLTIRWAMHFPEHYLASRPFNIQYQFVDHHGEELDLEQEDQDASGETGSSAWFNLADYDCDEYYVCEILEALIPYTQYRFRFELPFGENRDEVLYSPATPAYQTPPEGAPISAPVIEHLMGLDDSHLAVHWHPGRFTNGPIEGYRLRLSSSEGNATSEQLVPAGRGSYIFSQLQAGTNYTLALSMINKQGEGPVAKGFVQTHSARNEKPAKDLTESVLLVGRRAVMWQSLEPAGENSMIYQSQEELADIAWSKREQQLWLLNVHGELRSLKFESGQMVSPAQQLKLDLGNISSGRWVPRRLSFDWLHHRLYFAMESPERNQSSFQIISTDLLGESAQKVGESFDLPVEQLEVDALNGWIFWRNEESLWRQDLHGRMIHRLLRIRQPGWFLVQPQHFIIHLMLPQEGKFLEISYDGGFKHPLPLPPPSNGAGNGPASSHWQSFALLGRSLLLPDSGQLILVEQQGQAASPSASWPLKNLPDCWAVILLVPESQPLTSAGGKPHSLKALLGAQAAKISWKEPERNPYQSADAARSWSYELEVLDVASQSAFSIRNIRGPIFGLQRLQPDNLYQLRVRAINVDGEPGEWTEPLAARTWPLGPHRLRWASRQGSVIHTNELGEGLEVQQEQLERLPGPMTMVNESVGYYVTGDGLLHCINLVHSQWGCPISEPLQH

In [112]:
dfModels = pl.scan_csv('test/metrics/*.csv').collect()

dfModels['uniprot'], dfModels['output_model']

(shape: (5,)
 Series: 'uniprot' [str]
 [
 	"Q9HBH9"
 	"Q06418"
 	"Q9C098"
 	"Q9P1W9"
 	"Q05655"
 ],
 shape: (5,)
 Series: 'output_model' [str]
 [
 	"test/models/00…
 	"test/models/00…
 	"test/models/00…
 	"test/models/00…
 	"test/models/00…
 ])

In [183]:
m = Model(dfModels['uniprot'], dfModels['output_model'], 'configs/hu-b2048-r2-kikd.json', 'data/map-uniprot-seq.tsv')

In [184]:
# target = m.fetch_seq('Q2M2I8')
# m.fingerprint('CS(=O)(=O)Nc1cccc(c1)-c1ccc2c(NC(=O)C3CC3)n[nH]c2c1')
# zip_seq_model = ((m.knownSeqs[uid], model) for uid, model in m.knownModels.items())
# map_score = lambda query, model: (m.aligner.score(target, query), model)

In [185]:
m.predict('CS(=O)(=O)Nc1cccc(c1)-c1ccc2c(NC(=O)C3CC3)n[nH]c2c1', 'Q2M2I8')

5.902614

In [160]:
list(map(lambda t: map_score(*t), zip_seq_model))

[(-388.0,
  XGBRegressor(base_score=None, booster=None, callbacks=None,
               colsample_bylevel=None, colsample_bynode=None,
               colsample_bytree=None, early_stopping_rounds=None,
               enable_categorical=False, eval_metric=None, feature_types=None,
               gamma=None, gpu_id=None, grow_policy=None, importance_type=None,
               interaction_constraints=None, learning_rate=None, max_bin=None,
               max_cat_threshold=None, max_cat_to_onehot=None,
               max_delta_step=None, max_depth=2, max_leaves=14,
               min_child_weight=None, missing=nan, monotone_constraints=None,
               n_estimators=33, n_jobs=-1, num_parallel_tree=None, predictor=None,
               random_state=None, ...)),
 (-221.0,
  XGBRegressor(base_score=None, booster=None, callbacks=None,
               colsample_bylevel=None, colsample_bynode=None,
               colsample_bytree=None, early_stopping_rounds=None,
               enable_categorical

In [186]:
import pickle

with open('test.pkl', 'wb') as f:
    pickle.dump(m, f)