In [6]:
#Loading Libraries
import numpy as np
import torch

import os
from os import listdir
from os.path import isfile, join
from PIL import Image

import torch.optim as optim
from torch.autograd import grad
import time

import torch.utils.data as data_utils
from torch.nn import CrossEntropyLoss
from torch import nn
from torch.optim import Adam, lr_scheduler

from torchvision.datasets import CIFAR100
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

from scipy.io import loadmat
from scipy.io import savemat

In [7]:
device = 'cuda'
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        Resize(n_px, interpolation=BICUBIC),
        CenterCrop(n_px),
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])
nx=144

import clip
print(clip.available_models())
model, preprocess = clip.load('RN50x4', device)

relu = torch.nn.functional.relu
def features(net, x):
    x = x.type(net.conv1.weight.dtype)
    for conv, bn in [(net.conv1, net.bn1), (net.conv2, net.bn2), (net.conv3, net.bn3)]:
        x = relu(bn(conv(x)))
    x = net.avgpool(x)
    x = net.layer1(x)
    x = net.layer2(x)
    x = net.layer3(x)
    x = net.layer4(x)
    x=net.avgpool(x) 
    x=net.avgpool(x)
    return x

# Download the dataset
cifar100_train = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=True, transform=_transform(nx))
cifar100_test = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False, transform=_transform(nx))

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
Files already downloaded and verified
Files already downloaded and verified


In [None]:
from tqdm import tqdm
def generate_features(dataset,model):
    labels = torch.empty(0).cpu()
    i = 0
    for images, labs in tqdm(DataLoader(dataset, batch_size=100)):
            images = images.to(device)
            labs = labs.cpu()
            with torch.no_grad():
                f = features(model.visual,images) 
                f = f.squeeze((2,3))
            f=f.cpu()
            if i==0:
                d=f.shape[1]
                feat = torch.empty(0, d).cpu()
            feat = torch.cat((feat,f),dim=0)
            labels = torch.cat((labels , labs),dim=0)
            i = i+1
    return feat, labels

features_train,labels_train = generate_features(cifar100_train,model)
features_test,labels_test = generate_features(cifar100_test,model)

for i in range(100):
    print(i)
    name = r'D:\Cifar100\features\Clip\val'+'{}'.format(i) +'.mat'
    x = features_test[labels_test == i,:]
    savemat(name,{'feature':x.float().numpy()})

In [135]:
nc = 100 #number of classes
d = 2560 #number of features
#Loading data
def load_data(path,file):
    name=path+file
    m=loadmat(name)
    x=torch.tensor(m['feature'])
    return x.float()

path = r'D:\Cifar100\features\Clip\train/' 
trFeatures = torch.empty((0,d))
trY  = torch.empty(0)
for j in range(0,nc):
    x = load_data(path,'train{}.mat'.format(j))
    rep = x.shape[0]
    y =  torch.tensor(j)
    y1 = y.repeat(rep)
    trFeatures = torch.cat((trFeatures,x),dim = 0)
    trY = torch.cat((trY,y1),dim=0)

path = r'D:\Cifar100\features\Clip\val/'
valFeatures = torch.empty((0,d))
valY  = torch.empty(0)
for j in range(0,nc):
    x = load_data(path,'val{}.mat'.format(j))
    rep = x.shape[0]
    y =  torch.tensor(j)
    y1 = y.repeat(rep)
    valFeatures = torch.cat((valFeatures,x), dim = 0)
    valY = torch.cat((valY,y1),dim=0) 

In [136]:
valY.shape

torch.Size([10000])

