In [1]:
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

from common import (
    delta_g_path,
    id_seq_path
)
from char_table import CharTable

In [2]:
# load processed files
df_delta_g = pd.read_csv(delta_g_path)
df_id_seq = pd.read_csv(id_seq_path)

MAXLEN = int(np.percentile(df_id_seq["seqs"].map(lambda x: len(x)), 90))

In [3]:
MAXLEN

434

In [4]:
all_symbol_set = set()
id_seq_dict = {}
for idx, row in df_id_seq.iterrows():
    seqs = row["seqs"]
    id_seq_dict[row["id"]] = seqs
    all_symbol_set.update(seqs)

In [5]:
ctable = CharTable(all_symbol_set)

In [6]:
inputs = []
outputs = []
for idx, row in df_delta_g.iterrows():
    skip_flag = False
    for key in ["wt_name_0", "wt_name_1", "mut_name_0", "mut_name_1"]:
        seqs = id_seq_dict[row[key]]
        if len(seqs) > MAXLEN:
            skip_flag = True
            break
    if skip_flag:
        continue

    # wt
    input_item = []
    for key in ["wt_name_0", "wt_name_1"]:
        seqs = id_seq_dict[row[key]]
        delta_g = row["delta_g_wt"]
        encoded = ctable.encode(seqs, len(seqs)).tolist()
        for _ in range(MAXLEN - len(encoded)):
            encoded.append([0.0] * len(all_symbol_set))
        input_item.append(encoded)
    if input_item:
        inputs.append(input_item)
        outputs.append(delta_g)
    
    # mut
    input_item = []
    for key in ["mut_name_0", "mut_name_1"]:
        seqs = id_seq_dict[row[key]]
        if len(seqs) > MAXLEN:
            break
        delta_g = row["delta_g_mut"]
        encoded = ctable.encode(seqs, len(seqs)).tolist()
        for _ in range(MAXLEN - len(encoded)):
            encoded.append([0.0] * len(all_symbol_set))
        input_item.append(encoded)
    
    if input_item:
        inputs.append(input_item)
        outputs.append(delta_g)

In [7]:
inputs = np.array(inputs).reshape(len(inputs), MAXLEN, len(all_symbol_set), 2)
outputs = np.array(outputs)
len(inputs) == len(outputs)

True

In [8]:
print(inputs.shape, outputs.shape)

(8006, 434, 20, 2) (8006,)


In [9]:
train_size = int(len(inputs) * 0.9)

train_x = inputs[:train_size]
train_y = outputs[:train_size]

test_x = inputs[train_size:]
test_y = outputs[train_size:]

In [10]:
print(train_x.shape, train_y.shape, test_x.shape, test_y.shape)

(7205, 434, 20, 2) (7205,) (801, 434, 20, 2) (801,)


In [11]:
input_shape = (MAXLEN, len(all_symbol_set), 2)
learning_rate = 0.001
batch_size = 192
epochs = 50

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(5, 5), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(64, activation="relu"),
        layers.Dense(1),
    ]
)
model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate, amsgrad=True, epsilon=1e-6), loss="mse")
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 432, 18, 32)       608       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 216, 9, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 212, 5, 64)        51264     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 106, 2, 64)       0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 13568)             0         
                                                                 
 dropout (Dropout)           (None, 13568)             0

2021-11-21 22:50:25.742035: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-21 22:50:25.771951: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-21 22:50:25.772115: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:939] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-21 22:50:25.772574: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags

In [12]:
model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, validation_split=0.1)

Epoch 1/50


2021-11-21 22:50:27.536980: I tensorflow/stream_executor/cuda/cuda_dnn.cc:366] Loaded cuDNN version 8204


Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


<keras.callbacks.History at 0x7f00628e2d00>

In [13]:
score = model.evaluate(test_x, test_y, verbose=0)
score

4.877392292022705

In [14]:
test_y[0]

-8.633059487634572

In [15]:
model.predict(test_x)

array([[-11.047307 ],
       [-11.14163  ],
       [-11.115495 ],
       [-11.14163  ],
       [-11.173584 ],
       [-11.14163  ],
       [-11.171309 ],
       [-11.14163  ],
       [-11.117297 ],
       [-11.14163  ],
       [-11.153526 ],
       [-11.14163  ],
       [-11.125926 ],
       [-11.14163  ],
       [-11.131533 ],
       [-11.14163  ],
       [-11.215392 ],
       [-11.14163  ],
       [-11.129463 ],
       [-11.14163  ],
       [-11.095613 ],
       [-11.14163  ],
       [-11.147159 ],
       [-11.14163  ],
       [-11.049715 ],
       [-11.14163  ],
       [-11.174175 ],
       [-11.14163  ],
       [-11.070275 ],
       [-11.14163  ],
       [-11.047307 ],
       [-11.14163  ],
       [-11.149408 ],
       [-11.14163  ],
       [-11.127148 ],
       [-11.14163  ],
       [-11.153315 ],
       [-11.14163  ],
       [-11.139096 ],
       [ -8.604914 ],
       [ -8.583326 ],
       [ -8.604914 ],
       [ -8.648729 ],
       [ -9.804913 ],
       [ -9.802313 ],
       [ -