In [1]:

import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.regularizers import L2

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import StratifiedKFold


SEED = 123

NO_CLASSES = 20
N_FOLDS = 3

MAX_LENGTH = 100
BATCH_SIZE = 256


train = pd.read_csv('Train.csv')


test = pd.read_csv('Test.csv')


In [2]:

train['seq_char_count'] = train['SEQUENCE'].apply(lambda x: len(x))

codes = {code for seq in train['SEQUENCE'] for code in seq}


def create_dict(codes):
  char_dict = {}
  for index, val in enumerate(codes):
    char_dict[val] = index+1

  return char_dict

char_dict = create_dict(codes)

print(char_dict)
print("Dict Length:", len(char_dict))


def integer_encoding(data):
  """
  - Encodes code sequence to integer values.
  - 20 common amino acids are taken into consideration
    and rest 4 are categorized as 0.
  """
  
  encode_list = []
  for row in data['SEQUENCE'].values:
    row_encode = []
    for code in row:
      row_encode.append(char_dict.get(code, 0))
    encode_list.append(np.array(row_encode))
  
  return encode_list
  
train_encode = integer_encoding(train) 

train_pad = pad_sequences(train_encode, maxlen=MAX_LENGTH, padding='post', truncating='post')

print(train_pad.shape)




{'B': 1, 'K': 2, 'N': 3, 'I': 4, 'R': 5, 'L': 6, 'V': 7, 'M': 8, 'G': 9, 'P': 10, 'T': 11, 'Q': 12, 'U': 13, 'S': 14, 'D': 15, 'F': 16, 'X': 17, 'C': 18, 'W': 19, 'Z': 20, 'E': 21, 'A': 22, 'H': 23, 'Y': 24}
Dict Length: 24
(858777, 100)


In [3]:
# One hot encoding of sequences
X = to_categorical(train_pad)
print(X.shape) 

X = train_pad

y = train['LABEL'].str.replace('[A-Za-z]', '').astype(int)



(858777, 100, 25)


In [4]:

## Define model
x_input = tf.keras.Input(shape=(MAX_LENGTH,), name='Input')


emb = tf.keras.layers.Embedding(25, 64, input_length=MAX_LENGTH)(x_input)
bi_rnn = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, kernel_regularizer=L2(0.01), 
                                                            recurrent_regularizer=L2(0.01), 
                                                            bias_regularizer=L2(0.01)))(emb)
x = tf.keras.layers.Dropout(0.3)(bi_rnn)


# softmax classifier
x_output = tf.keras.layers.Dense(NO_CLASSES, activation='softmax')(x)

model = tf.keras.Model(inputs=x_input, outputs=x_output, name="Model_1")
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()




Model: "Model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
Input (InputLayer)           [(None, 100)]             0         
_________________________________________________________________
embedding (Embedding)        (None, 100, 64)           1600      
_________________________________________________________________
bidirectional (Bidirectional (None, 128)               66048     
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
dense (Dense)                (None, 20)                2580      
Total params: 70,228
Trainable params: 70,228
Non-trainable params: 0
_________________________________________________________________


In [7]:

es = tf.keras.callbacks.EarlyStopping(monitor='val_acc', patience=3)

skf = StratifiedKFold(n_splits=N_FOLDS)

score_list = []
for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_val = X[train_index], X[test_index]
    y_train, y_val = y[train_index], y[test_index]
    y_train = to_categorical(y_train)
    y_val = to_categorical(y_val)
    history = model.fit(X_train, y_train, epochs=50, batch_size=BATCH_SIZE,
                       validation_data=(X_val, y_val), callbacks=[es])
    best_score = max(history.history['acc'])
    score_list.append(best_score)


TRAIN: [281833 281851 281886 ... 858774 858775 858776] TEST: [     0      1      2 ... 292767 293060 293082]
Epoch 1/50
 259/2237 [==>...........................] - ETA: 7:25 - loss: 2.4247 - accuracy: 0.2962

KeyboardInterrupt: 