In [96]:
def fs_ppca(trFeatures, trY, nc, d, q=20):
    """
    Performs Feature Selection using Probabilistic Principal Component Analysis (PPCA)
    for each class and returns a list of feature indices sorted by their relevance.

    Args:
        trFeatures (torch.Tensor): Training features (N x d), where N is the number of samples
                                     and d is the number of features.
        trY (torch.Tensor): Training labels (N).
        nc (int): Number of classes.
        d (int): Number of features.
        q (int, optional): Number of principal components to retain. Defaults to 20.

    Returns:
        torch.Tensor: A tensor of shape (nc x d) containing feature indices sorted in
                      descending order of their relevance for each class.
    """
    index_list = torch.empty((0, d), dtype=torch.int32)
    for i in range(nc):
        data = trFeatures[trY == i, :]
        mu = torch.mean(data, dim=0)
        data_centered = (data - mu).t()  # Center the data

        # Perform Singular Value Decomposition (SVD)
        U, S, V = torch.linalg.svd(data_centered)
        S_squared = S**2

        # Estimate the noise variance (sig_ML)
        sig_ML = (1 / (d - q)) * torch.sum(S_squared[q:])

        # Construct the loading matrix A_ML
        eigenvalues_adjusted = torch.diag(S_squared[:q]) - sig_ML * torch.eye(q)
        # Ensure the diagonal elements are non-negative before taking the square root
#        eigenvalues_adjusted = torch.relu(eigenvalues_adjusted)
        A_ML = U[:, :q] @ (eigenvalues_adjusted**(0.5))

        # Calculate the row sums of squares (feature relevance)
        row_ss = torch.sum(A_ML**2, dim=1)

        # Get the indices sorted by row sums of squares in descending order
        indices = torch.argsort(-row_ss)
        index_list = torch.cat((index_list, indices.unsqueeze(0)), dim=0)

    return index_list

In [97]:
def EMv2(X, S_x, W, Psi, r, n, max_iter=30):
    """
    Performs Expectation-Maximization (EM) algorithm for PPCA.

    Args:
        X (torch.Tensor): Data matrix (d x n).
        S_x (torch.Tensor): Sample covariance matrix (d x d).
        W (torch.Tensor): Initial loading matrix (d x r).
        Psi (torch.Tensor): Initial noise variance (d).
        r (int): Number of principal components.
        n (int): Number of samples.
        max_iter (int, optional): Maximum number of EM iterations. Defaults to 20.

    Returns:
        tuple: Updated loading matrix (W) and noise variance (Psi).
    """
    for t in range(max_iter):
        Psi_d = torch.diag(Psi)
        M1 = torch.linalg.inv(Psi_d + W @ W.t())
        beta = W.t() @ M1
        E_gam = beta @ X
        M2 = E_gam @ E_gam.t()
        E_ggam = n * torch.eye(r) - n * beta @ W
        E_ggam = E_ggam + M2

        M3 = torch.linalg.inv(E_ggam)
        W_new = X @ E_gam.t()
        W_new = W_new @ M3
        M4 = beta @ S_x
        Psi_new = torch.diag(S_x - W_new @ M4)
        Psi_new = torch.clamp(Psi_new, min=1e-6) # Ensure Psi is positive

        W = W_new
        Psi = Psi_new
    return (W, Psi)

def LFA_st(u, s, q, d):
    """
    Initializes the parameters for Linear Factor Analysis (LFA).

    Args:
        u (torch.Tensor): Left singular vectors from SVD.
        s (torch.Tensor): Singular values from SVD.
        q (int): Number of principal components.
        d (int): Number of features.

    Returns:
        tuple: Initial loading matrix (A_ML) and noise variance (Psi).
    """
    sig_ML = (1 / (d - q)) * torch.sum(s[q:])
    A_ML = u[:, :q] @ ((torch.diag(s[:q]) - sig_ML * torch.eye(q))**(0.5))
    Psi = sig_ML * torch.ones(d)
    return (A_ML, Psi)

