Skip to content

Commit 569c0e7

Browse files
authored
Update model_training.py
1 parent faca863 commit 569c0e7

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

Malaria/model_training.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,22 @@
5252
labels2 = labels1[n]
5353

5454
# Splitting the dataset into the Training set and Test set
55-
X_train, X_valid, y_train, y_valid = train_test_split(data2, labels2, test_size=0.2, random_state=0)
55+
X_train, X_valid, y_train,y_valid = train_test_split(data2,
56+
labels2, test_size=0.2, random_state=0)
5657
X_trainF = X_train.astype('float32')
57-
X_validF = X_valid.astype('float32')
58+
X_validF = X_valid.astype('float32')
5859
y_trainF = to_categorical(y_train)
5960
y_validF = to_categorical(y_valid)
6061

6162
classifier = Sequential()
6263
# CNN layers
63-
classifier.add(Conv2D(32, kernel_size=(3, 3), input_shape=(36, 36, 3), activation='relu'))
64+
classifier.add(Conv2D(32, kernel_size=(3, 3),
65+
input_shape=(36, 36, 3), activation='relu'))
6466
classifier.add(MaxPooling2D(pool_size=(2, 2)))
6567
classifier.add(BatchNormalization(axis=-1))
6668
classifier.add(Dropout(0.5)) # Dropout prevents overfitting
67-
classifier.add(Conv2D(32, kernel_size=(3, 3), input_shape=(36, 36, 3), activation='relu'))
69+
classifier.add(Conv2D(32, kernel_size=(3, 3),
70+
input_shape=(36, 36, 3), activation='relu'))
6871
classifier.add(MaxPooling2D(pool_size=(2, 2)))
6972
classifier.add(BatchNormalization(axis=-1))
7073
classifier.add(Dropout(0.5))
@@ -73,8 +76,11 @@
7376
classifier.add(BatchNormalization(axis=-1))
7477
classifier.add(Dropout(0.5))
7578
classifier.add(Dense(units=2, activation='softmax'))
76-
classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
77-
history = classifier.fit(X_trainF, y_trainF, batch_size=120, epochs=15, verbose=1, validation_data=(X_validF, y_validF))
79+
classifier.compile(optimizer='adam',
80+
loss='categorical_crossentropy', metrics=['accuracy'])
81+
history = classifier.fit(X_trainF, y_trainF,
82+
batch_size=120, epochs=15,
83+
verbose=1, validation_data=(X_validF, y_validF))
7884
classifier.summary()
7985

8086
y_pred = classifier.predict(X_validF)

0 commit comments

Comments
 (0)