In [18]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [19]:
# Load the data
df1 = pd.read_excel(
    "/home/gddaslab/mxp140/sclerosis_project/miRNA_signal_hsa_number2.xlsx",
    engine="openpyxl",
    sheet_name="Sheet1",
)

# Drop non-feature columns
df = df1.drop(columns=["ID", "Transcript_ID"])
df = df.iloc[:, 10:]

# Label the columns based on their types
labels = {"aHC": 0, "sMS": 1, "aMS": 2, "aPOMS": 3, "sPOMS": 4, "pBar": 5}

# Create target labels for each column
y = []
for col in df.columns:
    for key in labels.keys():
        if col.startswith(key):
            y.append(labels[key])
            break

In [20]:
# Convert DataFrame to tensor
X = df.T.values
y = y

In [21]:
# Standardize features (optional but recommended)
scaler = StandardScaler()

In [22]:
class SoftmaxRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SoftmaxRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        # print(list(x[0]))
        return self.linear(x)


class ElasticNetLoss(nn.Module):
    def __init__(self, model, alpha=1.0, l1_ratio=0.5):
        super(ElasticNetLoss, self).__init__()
        self.model = model
        self.alpha = alpha
        self.l1_ratio = l1_ratio
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, outputs, targets):
        ce_loss = self.cross_entropy_loss(outputs, targets)
        l1_norm = sum(param.abs().sum() for param in self.model.parameters())
        l2_norm = sum(param.pow(2).sum() for param in self.model.parameters())
        elastic_net_penalty = self.alpha * (
            self.l1_ratio * l1_norm + (1 - self.l1_ratio) * l2_norm
        )
        return ce_loss + elastic_net_penalty

In [23]:
X.shape

(31, 4570)

In [24]:
# Load the model
input_dim = X.shape[1]
output_dim = len(torch.unique(torch.tensor(y)))
model = SoftmaxRegression(input_dim, output_dim)
model.load_state_dict(torch.load("softmax_classifier_wo_pHC.pth"))
model.eval()  # Set the model to evaluation mode

SoftmaxRegression(
  (linear): Linear(in_features=4570, out_features=6, bias=True)
)

In [25]:
parameters = []
for param in model.parameters():
    print(param)
    print(param.shape)
    parameters.append(param)

Parameter containing:
tensor([[-4.3036e-05,  4.0040e-05, -1.3318e-05,  ..., -2.8754e-05,
          3.7461e-05, -4.3856e-05],
        [ 3.5014e-05, -6.9638e-05, -6.4691e-06,  ...,  5.0804e-05,
          2.1818e-05,  3.7107e-05],
        [ 3.8762e-05, -1.8031e-05, -5.3990e-06,  ..., -5.5664e-06,
         -1.0579e-05, -3.0961e-05],
        [ 1.7438e-05,  4.1048e-05,  4.3546e-05,  ...,  3.7342e-05,
         -1.4173e-05,  2.5940e-02],
        [-4.3082e-06,  3.1483e-05,  9.1464e-06,  ...,  2.0700e-02,
          2.9029e-05, -1.9243e-05],
        [-2.3558e-05,  1.8010e-05,  1.6457e-05,  ..., -2.0188e-03,
         -4.2766e-05, -2.7076e-05]], requires_grad=True)
torch.Size([6, 4570])
Parameter containing:
tensor([-1.6991e-05, -1.3707e-05, -2.3782e-05,  5.3495e-05, -6.8918e-04,
        -5.0310e-05], requires_grad=True)
torch.Size([6])


In [26]:
# Evaluation
import numpy as np

with torch.no_grad():
    model.eval()
    weights = parameters[0].detach_().numpy()
    feature_importance = np.abs(weights)
    aggregated_importance = np.sum(feature_importance, axis=0)
    ranking_indices_for_miRNA = np.argsort(aggregated_importance)[::-1]
    top_indices = ranking_indices_for_miRNA[:10]
    top_miRNA_signals = X[:, top_indices]
    top_miRNA_patient = df1["Transcript_ID"].values[top_indices]

In [29]:
weights.shape

(6, 4570)

In [None]:
np.argsort(weights, axis=1)[:, -5:]

In [27]:
print(pd.DataFrame(top_miRNA_patient))

                 0
0   hsa-miR-127-3p
1     hsa-mir-4500
2      hsa-miR-607
3  hsa-miR-2276-5p
4  hsa-miR-487a-5p
5   hsa-miR-10b-5p
6     hsa-mir-5003
7     hsa-miR-4316
8   hsa-miR-18a-3p
9      hsa-mir-875