def fs_lfa(trFeatures, trY, nc, d, q=20):
    """
    Performs Feature Selection using Linear Factor Analysis (LFA) for each class.

    Args:
        trFeatures (torch.Tensor): Training features (N x d).
        trY (torch.Tensor): Training labels (N).
        nc (int): Number of classes.
        d (int): Number of features.
        q (int, optional): Number of principal components to retain. Defaults to 20.

    Returns:
        torch.Tensor: A tensor of shape (nc x d) containing feature indices sorted in
                      descending order of their relevance (SNR) for each class.
    """
    index_list = torch.empty((0, d), dtype=torch.int32)
    for i in range(nc):
        data = trFeatures[trY == i, :]
        n = data.shape[0]
        mu = torch.mean(data, dim=0)
        data_centered_t = (data - mu).t()
        S = torch.cov(data_centered_t) + 0.01 * torch.eye(d)
        u, s, _ = torch.linalg.svd(data_centered_t)
        W_st, Psi_st = LFA_st(u, s**2, q, d)
        W_ML, Psi_ML = EMv2(data_centered_t, S, W_st, Psi_st, q, n)
        snr = torch.sum(W_ML**2, dim=1) / Psi_ML
        indices = torch.argsort(-snr)
        index_list = torch.cat((index_list, indices.unsqueeze(0)), dim=0)
    return index_list

In [98]:
# ELF
def ELF_st(X, r, d):
    """
    Initializes parameters for Exploratory Latent Factor (ELF) model.

    Args:
        X (torch.Tensor): Data matrix (n x d).
        r (int): Number of latent factors.
        d (int): Number of features.

    Returns:
        tuple: Initial latent factors (Gamma) and noise variance (Psi).
    """
    u, s, v = torch.linalg.svd(X)
    Gamma = u[:, :r] * s[:r]
    Psi = torch.ones(d, dtype=torch.float32)
    return (Gamma, Psi)

def ELF(X0, Gamma, Psi, n, epochs, r, d, tolerance=0.1):
    """
    Performs Exploratory Latent Factor (ELF) algorithm.

    Args:
        X0 (torch.Tensor): Centered data matrix (n x d).
        Gamma (torch.Tensor): Initial latent factors (n x r).
        Psi (torch.Tensor): Initial noise variance (d).
        n (int): Number of samples.
        epochs (int): Maximum number of iterations.
        r (int): Number of latent factors.
        d (int): Number of features.
        tolerance (float, optional): Convergence tolerance. Defaults to 0.1.

    Returns:
        torch.Tensor: Signal-to-noise ratio (SNR) for each feature (d).
    """
    W = torch.ones((d, r))
    for ep in range(epochs):
        # print(ep)
        # Updating W
        W_new = X0.t() @ Gamma
        M1 = torch.linalg.inv(Gamma.t() @ Gamma)
        W_new = W_new @ M1
        # Updating Gamma
        m1 = torch.diag(1 / Psi) @ W_new
        invMat = torch.linalg.inv(W_new.t() @ m1)
        Gamma = X0 @ m1 @ invMat
        # Orthogonalisation
        u, s, v = torch.linalg.svd(Gamma)
        M = v.t() @ torch.diag(s)
        W_new = (W_new @ M) / np.sqrt(n)
        Gamma = u[:, :r] * np.sqrt(n)
        # Updating weights (Psi)
        Xhat = Gamma @ W_new.t()
        Psi = torch.mean((X0 - Xhat)**2, dim=0)
        Psi[Psi < 0.01] = 0.01

        if torch.sqrt(torch.sum((W - W_new)**2)) < tolerance:
            print('ep', ep)
            break
        W = W_new
    snr = torch.sum(W_new**2, dim=1) / Psi
    return snr

epochs_elf = 10
r_elf = 10
def fs_ELF(trFeatures, trY, nc, d, r=r_elf, epochs=epochs_elf):
    """
    Performs Self-Exploratory Latent Factor (SELF) based feature selection for each class.

    Args:
        trFeatures (torch.Tensor): Training features (N x d).
        trY (torch.Tensor): Training labels (N).
        nc (int): Number of classes.
        d (int): Number of features.
        r (int, optional): Number of latent factors. Defaults to r_elf (10).
        epochs (int, optional): Maximum number of ELF iterations. Defaults to epochs_elf (100).
        device (str, optional): Device to run computations on ('cpu' or 'cuda'). Defaults to 'cpu'.

    Returns:
        torch.Tensor: A tensor of shape (nc x d) containing feature indices sorted in
                      descending order of their relevance (SNR) for each class.
    """
    idx_list = torch.empty((0, d), dtype=torch.int32)

    for i in range(nc):
        X = trFeatures[trY == i, :]
        mu = torch.mean(X, dim=0)
        X0 = (X - mu)

        Gamma, Psi = ELF_st(X0, r, d)
        n = X0.shape[0]
        snr = ELF(X0, Gamma, Psi, n, epochs, r, d)
        idx = torch.argsort(-snr).to(device)
        idx_list = torch.cat((idx_list, idx.unsqueeze(0)), dim=0)
    return idx_list

