In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from custom_functions import processing
import numpy as np
import pandas as pd

In [None]:
dataset = pd.read_excel('data_v10.xlsx')
dataset = dataset.drop(['Foldername'], axis=1)

In [None]:
# feature names
feature_names = dataset.drop(['Fried_State','Fried_Score','Frailty_State','Frailty_Score',
                              'Item_1', 'Item_2', 'Item_3', 'Item_4', 'Item_5',
                              'Weight_Diff', 'HADS_D_Score', 'walk_time_4m', 'EXAMCLIN02', 'grip'], axis=1).columns.to_list()


In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, input_dim):
        super(MultiTaskModel, self).__init__()

        # Shared layers for feature extraction
        self.shared_layers = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Task-specific layers for Frailty_State
        self.frailty_head = nn.Linear(64, 2)

        # Task-specific layers for Fried_State
        self.fried_head = nn.Linear(64, 2)

    def forward(self, x):
        # Pass through shared layers
        shared_out = self.shared_layers(x)
        print("Shared Layer Output:", shared_out)  # Prints the common features

        # Capture Common Feature Space
        common_features_space = shared_out.detach().numpy()

        # Pass through task-specific layers
        frailty_out = self.frailty_head(shared_out)

        # Capture Frailty-Specific Feature Space
        frailty_features_space = frailty_out.detach().numpy()

        fried_out = self.fried_head(shared_out)

        # Capture Fried-Specific Feature Space
        fried_features_space = fried_out.detach().numpy()

        return frailty_out, fried_out, common_features_space, frailty_features_space, fried_features_space

    def print_important_features(self):
        # Access the weights of the first shared layer
        first_layer_weights = self.shared_layers[0].weight.data
        for i, neuron_weights in enumerate(first_layer_weights):
            print(f"Neuron {i+1} in first hidden layer:")
            important_features = [(weight, feature_names[j]) for j, weight in enumerate(neuron_weights)]
            important_features.sort(reverse=True)
            for weight, feature in important_features[:5]:  # Top 10 features for each neuron
                print(f"{feature}: {weight}")



In [None]:
# Hyperparameters
input_dim = 807  # feature dimension
learning_rate = 0.001

In [None]:
# Initialize the model, loss, and optimizer
model = MultiTaskModel(input_dim)
frailty_loss_fn = FocalLoss()
fried_loss_fn = FocalLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
X_sample = dataset.drop(['Fried_State','Fried_Score','Frailty_State','Frailty_Score',
                  'Item_1','Item_2','Item_3','Item_4','Item_5',
                  'Weight_Diff', 'HADS_D_Score','walk_time_4m','EXAMCLIN02','grip'], axis=1).values

y1_sample = dataset['Frailty_State'].values
y2_sample = dataset['Fried_State'].values

# impute X_sample
X_sample,_ = processing(X_sample, 5,[1,3,5,7,9], verbose=False)

X_sample = torch.from_numpy(X_sample).float()
y1_sample = torch.from_numpy(y1_sample).long()
y2_sample = torch.from_numpy(y2_sample).long()


In [None]:
common_features_list = []
frailty_features_list = []
fried_features_list = []

# Training loop
for epoch in range(100):  # Replace with your number of epochs
    # Forward pass
    frailty_out, fried_out, common_features, frailty_features, fried_features = model(X_sample)

    # Compute losses
    loss_frailty = frailty_loss_fn(frailty_out, y1_sample)
    loss_fried = fried_loss_fn(fried_out, y2_sample)

    # Combine losses
    combined_loss = loss_frailty + loss_fried

    # Backward pass and optimization
    optimizer.zero_grad()
    combined_loss.backward()
    optimizer.step()

    common_features_list.append(common_features)
    frailty_features_list.append(frailty_features)
    fried_features_list.append(fried_features)

    # Print epoch info
    #print(f"Epoch {epoch+1}, Frailty Loss: {loss_frailty.item()}, Fried Loss: {loss_fried.item()}")
    #model.print_important_features()



