In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

torch.manual_seed(0)

class Autoencoder(torch.nn.Module):
    def __init__(self, input_size, encoding_size):
        super(Autoencoder, self).__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_size, 1024), 
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 512), 
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256), 
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),  
            torch.nn.ReLU(),
            torch.nn.Linear(128, encoding_size),
            torch.nn.ReLU()
        )

        # Decoder layers
    
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(encoding_size, 128), 
            torch.nn.ReLU(),
            torch.nn.Linear(128, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, input_size),
            torch.nn.Sigmoid()  # Use Sigmoid activation for the output layer
        )


    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

class CompactCNN(torch.nn.Module):
    def __init__(self, classes=2, channels=32, kernelLength=15, encoding_size=64):
        super(CompactCNN, self).__init__()
        self.kernelLength = kernelLength

        self.conv1 = torch.nn.Conv2d(1, channels, (1, kernelLength))
        self.batch1 = Batchlayer(channels)
        self.elu1 = torch.nn.ELU()
        self.dropout1 = torch.nn.Dropout(0.7)

        # GAP and Fully Connected Layer
        self.fc = torch.nn.Linear(channels, classes)
        self.softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, inputdata):
        intermediate = self.conv1(inputdata)
        intermediate = self.batch1(intermediate)
        intermediate = self.elu1(intermediate)
        intermediate = self.dropout1(intermediate)

        # Global Average Pooling
        intermediate = torch.nn.functional.adaptive_avg_pool2d(intermediate, (1, 1))

        # Flatten the output before fully connected layer
        intermediate = intermediate.view(intermediate.size(0), -1)

        intermediate = self.fc(intermediate)
        output = self.softmax(intermediate)

        return output


class Batchlayer(torch.nn.Module):
    def __init__(self, dim):
        super(Batchlayer, self).__init__()
        self.gamma = torch.nn.Parameter(torch.Tensor(1, dim, 1, 1))
        self.beta = torch.nn.Parameter(torch.Tensor(1, dim, 1, 1))
        self.gamma.data.uniform_(-0.1, 0.1)
        self.beta.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        data = normalizelayer(input)
        gammamatrix = self.gamma.expand(int(data.size(0)), int(data.size(1)), int(data.size(2)), int(data.size(3)))
        betamatrix = self.beta.expand(int(data.size(0)), int(data.size(1)), int(data.size(2)), int(data.size(3)))

        return data * gammamatrix + betamatrix

class CompactCNNWithAutoencoder(torch.nn.Module):
    def __init__(self, classes=2, channels=32, kernelLength=125, sampleLength=1751, encoding_size=64):
        super(CompactCNNWithAutoencoder, self).__init__()
        self.autoencoder = Autoencoder(sampleLength, encoding_size)
        self.compact_cnn = CompactCNN(classes, channels, kernelLength, encoding_size)  # Use encoding_size instead of sampleLength

    def forward(self, inputdata):
        encoded_data = self.autoencoder(inputdata.view(inputdata.size(0), -1))
        # Use the encoded features directly, no need to reconstruct
        
        output = self.compact_cnn(encoded_data.view(encoded_data.size(0), 1, 1, -1))
        return output

def normalizelayer(data):
    eps = 1e-05
    a_mean = data - torch.mean(data, [0, 2, 3], True).expand(int(data.size(0)), int(data.size(1)), int(data.size(2)),
                                                              int(data.size(3)))
    b = torch.div(a_mean, torch.sqrt(torch.mean((a_mean) ** 2, [0, 2, 3], True) + eps).expand(int(data.size(0)),
                                                                                                int(data.size(1)),
                                                                                                int(data.size(2)),
                                                                                                int(data.size(3))))

    return b

def get_encoded_features(data, autoencoder_model):
    with torch.no_grad():
        data_tensor = torch.FloatTensor(data)
        encoded_data = autoencoder_model(data_tensor.view(data_tensor.size(0), -1))
    return encoded_data

