In [None]:

import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Flatten, Dropout, Conv1D, MaxPooling1D, GlobalAveragePooling1D
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

# Assuming you have loaded your ECG dataset and labels
# X, y = load_ecg_data()

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the base AlexNet model
def create_alexnet(input_shape, num_classes):
    model = Sequential()

    # Convolutional layers
    model.add(Conv1D(96, 11, activation='relu', input_shape=input_shape))
    model.add(MaxPooling1D(3))
    model.add(Conv1D(256, 5, activation='relu'))
    model.add(MaxPooling1D(3))
    model.add(Conv1D(384, 3, activation='relu'))
    model.add(Conv1D(384, 3, activation='relu'))
    model.add(Conv1D(256, 3, activation='relu'))
    model.add(MaxPooling1D(3))

    # Flatten and fully connected layers
    model.add(Flatten())
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))

    return model

# Create the AlexNet model for your specific input shape and number of classes
input_shape = X_train.shape[1:]  # Shape of one ECG signal
num_classes = 5  # Replace with your actual number of classes
alexnet_model = create_alexnet(input_shape, num_classes)

# Load pre-trained weights (assuming you have a pre-trained AlexNet model)
# alexnet_model.load_weights('path_to_pretrained_weights.h5', by_name=True)

# Compile the model with Adam optimizer, a learning rate of 0.01, and categorical crossentropy loss
adam_optimizer = Adam(lr=0.01)
alexnet_model.compile(optimizer=adam_optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model with a batch size of 128
alexnet_model.fit(X_train, y_train, epochs=10, batch_size=128, validation_data=(X_test, y_test))


