In [1]:
import os
os.chdir('/content/drive/MyDrive/mini')

In [2]:
!pip install subword-nmt



In [3]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import json

from sklearn.preprocessing import OneHotEncoder

from subword_nmt.apply_bpe import BPE #서브워드 분할 함수,
import codecs

vocab_path = './ESPF/protein_codes_uniprot.txt'
bpe_codes_protein = codecs.open(vocab_path) #encoding default -> windows cp1252 , linux UTF-8
pbpe = BPE(bpe_codes_protein, merges=-1, separator='') # bpe rule을 학습시키는 과정은 들어가 있지 않음 , merges=-1 모든 병합규칙을 사용한다는 의미, separator=' '해야 이후 .split()제대로 작동할 것 같은데
sub_csv = pd.read_csv('./ESPF/subword_units_map_uniprot.csv')

idx2word_p = sub_csv['index'].values # protein sub-sequence(word)
words2idx_p = dict(zip(idx2word_p, range(0, len(idx2word_p)))) #dict((key,value)), zip-> tuple, key(sub_word)-> value(index)

vocab_path = './ESPF/drug_codes_chembl.txt'
bpe_codes_drug = codecs.open(vocab_path)
dbpe = BPE(bpe_codes_drug, merges=-1, separator='')
sub_csv = pd.read_csv('./ESPF/subword_units_map_chembl.csv')

idx2word_d = sub_csv['index'].values
words2idx_d = dict(zip(idx2word_d, range(0, len(idx2word_d))))

max_d = 205
max_p = 545


def protein2emb_encoder(x):
    max_p = 545
    t1 = pbpe.process_line(x).split()  # split
    try:
        i1 = np.asarray([words2idx_p[i] for i in t1])  # index
    except:
        i1 = np.array([0])
        #print(x)

    l = len(i1)

    if l < max_p:
        i = np.pad(i1, (0, max_p - l), 'constant', constant_values = 0)
        input_mask = ([1] * l) + ([0] * (max_p - l)) # real value와 padding 구분하기 위한 input mask
    else:
        i = i1[:max_p]
        input_mask = [1] * max_p

    return i, np.asarray(input_mask)

def drug2emb_encoder(x):
    max_d = 50
    #max_d = 100
    t1 = dbpe.process_line(x).split()  # split
    try:
        i1 = np.asarray([words2idx_d[i] for i in t1])  # index
    except:
        i1 = np.array([0])
        #print(x)

    l = len(i1)

    if l < max_d:
        i = np.pad(i1, (0, max_d - l), 'constant', constant_values = 0)
        input_mask = ([1] * l) + ([0] * (max_d - l))

    else:
        i = i1[:max_d]
        input_mask = [1] * max_d

    return i, np.asarray(input_mask)


class BIN_Data_Encoder(Dataset):

    def __init__(self, list_IDs, labels, df_dti):
        'Initialization'
        self.labels = labels
        self.list_IDs = list_IDs
        self.df = df_dti

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.list_IDs)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        # Load data and get label
        index = self.list_IDs[index]
        #d = self.df.iloc[index]['DrugBank ID']
        d = self.df.iloc[index]['SMILES']
        p = self.df.iloc[index]['Target Sequence']

        #d_v = drug2single_vector(d)
        d_v, input_mask_d = drug2emb_encoder(d)
        p_v, input_mask_p = protein2emb_encoder(p)

        #print(d_v.shape)
        #print(input_mask_d.shape)
        #print(p_v.shape)
        #print(input_mask_p.shape)
        y = self.labels[index]
        return d_v, p_v, input_mask_d, input_mask_p, y

In [4]:
dftrain=pd.read_csv('train.csv')