def run():
    lr_autoencoder = 1e-2
    lr_compactcnn = 1e-3
    batch_size = 32
    n_epoch_autoencoder = 10
    n_epoch_compactcnn = 10

    indices = [128, 21, 22, 32, 23, 10, 27, 12, 44, 35, 46, 36, 57, 51, 66, 69, 8, 2, 121, 123, 116, 111, 107, 103, 97, 86,
               95, 91, 61, 76, 82, 74]

    xdata = np.load("path/eeg_data.npy")
    ydata = np.load("path/labels.npy")

    reduced_features_list = []  # to store reduced features for each index

    for i in indices:
        x_index = xdata[:, i - 1, :]  # Extracting data for a specific index i
        x_index = x_index.reshape(x_index.shape[0], 1, 1, 1751)
        x_train, x_test, y_train, y_test = train_test_split(x_index, ydata, test_size=0.2, random_state=42)

        x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long)

        x_val_tensor = torch.tensor(x_test, dtype=torch.float32)
        y_val_tensor = torch.tensor(y_test, dtype=torch.long)

        train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
        val_dataset = TensorDataset(x_val_tensor, y_val_tensor)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

        # Autoencoder Training
        autoencoder_cnn = CompactCNNWithAutoencoder(classes=2, channels=32, kernelLength=125, sampleLength=1751, encoding_size=64)
        optimizer_autoencoder = optim.Adam(autoencoder_cnn.autoencoder.parameters(), lr=lr_autoencoder)
        loss_autoencoder = torch.nn.MSELoss()
        threshold = 0.03
        for epoch in range(n_epoch_autoencoder):
            total_loss = 0.0
            total_samples = 0
            

            for j, data in enumerate(train_loader, 0):
                inputs, _ = data
                autoencoder_cnn.autoencoder.zero_grad()
                autoencoder_cnn.autoencoder.train()

                # Encoding
                encoded_data = autoencoder_cnn.autoencoder.encoder(inputs.view(inputs.size(0), -1))

                # Decoding
                reconstructed_data = autoencoder_cnn.autoencoder.decoder(encoded_data)

                # Loss calculation
                err_autoencoder = loss_autoencoder(reconstructed_data, inputs)
                err_autoencoder.backward()
                optimizer_autoencoder.step()

                
                total_loss += err_autoencoder.item()
                total_samples += inputs.size(0)

                diff = torch.abs(reconstructed_data - inputs)
                mse_per_sample = torch.mean(diff, dim=(1, 2, 3))
                
            # Print average MSE and accuracy for the epoch
            average_loss = total_loss / total_samples
            print(f"Epoch {epoch + 1}, Average MSE Loss: {average_loss:.4f}")

        # Get encoded features for training and validation data
        autoencoder_cnn.autoencoder.eval()
        with torch.no_grad():
            x_train_encoded = autoencoder_cnn.autoencoder.encoder(torch.tensor(x_train, dtype=torch.float32).view(x_train.shape[0], -1)).numpy()
            x_val_encoded = autoencoder_cnn.autoencoder.encoder(torch.tensor(x_test, dtype=torch.float32).view(x_test.shape[0], -1)).numpy()

        reduced_features_list.append((x_train_encoded, x_val_encoded))


        # CompactCNN Training with Encoded Features
        x_train_encoded = x_train_encoded.reshape(x_train_encoded.shape[0], 1, 1, -1)
        x_val_encoded = x_val_encoded.reshape(x_val_encoded.shape[0], 1, 1, -1)

        x_train_tensor_encoded = torch.tensor(x_train_encoded, dtype=torch.float32)
        y_train_tensor = torch.tensor(y_train, dtype=torch.long)

        x_val_tensor_encoded = torch.tensor(x_val_encoded, dtype=torch.float32)
        y_val_tensor = torch.tensor(y_test, dtype=torch.long)
        

        
        train_dataset_encoded = TensorDataset(x_train_tensor_encoded, y_train_tensor)
        val_dataset_encoded = TensorDataset(x_val_tensor_encoded, y_val_tensor)

        train_loader_encoded = DataLoader(train_dataset_encoded, batch_size=batch_size, shuffle=True)
        val_loader_encoded = DataLoader(val_dataset_encoded, batch_size=batch_size, shuffle=False)

        compactcnn_model = CompactCNN(classes=2, channels=32, kernelLength=15, encoding_size=64) 
        optimizer_compactcnn = optim.Adam(compactcnn_model.parameters(), lr=lr_compactcnn)
        loss_compactcnn = torch.nn.NLLLoss()


        for epoch in range(n_epoch_compactcnn):
            for j, data in enumerate(train_loader_encoded, 0):
                inputs, labels = data
                compactcnn_model.zero_grad()
                compactcnn_model.train()
                class_output = compactcnn_model(inputs)
                err_s_label = loss_compactcnn(class_output, labels)
                err = err_s_label
                err.backward()
                optimizer_compactcnn.step()

        # Evaluation
        compactcnn_model.eval()
        with torch.no_grad():
            x_val_tensor_encoded = torch.tensor(x_val_encoded, dtype=torch.float32)
            answer = compactcnn_model(x_val_tensor_encoded)
            probs = answer.cpu().numpy()
            preds = probs.argmax(axis=-1)
            acc = accuracy_score(y_test, preds)
            print(f"Accuracy (Index {i}): {acc:.4f}")

if __name__ == '__main__':
    run()