In [99]:
# HPCA
T_hpca = 5
r_hpca = 10
def heteroPCA(Cov, r, T, d):
    """
    Performs Heteroscedastic Principal Component Analysis (HPCA).

    Args:
        Cov (torch.Tensor): Covariance matrix (d x d).
        r (int): Number of principal components.
        T (int): Number of iterations.
        d (int): Number of features.

    Returns:
        tuple: Principal components (u) and the modified covariance matrix (N_tilda).
    """
    d_indi = torch.arange(d)
    N_prev = Cov - torch.diag(torch.diagonal(Cov))
    for t in range(T):
        # print(t)
        u, s, v = torch.linalg.svd(N_prev)
        N_tilda = u[:, :r] @ torch.diag(s[:r]) @ v[:r, :]
        N_tilda_D = torch.diagonal(N_tilda)
        N_prev[d_indi, d_indi] = N_tilda_D
    return (u[:, :r], N_tilda)

def fs_HPCA(trFeatures, trY, nc, d, r=r_hpca, T=T_hpca):
    """
    Performs Feature Selection using Heteroscedastic Probabilistic Principal Component Analysis (HPCA)
    for each class.

    Args:
        trFeatures (torch.Tensor): Training features (N x d).
        trY (torch.Tensor): Training labels (N).
        nc (int): Number of classes.
        d (int): Number of features.
        r (int, optional): Number of principal components for HPCA. Defaults to r_hpca (10).
        T (int, optional): Number of iterations for HPCA. Defaults to T_hpca (5).
        device (str, optional): Device to run computations on ('cpu' or 'cuda'). Defaults to 'cpu'.

    Returns:
        torch.Tensor: A tensor of shape (nc x d) containing feature indices sorted in
                      descending order of their relevance (SNR-like metric) for each class.
    """
    idx_list = torch.empty((0, d), dtype=torch.int32)
    for i in range(nc):
        X = trFeatures[trY == i, :]
        mu = torch.mean(X, dim=0)
        # std = torch.std(X,dim=0) # Standard deviation is not used here
        X_centered_t = (X - mu).t()
        Cov = torch.cov(X_centered_t) + 0.01 * torch.eye(d)
        u, _ = heteroPCA(Cov, r, T, d)

        Gamma = X_centered_t.t() @ u
        W = u

        n = X_centered_t.shape[1]
        u_gamma, s_gamma, v_gamma = torch.linalg.svd(Gamma)
        M = v_gamma.t() @ torch.diag(s_gamma)
        W_new = (W @ M) / np.sqrt(n)
        Gamma_new = u_gamma[:, :r] * np.sqrt(n)

        X_hat = Gamma_new @ W_new.t()

        Psi = torch.mean((X_centered_t.t() - X_hat)**2, dim=0)
        Sig = torch.sum(W_new**2, dim=1)

        snr_e = Sig / Psi
        idx = torch.argsort(-snr_e).to(device)
        idx_list = torch.cat((idx_list, idx.unsqueeze(0)), dim=0)
    return idx_list


