This notebook was created as an example to use Federated Learning on GWAS data. 

It specifically uses Polygenic Risk Scores (PRS) accompanied with demographical non-genetic data to calculate the likelihood of an individual getting diagnosed with a specific disease. Data is collected for a variety of diseases with multiple different methods. This specific example uses the data obtained by LDpred2 (Linkage Disequilibrium matrix) method for a patients likelihood for CAD (Coronary Artery Disease).  

Importing the libraries

In [12]:
#!/usr/bin/python
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
import time
from sklearn.metrics import roc_auc_score
import os
import utilities as util
import re

Load the Polygenic Risk Score and Non-genetic Data

In [None]:
### user input params
PRS_score_path = './PRS'
non_genetic_data_path = './non_genetic_data'
label_path = './eid_label'
category_list =['primary_demographics','lifestyle','physical_measures']#'physical_measures' 'lifestyle'
disease = 'CAD'
PRS_Method = 'LDpred2'# DBSLMM, SBLUP, PRSice2, P+T, LDpred2
save_dir = './PRSIMD_results_MRfactors'

# Display the disease being studied
print("Disease being studied:", disease)

# Display the method used for PRS calculation
print("PRS Method used:", PRS_Method)# Concatenate score, non_genetic_data, and label into a single DataFrame and save as CSV


*(Optional)* Parse and print some information about the data

In [None]:
def loading_data(PRS_score_path, PRS_Method, Disease,non_genetic_data_path, label_path, category_list):
    score_dict = {}
    non_genetic_data_dict = {}
    label_dict = {}
    valt = {'train':'train','validation':'val', 'test':'test'}
    for _set in ['train','validation','test']:
        score, _ = util.readbyLines(os.path.join(PRS_score_path, _set + '_score_' + PRS_Method + '_' + Disease + '.txt'),datatype="float")
        score_dict[_set] =  Variable(torch.FloatTensor(score))
        
        collect = []
        for categ in category_list:
            collect.append( pd.read_table(os.path.join(non_genetic_data_path, 'data_mat', Disease.lower(), categ + '_' + valt[_set] + '_data.txt')))
            data = pd.concat(collect, axis=1)
            data['Age when attended assessment centre'] /= 10 # used age group
            data['Age when attended assessment centre'] = data['Age when attended assessment centre'].astype('int')
        #non_genetic_data_dict[_set] = Variable(torch.FloatTensor(data))
        non_genetic_data_dict[_set] = data
      
        y = pd.read_table(os.path.join(label_path, Disease.lower()+'_' + valt[_set] + '_y.txt'),header=None).values.ravel()      
        label_dict[_set] = Variable(torch.LongTensor(y.astype('long'))) 
    
    return score_dict, non_genetic_data_dict, label_dict

score, non_genetic_data, label = loading_data(PRS_score_path, PRS_Method, Disease, non_genetic_data_path, label_path, category_list)

# Load the Polygenic Risk Score and Non-genetic Data
print("PRS score and non-genetic data loaded successfully.")
# Display the first few rows of the loaded data
print("PRS Score (train):", score['train'][:5])
print("Non-genetic data (train):", non_genetic_data['train'].head())
# Display the labels for the training set
print("Labels (train):", label['train'][:5])

# Display the shape of the loaded data
print("Shape of PRS score (train):", score['train'].shape)
print("Shape of non-genetic data (train):", non_genetic_data['train'].shape)
print("Shape of labels (train):", label['train'].shape)
# Display the categories of non-genetic data
print("Categories of non-genetic data:", category_list)



In [None]:
def get_data_info_for_model_construction(non_genetic_data, Disease, non_genetic_data_path, category_list):    
    collect = []
    for cate in category_list:
        collect.append(pd.read_table(os.path.join(non_genetic_data_path, Disease.lower()+'_'+cate+'.txt')))
    df_factors_infor = pd.concat(collect,axis=0)
    n_factors = df_factors_infor.shape[0]
    
    # get the info for model initialization
    col_name = non_genetic_data['train'].columns.values
    n_cols_factor = [] # the number of columns that each factor data accounted
    count = 0
    for name in df_factors_infor['field_name'].values:  
        for col in col_name:
            tmp = col.split(':')   
            if name == tmp[0]:
                count+=1
      
        n_cols_factor.append(count)
        count = 0
                   
    return n_factors, n_cols_factor

n_factors, n_cols_factor = get_data_info_for_model_construction(non_genetic_data, Disease, non_genetic_data_path, category_list)

# Display the number of factors and their column counts
print("Number of factors:", n_factors)
print("Number of columns for each factor:", n_cols_factor)
print("Double check the sum of factors:", sum(n_cols_factor))



Generate the merged CSV file to add to our hospital node.

In this example, the CSV and the node is already generated.

In [None]:
# Concatenate score, non_genetic_data, and label into a single DataFrame and save as CSV

# Example for the 'train' set; repeat for 'validation' and 'test' as needed
score_df = pd.DataFrame(score['train'].numpy(), columns=['PRS_score'])
label_df = pd.DataFrame(label['train'].numpy(), columns=['target'])
merged_df = pd.concat([score_df.reset_index(drop=True), non_genetic_data['train'].reset_index(drop=True), label_df.reset_index(drop=True)], axis=1)

# Save to CSV
merged_df.to_csv('train_merged.csv', index=False)
print('Merged CSV saved as train_merged.csv')

# Display the first few rows
print(merged_df.shape)


Define the Training Plan with the customized Logistic Regression Model

***(TODO)*** Improve the model

In [28]:
from fedbiomed.common.training_plans import TorchTrainingPlan

