/
CNN.py
79 lines (65 loc) · 2.34 KB
/
CNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# %%
import numpy as np
from tensorflow import keras
from models import CNN_model
import matplotlib.pyplot as plt
from os import path, makedirs
# Create dummy input data
bc = np.loadtxt('../simp/results_merge_2/bc.txt')
load = np.loadtxt('../simp/results_merge_2/load.txt')
output = np.loadtxt('../simp/results_merge_2/output.txt')
# Generate random input data
input_shape = (61, 61) # Input size of 61x61
num_channels = 2 # Number of channels in each input array
batch_size = bc.shape[0] # Number of samples in each batch
input_data = np.zeros((batch_size,) + input_shape + (num_channels,))
for i in range(batch_size):
input_data[i, :, :, 0] = bc[i].reshape((61,61))
input_data[i, :, :, 1] = load[i].reshape((61,61))
#input_data[i, :, :, 2] = vol[i].reshape((61,61))
output_data = output.reshape((output.shape[0],60,60,1))
input_train = input_data[:-1000]
output_train = output_data[:-1000]
input_test = input_data[-1000:]
output_test = output_data[-1000:]
model = CNN_model((61,61,num_channels))
# %%
checkpoint_callback = keras.callbacks.ModelCheckpoint(
'./best/cp.ckpt',
monitor="val_accuracy",
mode="max",
save_best_only=True,
save_weights_only=True,
verbose= 1,
)
earlyStopping_callback = keras.callbacks.EarlyStopping(
monitor="val_accuracy",
mode="max",
patience=5,
verbose=1,
)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(input_train, output_train, epochs=5, batch_size=10,validation_split=0.1, callbacks=[checkpoint_callback, earlyStopping_callback])
# Save the model
model.save('../models/model_unet')
dir = './plots'
if not path.exists(dir): makedirs(dir)
plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('plots/loss_plot.png') # Save the plot as an image
plt.show()
# Plotting training and validation accuracy
plt.figure(figsize=(10, 6))
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('plots/accuracy_plot.png') # Save the plot as an image
plt.show()