In [100]:
class MahalanobisClassifier:
    """
    A classifier that uses Mahalanobis distance for classification.
    """
    def __init__(self, device='cpu'):
        """
        Initializes the MahalanobisClassifier.

        Args:
            device (str, optional): The device to perform computations on ('cpu' or 'cuda'). Defaults to 'cpu'.
        """
        self.device = device
        self.means = []
        self.inv_covariances = []

    def _calculate_class_parameters(self, index_list, features_train, trY, nc, nf):
        """
        Calculates the mean and inverse covariance matrix for the selected features of each class.

        Args:
            index_list (torch.Tensor): Tensor of feature indices for each class (nc x nf).
            features_train (torch.Tensor): Training features (N x d).
            trY (torch.Tensor): Training labels (N).
            nc (int): Number of classes.
            nf (int): Number of selected features.
        """
        self.means = []
        self.inv_covariances = []
        for i in range(nc):
            indices = index_list[i]
            features = features_train[trY == i, :]
            features_selected = features[:, indices]

            mean = torch.mean(features_selected, dim=0).to(self.device)
            self.means.append(mean)

            cov = torch.cov(features_selected.t()) + 0.1 * torch.eye(nf, device=self.device)
            inv_cov = torch.linalg.inv(cov)
            self.inv_covariances.append(inv_cov)

    def _mahalanobis_distance(self, x, inv_cov):
        """
        Calculates the Mahalanobis distance.

        Args:
            x (torch.Tensor): Data point (1 x nf).
            inv_cov (torch.Tensor): Inverse covariance matrix (nf x nf).

        Returns:
            torch.Tensor: Mahalanobis distance.
        """
        diff = (x - self.means[self.current_class]).to(self.device)
        dist = (diff @ inv_cov) @ diff.t()
        return dist

    def classify(self, features_valid, index_list, nc, n_val):
        """
        Classifies validation features based on Mahalanobis distance to class means.

        Args:
            features_valid (torch.Tensor): Validation features (N_val x d).
            index_list (torch.Tensor): Tensor of feature indices for each class (nc x nf).
            nc (int): Number of classes.
            n_val (int): Number of validation samples.

        Returns:
            torch.Tensor: Predicted class labels for the validation set.
        """
        out = torch.zeros(n_val, nc, device=self.device)
        for i in range(nc):
            self.current_class = i
            indices = index_list[i]
            x_valid_selected = features_valid[:, indices].to(self.device)
            inv_cov = self.inv_covariances[i]
            for j in range(n_val):
                out[j, i] = self._mahalanobis_distance(x_valid_selected[j], inv_cov)
        predicted_labels = torch.argmin(out, dim=1).cpu()
        return predicted_labels

def calculate_accuracy(classifier, features_train, features_valid, index_list, valY, trY, nc, nf):
    """
    Calculates the classification accuracy using the MahalanobisClassifier.

    Args:
        classifier (MahalanobisClassifier): An instance of the MahalanobisClassifier.
        features_train (torch.Tensor): Training features (N x d).
        features_valid (torch.Tensor): Validation features (N_val x d).
        index_list (torch.Tensor): Tensor of feature indices for each class (nc x nf).
        valY (torch.Tensor): Validation labels (N_val).
        trY (torch.Tensor): Training labels (N).
        nc (int): Number of classes.
        nf (int): Number of selected features.

    Returns:
        float: The classification accuracy.
    """
    n_val = features_valid.shape[0]
    classifier._calculate_class_parameters(index_list, features_train, trY, nc, nf)
    predictions = classifier.classify(features_valid, index_list, nc, n_val)
    accuracy = np.round((torch.sum(predictions == valY.cpu()).item() / n_val), 4)
    return accuracy

In [101]:
trY.shape

torch.Size([50000])

In [37]:
index_dic={}
time_dic={}

In [38]:
# Feature Selection using PPCA
start = time.time()
index_lst = fs_ppca(trFeatures, trY, nc, d, q=20)
index_dic['PPCA'] = index_lst
end = time.time()
time_dic['PPCA'] = end-start

In [44]:
time_dic

{'PPCA': 8.657537698745728}

In [41]:
# Feature Selection using LFA
index_lst = fs_lfa(trFeatures, trY, nc, d, q=20)
index_dic['LFA'] = index_lst