class LogisticRegressionTrainingPlan(TorchTrainingPlan):

    # model for the genetic factor
    class logistic_regression_model(nn.Module):
        def __init__(self, disease, n_factors, n_cols_factor):
            super().__init__()
            
            # self.w = nn.Parameter(torch.FloatTensor(torch.randn(2,1))) 
            # self.b = nn.Parameter(torch.FloatTensor(torch.randn(2,1))) 
            
            self.w = nn.Parameter(torch.FloatTensor(torch.randn(1,1))) 
            self.b = nn.Parameter(torch.FloatTensor(torch.randn(1,1))) 


            self.activation = torch.tanh
            self.disease = disease
            self.n_factors = n_factors
            self.n_cols_factor = n_cols_factor
            
            n_cols = np.sum(self.n_cols_factor)       
            self.W = nn.Parameter(torch.FloatTensor(torch.randn(2, n_cols)))       
            self.Gamma = nn.Parameter(torch.FloatTensor(torch.randn(2, n_factors+1))) 
            
        def forward(self, f_data):
            
            score = f_data[:, 0]
            f_data = f_data[:, 1:]  # Exclude the score column

            logit_out = 1/(1+torch.exp(-(self.w * score + self.b))) 
            ne_logit_out = 1-logit_out
            # logit_genetic = torch.cat((ne_logit_out,logit_out),dim=0)
            logit_genetic = torch.stack((ne_logit_out, logit_out), dim=1)


            # f_data is the non-genetic data, which is a 2D tensor of shape (batch_size, n_cols)
            f_data = f_data.unsqueeze(1)
            f_data = f_data.repeat(1,2,1)
            eleW_product = self.W * f_data # (n_classes, n_cols) * (batch_size, n_classes, n_cols) 
            
            start = 0     
            phi_collect = []       
            for i in self.n_cols_factor:
                
                one_factor = eleW_product[:,:, start: start+i]
            
                if one_factor.shape[2] > 1:
                    one_factor = torch.sum(one_factor, dim=2)
                else:
                    one_factor = one_factor.squeeze(2)
                phi_collect.append(one_factor)
                
                start += i
            
            
            phi = torch.stack(phi_collect,dim=2)  
            phi = self.activation(phi)
            phi_bias = torch.ones(phi.shape[0],phi.shape[1],1)
            phi = torch.cat((phi, phi_bias),dim=2)

            Gamma_times_Phi = self.Gamma * phi
            
            temp = torch.sum(Gamma_times_Phi, dim=2)
            temp = torch.exp(temp)
            temp_sum = torch.sum(temp, dim=1)
            temp_sum = temp_sum.unsqueeze(1)    
            temp_sum = temp_sum.repeat(1,2)     
            logit_non_genetic = temp / temp_sum
        
            return (logit_genetic*0.5 + logit_non_genetic*0.5)


    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms",
                "from torchvision.transforms import ToTensor",
                'from torch.optim import Adam, AdamW, SGD',
                "import torch.nn.functional as F",
                "import pandas as pd",
                "import numpy as np",
                ]

        return deps
    
    def init_model(self, model_args: dict):
        """Defines your model here"""
        model = self.logistic_regression_model(disease=model_args.get('disease', 'CAD'), 
                                                n_factors=model_args.get('n_factors', 32), 
                                                n_cols_factor=model_args.get('n_cols_factor', [1, 1, 3, 2, 2, 5, 2, 1, 2, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])    
                                                )       
        return model
    
    def init_optimizer(self, optimizer_args):
        """Defines your optimizer here"""
        optimizer = Adam(self.model().parameters(), 
                         weight_decay = optimizer_args.get('weight_decay', 0.0001), 
                         amsgrad = optimizer_args.get('amsgrad', True))
        return optimizer

    def training_data(self):
        """Defines data handling/parsing here"""
        dataset = pd.read_csv(self.dataset_path, delimiter=',')

        cols_except_last = dataset.columns[:-1].tolist()
        regressors_col = cols_except_last
        target_col = ['target']

        return DataManager(dataset=dataset[regressors_col], target=dataset[target_col])

    def training_step(self, data, target):
        """Defines cost function and how to compute loss"""
        predictions = self.model().forward(data)
        logistic_re_loss = F.cross_entropy(predictions, target)
        return logistic_re_loss

Define the Model and Training Arguments

In [None]:
# Define model and training arguments

model_args = {
    # 'disease': disease,
    # 'n_factors': n_factors,
    # 'n_cols_factor': n_cols_factor,
}

training_args = {
    'loader_args': { 
        'batch_size': 128, 
    },
    'optimizer_args': {
        'weight_decay': 0.0001,
        'amsgrad': True,
    },
    #'num_updates': 2,
    'epochs': 300,
    'dry_run': False,
    'log_interval': 10,
    'test_ratio' : 0.1,
    'test_on_global_updates': True,
    'test_on_local_updates': True,
}

Define and Run the experiment with Polygenic Risk Score

In [27]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['lr_train']
num_rounds = 5

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=LogisticRegressionTrainingPlan,
                 training_args=training_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 tensorboard=True
                )

2025-07-17 17:41:13,628 fedbiomed INFO - Starting researcher service...

2025-07-17 17:41:13,630 fedbiomed INFO - Waiting 3s for nodes to connect...

2025-07-17 17:41:13,631 fedbiomed ERROR - Researcher gRPC server has stopped. Please try to restart: Failed to bind to address localhost:50051; set GRPC_VERBOSITY=debug environment variable to see detailed error message.


--------------------
Fed-BioMed researcher stopped due to exception:
ErrorNumbers.FB628: Error while getting all nodes connected:  Communication client is not initialized.
--------------------


FedbiomedSilentTerminationError: 



In [None]:
exp.run()