In [None]:
# After training, to get the weights of the first shared layer
first_layer_weights = model.shared_layers[0].weight.data.numpy()

# For each neuron in the first layer
for i, weights in enumerate(first_layer_weights):
    print(f"Neuron {i+1}")

    # Get the top contributing original features
    feature_contributions = [(abs(weight), feature) for weight, feature in zip(weights, feature_names)]
    feature_contributions.sort(reverse=True)

    print("Top contributing original features:")
    for weight, feature in feature_contributions[:5]:
        print(f"{feature}: {weight}")


Neuron 1
Top contributing original features:
FROPCOM0004_SQ001_SQ002__850_mg: 0.17891156673431396
FROPCOM0004_SQ003_SQ003__0_1_1: 0.15518584847450256
FROPCOM0004_SQ002_SQ001__ATORVASTADINE: 0.1537536233663559
FROPCOM0004_SQ001_SQ002__5_mg: 0.14717112481594086
FROPCOM0004_SQ004_SQ002__30_mg___10_mg___500_mg: 0.14482593536376953
Neuron 2
Top contributing original features:
FROPCOM0004_SQ001_SQ001__PRAVASTINE: 0.16537535190582275
FROPCOM0004_SQ003_SQ003__1_0_0_si_besoin: 0.16301164031028748
FROPCOM0004_SQ003_SQ002__10_mg_LP: 0.16201867163181305
FROPCOM0004_SQ003_SQ001__METFORMINE: 0.15564174950122833
FROPCOM0004_SQ004_SQ002__1000_mg: 0.1544954627752304
Neuron 3
Top contributing original features:
FROPCOM0004_SQ003_SQ001__BISOCE: 0.0620579831302166
FROPCOM0007_SQ021_: 0.06181642785668373
FROPCOM0004_SQ006_SQ002__40_mg_: 0.06064973771572113
Gender: 0.058897167444229126
FROPCOM0004_SQ002_SQ003__1_0_0: 0.05877196043729782
Neuron 4
Top contributing original features:
FROPCOM0004_SQ002_SQ001__P

In [None]:
common_features_list

In [None]:
model.print_important_features()

FROPCOM0004_SQ001_SQ002__850_mg: 0.17891156673431396
FROPCOM0004_SQ003_SQ003__0_1_1: 0.15518584847450256
FROPCOM0004_SQ002_SQ001__ATORVASTADINE: 0.1537536233663559
FROPCOM0004_SQ001_SQ002__5_mg: 0.14717112481594086
FROPCOM0004_SQ004_SQ002__30_mg___10_mg___500_mg: 0.14482593536376953
FROPCOM0004_SQ002_SQ001__ALPRAZOLAM: 0.14448899030685425
FROPCOM0004_SQ004_SQ003__si_besoin: 0.13941089808940887
FROPCOM0004_SQ001_SQ001__OMEPRAZOLE: 0.1369500309228897
FROPCOM0004_SQ004_SQ003__0_1_0: 0.1355278193950653
FROPCOM0004_SQ001_SQ001__GINKGO: 0.1355104148387909
FROPCOM0004_SQ004_SQ002__1000_mg: 0.1544954627752304
FROPCOM0004_SQ002_SQ001__PERINDOPRIL: 0.14695857465267181
FROPCOM0004_SQ004_SQ001__PARACETAMOL: 0.14160621166229248
FROPCOM0004_SQ001_SQ001__CETIRIZINE: 0.1374562680721283
FROPCOM0004_SQ004_SQ003__si_besoin: 0.13676141202449799
FROPCOM0004_SQ002_SQ001__ATORVASTATINE: 0.12111683189868927
FROPCOM0004_SQ002_SQ001__FLUCONAZOLE: 0.10650429874658585
FROPCOM0004_SQ002_SQ002__5_mg_: 0.10386402159