In [46]:
# Feature Selection using ELF
index_lst = fs_ELF(trFeatures, trY, nc, d, r=r_elf, epochs=epochs_elf)
index_dic['ELF'] = index_lst


In [50]:
# Feature Selection using HPCA
index_lst = fs_HPCA(trFeatures, trY, nc, d, r=r_hpca, T=T_hpca)
index_dic['HPCA'] = index_lst

In [137]:
mu = torch.mean(trFeatures,dim = 0)
std = torch.std(trFeatures,dim=0)

trfeatures = (trFeatures-mu)/std
valfeatures = (valFeatures -mu)/std

In [105]:
nf_list = [750, 1000, 1250, 1500, 1750, 2000, 2250, 2560]
accuracy_results = {}
classifier = MahalanobisClassifier()
for method, indices in index_dic.items():
    accuracy_list = []
    for nf in nf_list:
        print(f"Evaluating {method} with {nf} features...")
        index_sel = indices[:, :nf]
        accuracy = calculate_accuracy(classifier, trfeatures, valfeatures, index_sel, valY, trY, nc, nf)
        accuracy_list.append(accuracy)
        print(f"{method} ({nf} features) Accuracy: {accuracy}")
    accuracy_results[method] = accuracy_list


print("\nClassification Accuracy for Different Number of Features:")
for method, accuracies in accuracy_results.items():
    print(f"{method}:")
    for i, nf in enumerate(nf_list):
        print(f"  {nf} features: {accuracies[i]}")


Evaluating PPCA with 750 features...
PPCA (750 features) Accuracy: 0.6547
Evaluating PPCA with 1000 features...
PPCA (1000 features) Accuracy: 0.6936
Evaluating PPCA with 1250 features...
PPCA (1250 features) Accuracy: 0.7083
Evaluating PPCA with 1500 features...
PPCA (1500 features) Accuracy: 0.7212
Evaluating PPCA with 1750 features...
PPCA (1750 features) Accuracy: 0.7254
Evaluating PPCA with 2000 features...
PPCA (2000 features) Accuracy: 0.7301
Evaluating PPCA with 2250 features...
PPCA (2250 features) Accuracy: 0.7283
Evaluating PPCA with 2560 features...
PPCA (2560 features) Accuracy: 0.7281
Evaluating LFA with 750 features...
LFA (750 features) Accuracy: 0.6478
Evaluating LFA with 1000 features...
LFA (1000 features) Accuracy: 0.6736
Evaluating LFA with 1250 features...
LFA (1250 features) Accuracy: 0.6933
Evaluating LFA with 1500 features...
LFA (1500 features) Accuracy: 0.7096
Evaluating LFA with 1750 features...
LFA (1750 features) Accuracy: 0.7168
Evaluating LFA with 2000 f

In [138]:
train_data = data_utils.TensorDataset(trfeatures, trY)
test_data = data_utils.TensorDataset(valfeatures, valY)
######################################################################################################################

