In [6]:
import os
from data_loader.SingleDataLoader import SingleDataLoader
from data_loader.DoubleDataLoader import DoubleDataLoader
from models.ResNet_revised import ResNet_revised
from trainer.train import train_func
from trainer.aug_train import aug_train_func
from lib.path import get_training_data_dir
from custom_losses.dice import dice_loss, dice_coefficient
from tensorflow.keras.metrics import Recall, Precision
from tensorflow.keras.losses import BinaryCrossentropy

In [7]:
data_dir = '../data'
train_list = os.path.join(data_dir, 'train.txt')
test_list = os.path.join(data_dir, 'test.txt')
val_list = os.path.join(data_dir, 'val.txt')

In [8]:
DATA_TYPE1 = 'gr'
# DATA_TYPE2 = 'Protein'
DATA_VOXEL_NUM = 10
CLASSIFYING_RULE = 'WaterClassifyingRuleEmbedding'
LIGAND_POCKET_DEFINER = 'LigandPocketDefinerOriginal'
LIGAND_VOXEL_NUM = 8

training_data_dir1 = get_training_data_dir(DATA_TYPE1, DATA_VOXEL_NUM, CLASSIFYING_RULE, LIGAND_POCKET_DEFINER, LIGAND_VOXEL_NUM)

In [9]:
data_loader = SingleDataLoader(training_data_dir1)

In [10]:
train_data, train_labels = data_loader.load_data(train_list)
test_data, test_labels = data_loader.load_data(test_list)
val_data, val_labels = data_loader.load_data(val_list)

In [11]:
print('Train data shape: ', train_data.shape)
print('Train labels shape: ', train_labels.shape)
print('Test data shape: ', test_data.shape)
print('Test labels shape: ', test_labels.shape)
print('Val data shape: ', val_data.shape)
print('Val labels shape: ', val_labels.shape)

Train data shape:  (56269, 21, 21, 21, 1)
Train labels shape:  (56269,)
Test data shape:  (13354, 21, 21, 21, 1)
Test labels shape:  (13354,)
Val data shape:  (11995, 21, 21, 21, 1)
Val labels shape:  (11995,)


In [12]:
input_shape = (DATA_VOXEL_NUM*2+1, DATA_VOXEL_NUM*2+1, DATA_VOXEL_NUM*2+1, train_data.shape[-1])
epochs = 100
batch_size = 128
n_base = 32
learning_rate = 1e-4
early_stopping = 40
BN = True
dropout = 0.5
model_func = ResNet_revised
MODEL_NAME = model_func.__name__
TRAINER_NAME = 'normal_train'
losses = [BinaryCrossentropy(), dice_loss]
loss= losses[0]
metrics = ['accuracy', dice_coefficient, Recall(), Precision()]
path_type = f'/{DATA_TYPE1}/data_voxel_num_{DATA_VOXEL_NUM}/{LIGAND_POCKET_DEFINER}/ligand_pocket_voxel_num_{LIGAND_VOXEL_NUM}/{CLASSIFYING_RULE}/{MODEL_NAME}/{TRAINER_NAME}/'
# path_type = f'/{DATA_TYPE1}_{DATA_TYPE2}/data_voxel_num_{DATA_VOXEL_NUM}/{LIGAND_POCKET_DEFINER}/ligand_pocket_voxel_num_{LIGAND_VOXEL_NUM}/{CLASSIFYING_RULE}/{MODEL_NAME}/{TRAINER_NAME}/'

checkpoint_path = f"./checkpoints/{path_type}/" + "cp-{epoch:04d}.weights.h5"
model_checkpoint = True

2024-07-29 20:40:26.999006: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [13]:
pos = train_labels.sum()
neg = train_labels.shape[0] - pos
total = train_labels.shape[0]

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)
class_weight = {0: weight_for_0, 1: weight_for_1}
print(class_weight)

{0: 1.0062410586552217, 1: 0.9938358825815112}


In [14]:
# clf, clf_hist = aug_train_func(
#                                 x_train=train_data,
#                                 y_train=train_labels,
#                                 x_val=val_data,
#                                 y_val=val_labels,
#                                 input_shape=input_shape,
#                                 model_func=model_func,
#                                 loss=loss,
#                                 metrics=metrics,
#                                 epochs=epochs,
#                                 batch_size=batch_size,
#                                 num_rotations=1,
#                                 angle_unit=45,
#                                 n_base=n_base,
#                                 learning_rate=learning_rate,
#                                 early_stopping=early_stopping,
#                                 checkpoint_path=checkpoint_path,
#                                 model_checkpoint=model_checkpoint,
#                                 class_weight=class_weight,
#                                 BN = BN,
#                                 dropout=dropout
#                             )

In [15]:
clf, clf_hist, clf_eval = train_func(
                                    x_train=train_data,
                                    y_train=train_labels,
                                    x_test=test_data,
                                    y_test=test_labels,
                                    x_val=val_data,
                                    y_val=val_labels,
                                    input_shape=input_shape,
                                    model_func=model_func,
                                    loss=loss,
                                    metrics=metrics,
                                    epochs=epochs,
                                    batch_size=batch_size,
                                    n_base=n_base,
                                    learning_rate=learning_rate,
                                    early_stopping=early_stopping,
                                    checkpoint_path=checkpoint_path,
                                    model_checkpoint=model_checkpoint,
                                    class_weight=class_weight,
                                    BN = BN,
                                    dropout=dropout
                                )

Epoch 1/100
[1m239/440[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m3:48[0m 1s/step - accuracy: 0.5475 - dice_coefficient: 0.5322 - loss: 0.9198 - precision: 0.5510 - recall: 0.5370

KeyboardInterrupt: 

In [None]:
prediction = clf.predict(test_data)

In [None]:
prediction.round().sum()