In [9]:
import os
from lib.data_load import load_data
from models.LeNet import LeNet
from models.AlexNet import Alexnet
from trainer.train import train_func
from custom_losses.dice import dice_loss, dice_coefficient
from tensorflow.keras.metrics import Recall, Precision
from tensorflow.keras.losses import BinaryFocalCrossentropy

In [10]:
data_dir = '../data'
train_list = os.path.join(data_dir, 'train_list')
test_list = os.path.join(data_dir, 'test_list')
val_list = os.path.join(data_dir, 'val_list')

In [11]:
GR_VOXEL_NUM = 10
LIGAND_VOXEL_NUM = 8

In [12]:
train_data, train_labels = load_data(train_list, LIGAND_VOXEL_NUM, GR_VOXEL_NUM)
test_data, test_labels = load_data(test_list, LIGAND_VOXEL_NUM, GR_VOXEL_NUM)
val_data, val_labels = load_data(val_list, LIGAND_VOXEL_NUM, GR_VOXEL_NUM)

In [13]:
print('Train data shape: ', train_data.shape)
print('Train labels shape: ', train_labels.shape)
print('Test data shape: ', test_data.shape)
print('Test labels shape: ', test_labels.shape)
print('Val data shape: ', val_data.shape)
print('Val labels shape: ', val_labels.shape)

Train data shape:  (33598, 21, 21, 21, 1)
Train labels shape:  (33598,)
Test data shape:  (16754, 21, 21, 21, 1)
Test labels shape:  (16754,)
Val data shape:  (17329, 21, 21, 21, 1)
Val labels shape:  (17329,)


In [14]:
input_shape = (GR_VOXEL_NUM*2+1, GR_VOXEL_NUM*2+1, GR_VOXEL_NUM*2+1, 1)
epochs = 300
batch_size = 32
n_base = 32
learning_rate = 1e-5
early_stopping = 300
BN = True
dropout = 0.4
model_func = LeNet
loss= dice_loss
metrics = ['accuracy', dice_coefficient, Recall(), Precision()]
checkpoint_path = "./checkpoints/LIGAND_VOXEL_NUM_8/GR_VOXEL_NUM_10/LeNet/cp-{epoch:04d}.weights.h5"
model_checkpoint = True

In [15]:
pos = train_labels.sum()
neg = train_labels.shape[0] - pos
total = train_labels.shape[0]

weight_for_0 = (1 / neg) * (total / 2.0)
weight_for_1 = (1 / pos) * (total / 2.0)
class_weight = {0: weight_for_0, 1: weight_for_1}
print(class_weight)

{0: 0.6850024465829392, 1: 1.8513334802733084}


In [16]:
clf, clf_hist, clf_eval = train_func(
                                    x_train=train_data,
                                    y_train=train_labels,
                                    x_test=test_data,
                                    y_test=test_labels,
                                    x_val=val_data,
                                    y_val=val_labels,
                                    input_shape=input_shape,
                                    model_func=model_func,
                                    loss=loss,
                                    metrics=metrics,
                                    epochs=epochs,
                                    batch_size=batch_size,
                                    n_base=n_base,
                                    learning_rate=learning_rate,
                                    early_stopping=early_stopping,
                                    checkpoint_path=checkpoint_path,
                                    model_checkpoint=model_checkpoint,
                                    class_weight=class_weight,
                                    BN = BN,
                                    dropout=dropout
                                )

Epoch 1/300
[1m1050/1050[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 32ms/step - accuracy: 0.3284 - dice_coefficient: 0.4370 - loss: 0.5553 - precision_1: 0.2783 - recall_1: 0.9295 - val_accuracy: 0.5368 - val_dice_coefficient: 0.4581 - val_loss: 0.5419 - val_precision_1: 0.3228 - val_recall_1: 0.7534
Epoch 2/300
[1m1050/1050[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 29ms/step - accuracy: 0.5709 - dice_coefficient: 0.4949 - loss: 0.5009 - precision_1: 0.3647 - recall_1: 0.7593 - val_accuracy: 0.6794 - val_dice_coefficient: 0.4898 - val_loss: 0.5102 - val_precision_1: 0.4086 - val_recall_1: 0.5922
Epoch 3/300
[1m1050/1050[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 29ms/step - accuracy: 0.6578 - dice_coefficient: 0.5242 - loss: 0.4692 - precision_1: 0.4177 - recall_1: 0.6940 - val_accuracy: 0.7112 - val_dice_coefficient: 0.4934 - val_loss: 0.5066 - val_precision_1: 0.4426 - val_recall_1: 0.5373
Epoch 4/300
[1m1050/1050[0m [32m━━━━━━━━━━━━━━━━━━━━

In [None]:
prediction = clf.predict(test_data)

[1m524/524[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step


In [11]:
prediction.round().sum()

2417.0

In [5]:
precision = 0.756
recall =0.3592

2*precision*recall/(precision+recall)

0.48700717360114776