class FullyConnected(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FullyConnected, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.zero_()
            if module.bias is not None:
                module.bias.data.zero_()

    def forward(self, x):
        out = self.fc1(x)
        return out

######################################################################################################################
def CorrectPred(outputs, responses):
    softmax = torch.nn.functional.softmax(outputs, dim=1)
    prob = softmax.detach().cpu().numpy()
    predictions = np.argmax(prob, axis=1)
    res = responses.long().cpu().numpy()
    return np.sum(predictions == res)

def complement(total_dim, indices):
    all_indices = torch.arange(0, total_dim)
    mask = torch.ones(total_dim, dtype=torch.bool)
    mask[indices] = False
    complement_indices = all_indices[mask]
    return complement_indices
######################################################################################################################
def FSA(input_dim, w1,k,mu_fsa,current_epoch,M_s,epochs):
    M = int(np.floor(k + (input_dim - k) * max(0, (epochs - 2 * current_epoch) / (2 * current_epoch * mu_fsa + epochs))))
    M_s[current_epoch] = M
    if (current_epoch > 0) and (M_s[current_epoch] <= M_s[current_epoch - 1]):
        crit = torch.sum(w1**2, dim=0)
        ordered_ind = torch.argsort(-crit)[M:]
        w1[:, ordered_ind] = 0
    return(w1,M_s)

def TISP(w1,k):
    crit = torch.sum(w1**2,dim=0)
    orderd_ind = torch.argsort(-crit)
    selected_ind = orderd_ind[0:k]
    lamb = crit[orderd_ind[k]]**0.5
    mask = complement(d,selected_ind)
    w1 = w1 - lamb
    w1[:,mask] = torch.zeros((nc,d-k))
    return(w1)
    
def train_and_evaluate(input_dim, output_dim, method, k, epochs=30, learning_rate=0.01, mu_fsa=20, device='cpu'):
    model = FullyConnected(input_dim, output_dim).to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    loss_fn = CrossEntropyLoss().to(device)

    acc_history_train = []
    acc_history_test = []
    elapsed_time_total = 0
    M_s = np.zeros(epochs)
    
    batchsize = 128  # Initial batch size
    train_loader = data_utils.DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=0)
    test_loader = data_utils.DataLoader(test_data, batch_size=batchsize, shuffle=True, num_workers=0)

    for epoch in range(epochs):
        correct_train = 0
        total_train = 0
        print(f"Epoch {epoch+1}/{epochs}")
        model.train()
        start_time = time.time()
        for i, (data, Y) in enumerate(train_loader):
            data = data.to(device)
            Y = Y.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, Y.long())
            correct_train += CorrectPred(outputs, Y)
            total_train += Y.size(0)
            loss.backward()
            optimizer.step()

        acc_train = correct_train / total_train
        acc_history_train.append(acc_train)

        w1 = model.fc1.weight.detach()
        if method == 'FSA':
            w1,M_s = FSA(input_dim,w1,k,mu_fsa,epoch,M_s,epochs)
        if method == 'TISP':
            w1 = TISP(w1,k)
            
        with torch.no_grad():
            model.fc1.weight = nn.Parameter(w1)
        elapsed_time_epoch = time.time() - start_time
        elapsed_time_total += elapsed_time_epoch

        if epoch % 5 == 0:
            correct_test = 0
            total_test = 0
            model.eval()
            with torch.no_grad():
                for data_test, Y_test in test_loader:
                    data_test = data_test.to(device)
                    Y_test = Y_test.to(device)
                    outputs_test = model(data_test)
                    correct_test += CorrectPred(outputs_test, Y_test)
                    total_test += Y_test.size(0)
            acc_test = correct_test / total_test
            acc_history_test.append(acc_test)

        if (epoch + 1) % 10 == 0:
            batchsize *= 2
            train_loader = data_utils.DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=0)
            test_loader = data_utils.DataLoader(test_data, batch_size=batchsize, shuffle=True, num_workers=0)
            print(f"  Batch size increased to: {batchsize}")

    print(f"Total Training Time: {elapsed_time_total:.2f}s")
    print(f"  Test Accuracy: {acc_test:.4f}")
    return acc_history_test,elapsed_time_total


In [None]:
methods = ['FSA','TISP']
def get_accuracy_for_nf(train_loader,test_loader, d, nc, nf_list):
    accuracy_results = {}
    time = {}
    for m in methods:
        for nf in nf_list:
            print(f"\nTraining and evaluating with {nf} features...")
            acc_test_hist, training_time = train_and_evaluate(d, nc, m, nf, epochs=30, learning_rate=0.01, mu_fsa=20, device='cpu')
            accuracy_results[nf] = acc_test_hist[-1]
            if m not in time.keys():
                time[m] = [training_time]
                accuracy_results[m] = [acc_test_hist[-1]]
            else:
                time[m].append(training_time)
                accuracy_results[m].append(acc_test_hist[-1])
    return accuracy_results,training_time


In [None]:
accuracy, training_time = get_accuracy_for_nf(train_loader,test_loader, d, nc, nf_list)


Training and evaluating with 750 features...
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
  Batch size increased to: 256
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