In [11]:
dftrain

Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,DrugBank ID,Gene,Label,SMILES,Target Sequence
0,0,3,4,DB08533,P49862,0.0,CC1=CN=C2N1C=CN=C2NCC1=CC=NC=C1,MARSLLLPLQILLLSLALETAGEEAQGDKIIDGAPCARGSHPWQVA...
1,1,4,5,DB00755,P48443,1.0,C\C(\C=C\C1=C(C)CCCC1(C)C)=C/C=C/C(/C)=C/C(O)=O,MYGNYSHFMKFPAGYGGSPGHTGSTSMSPSAALSTGKPMDSHPSYT...
2,2,5,6,DB00361,O60218,0.0,[H][C@@]12N(C)C3=CC(OC)=C(C=C3[C@@]11CCN3CC=C[...,MATFVELSTKAKMPIVGLGTWKSPLGKVKEAVKVAIDAGYRHIDCA...
3,3,7,8,DB01136,P08588,1.0,COC1=CC=CC=C1OCCNCC(O)COC1=CC=CC2=C1C1=CC=CC=C1N2,MGAGVLVLGASEPGNLSSAAPLPDGAATAARLLVPASPPASLLPPA...
4,4,8,9,DB06963,Q9Y691,0.0,[H][C@@](C)(NC1=CC2=C(C=N1)C(C)=NN2C1=CC=CC(CC...,MFIWTSGRTSSSYRHDEKRNIYQKIRDHDLLDKRKTVTALKAGEDR...
...,...,...,...,...,...,...,...,...
19233,19233,27549,27580,DB00339,P08684,1.0,NC(=O)C1=NC=CN=C1,MALIPDLAMETWLLLAVSLVLLYLYGTHSHGLFKKLGIPGPTPLPF...
19234,19234,27550,27581,DB07321,P09467,1.0,COC1=CC2=C(OC(NS(=O)(=O)C3=CC(Cl)=CC=C3Cl)=N2)...,MADQAPFDTDVNTLTRFVMEEGRKARGTGELTQLLNSLCTAVKAIS...
19235,19235,27552,27583,DB00543,P35348,1.0,ClC1=CC2=C(OC3=CC=CC=C3N=C2N2CCNCC2)C=C1,MVFLSGNASDSSNCTQPPAPVNISKAILLGVILGGLILFGVLGNIL...
19236,19236,27553,27584,DB07328,P27487,1.0,[H][C@]1(CNC(=O)NC2=CC=C(C=C2)C(=O)OC)CC[C@]([...,MKTPWKVLLGLLGAAALVTIITVPVVLLNKGTDDATADSRKTYTLT...


In [17]:
length=0
for i in range(19237):
  length+=len(dftrain['SMILES'][i])

print(length/19237)

57.37557831262671


In [18]:
length=0
for i in range(19237):
  length+=len(dftrain['Target Sequence'][i])

print(length/19237)

551.0095649009721


In [8]:
t1 = pbpe.process_line(x)

In [9]:
t1

'M ARS LLL PL QI LLL SLAL ET AGEE AQ GD KI IDG AP CAR GS HP W QV ALL SG NQL HC GGVL VN ER WVL TAA HC KM NE YTV HLG SD TL GD RR AQ RIK ASKS FRH PG YS TQ TH VNDL ML VKL NS QARL SS MV KK VRL PS RCE PPG TT C TVSG WG TTTS PD VT FPS DL MC VD VKL ISP QD CTK VY KD LLE NS ML CAG IP DS KK NAC NG DS GGPL VC RGTL QGL VS WG TF PCG QP ND PGVY TQ VC KF TK WIN DT MKK H R'

In [10]:
words2idx_p

{'L': 0,
 'V': 1,
 'S': 2,
 'W': 3,
 'Q': 4,
 'E': 5,
 'I': 6,
 'T': 7,
 'Y': 8,
 'G': 9,
 'A': 10,
 'R': 11,
 'U': 12,
 'M': 13,
 'P': 14,
 'H': 15,
 'Z': 16,
 'D': 17,
 'F': 18,
 'K': 19,
 'O': 20,
 'B': 21,
 'C': 22,
 'X': 23,
 'N': 24,
 'LL': 25,
 'AA': 26,
 'AL': 27,
 'VL': 28,
 'GL': 29,
 'EL': 30,
 'SL': 31,
 'GG': 32,
 'SS': 33,
 'EE': 34,
 'TL': 35,
 'DL': 36,
 'RL': 37,
 'IL': 38,
 'AV': 39,
 'KL': 40,
 'AG': 41,
 'VV': 42,
 'AE': 43,
 'KK': 44,
 'SG': 45,
 'AI': 46,
 'PL': 47,
 'AR': 48,
 'AD': 49,
 'AS': 50,
 'QL': 51,
 'TG': 52,
 'AK': 53,
 'VE': 54,
 'NL': 55,
 'FL': 56,
 'VI': 57,
 'VG': 58,
 'AT': 59,
 'KE': 60,
 'RR': 61,
 'VD': 62,
 'VS': 63,
 'PG': 64,
 'IE': 65,
 'PE': 66,
 'IG': 67,
 'ID': 68,
 'VT': 69,
 'RE': 70,
 'IS': 71,
 'AQ': 72,
 'DG': 73,
 'VK': 74,
 'DE': 75,
 'PS': 76,
 'YL': 77,
 'RG': 78,
 'IT': 79,
 'AF': 80,
 'NG': 81,
 'KG': 82,
 'AP': 83,
 'VR': 84,
 'TT': 85,
 'IK': 86,
 'FG': 87,
 'SE': 88,
 'AN': 89,
 'VP': 90,
 'HL': 91,
 'IN': 92,
 'ML': 93,
 