# Prediction based on Linear model, GMM_class and GMM_patient

In [None]:
import pickle
import scipy
import scipy.io
import os
import numpy as np
import scanpy as sc
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import sklearn.model_selection as sks
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
from matplotlib.pyplot import rc_context
import sklearn.metrics as skm
import collections
from sklearn.model_selection import KFold
import torch
import traceback
import random
import pathlib
import sklearn.mixture
import math
from lib import *

In [None]:
class Linear(torch.nn.Module):
    def __init__(self, dim, states):
         super(Linear, self).__init__()
         self.states = states
         if states == 1:
             self.layer = torch.nn.Linear(dim, states)
         else:
             self.layer = torch.nn.Linear(dim, states - 1)
         self.layer.weight.data.zero_()
         if self.layer.bias is not None:
             self.layer.bias.data.zero_()

    def forward(self, x):
        if self.states == 1:
            return torch.sum(self.layer(x), dim=-2, keepdim=True)
        else:
            return torch.cat([torch.zeros(1, 1), torch.sum(self.layer(x), dim=-2, keepdim=True)], dim=1)
class Aggregator(torch.nn.Module):
    
    def __init__(self):
        super(Aggregator, self).__init__()

    def forward(self, x):
        return torch.mean(x, dim=0).unsqueeze_(0)


In [None]:

def train_classifier(Xtrain, Xvalid, Xtest, classifier, regularize=None,
                     eta=1e-8, iterations=3, stochastic=True, cuda=False, state=None, regression=False,path=''):

#     logger = logging.getLogger(__name__)

    if torch.cuda.is_available() and cuda:
        classifier.cuda()

    if regression:
        criterion = torch.nn.modules.MSELoss()
    else:
        criterion = torch.nn.modules.CrossEntropyLoss()
#     if len(Xtrain)>0:
#         kf = KFold(n_splits=5)
#         X_train = Xtrain
#         res_kfold=kf.split(X_train)
#         for train_index, test_index in res_kfold: 
#             Xtrain, Xvalid = list(np.array(X_train)[train_index]), list(np.array(X_train)[test_index])
#             pass
    batch = {"train": [None for _ in range(len(Xtrain))],
             "valid": [None for _ in range(len(Xvalid))],
             "test": [None for _ in range(len(Xtest))]}

    optimizer = torch.optim.SGD(classifier.parameters(), lr=eta, momentum=0.9)

    log = []
    best_res = {"accuracy": -float("inf")}
    best_res = {"loss": float("inf")}
    best_model = None


    for iteration in range(iterations):
#         logger.debug("Iteration #" + str(iteration + 1) + ":")
        # logger.debug(list(classifier.parameters()))
#         if len(Xtrain)>0:
#             kf = KFold(n_splits=5).split(X_train)
#         else:
#             kf = ['e']
#         for elem in kf: 
#             if len(Xtrain)>0:
#                 train_index, test_index=elem
#                 Xtrain, Xvalid = list(np.array(X_train)[train_index]), list(np.array(X_train)[test_index])
        print(iteration)
        for dataset in ["train", "valid"] + (["test"] if (iteration == (iterations - 1)) else []):
                if dataset == "train":
#                     logger.debug("    Training:")
                    X = Xtrain
                elif dataset == "valid":
#                     logger.debug("    Validation:")
                    X = Xvalid
                elif dataset == "test":
#                     logger.debug("    Testing:")
                    X = Xtest
                else:
                    raise NotImplementedError()
                n = len(X)
                total = 0.
                correct = 0
                prob = 0.
                loss = 0.
                y_score = []
                pred_=[]
                y_true = []
                reg = 0.
                for (start, (x, y, *_)) in enumerate(X):
                    if batch[dataset][start] is None:
                        if isinstance(x, torch.Tensor):
                            pass
                        elif isinstance(x, np.ndarray):
                            x = torch.Tensor(x)
                        else:
                            x = x.tocoo()
                            v = torch.Tensor(x.data)
                            i = torch.LongTensor([x.row, x.col])
                            x = torch.sparse.FloatTensor(i, v, x.shape)

                        if dataset != "test":
                            if regression:
                                y = torch.FloatTensor([y])
                            else:
                                y = torch.LongTensor([y])

                        if torch.cuda.is_available() and cuda:
                            x = x.cuda()
                            if dataset != "test":
                                y = y.cuda()

                        batch[dataset][start] = (x, y)
                    else:
                        x, y = batch[dataset][start]

#                     t = time.time()

                    z = classifier(x)
                    print(z)
                    if regression:
                        if len(z.shape) == 2:
                            if z.shape[1] == 1:
                                z = z[:, 0]
                            elif z.shape[1] == 2:
                                z = z[:, 1]
                        y_score.append(z[0].detach().numpy().item())
                    else:
                        y_score.append((z[0, 1] - z[0, 0]).detach().numpy().item())
                    y_true.append(y.detach().numpy().item())
                    pred = torch.argmax(z)
    #                 print(pred)
                    pred_.append(torch.argmax(z))
                    if dataset != "test":
                        if not regression:
                            prob += (torch.exp(z[0, y] - logsumexp(z))).detach().cpu().numpy()[0]
                            correct += torch.sum(pred == y).cpu().numpy()
                        l = criterion(z, y)
                        if stochastic:
                            loss = l
                        else:
                            loss += l
                        total += l.detach().cpu().numpy()
                        if regularize is not None:
                            r = regularize(classifier)
                            loss += r
                            reg += r.detach().cpu().numpy()

                    if dataset == "train" and (stochastic or start + 1 == len(X)):
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                    elif dataset == "test":
                        pred = pred.numpy()
                        if len(pred.shape) != 0:
                            pred = pred[0]
                        if state is not None:
                            pred = state[pred]

    #                     print(str(pred))
#                         logger.info(y + "\t" + str(pred))

                if dataset != "test" and n != 0:
                    res = {}
                    res["loss"] = total / float(n)
                    res["accuracy"] = correct / float(n)
                    res["soft"] = prob / float(n)
                    if any(map(math.isnan, y_score)):
                        res["auc"] = float("nan")
                        res["r2"] = float("nan")
                    elif regression:
                        res["auc"] = float("nan")
    #                     print(np.unique(y_score))
    #                     print(np.unique(y_true))

    #                     onehot_encoder = OneHotEncoder(sparse=False)
    #                     integer_encoded = np.reshape(y_score,(len(y_score), 1))
    #                     y_score = onehot_encoder.fit_transform(integer_encoded)
                        # invert first example
                        res["r2"] = sklearn.metrics.r2_score(y_true, y_score, multi_class='ovo')
                    else:
    #                     onehot_encoder = OneHotEncoder(sparse=False)
    #                     integer_encoded = np.reshape(y_score,(len(y_score), 1))
    #                     y_score = onehot_encoder.fit_transform(integer_encoded)
                        res["auc"] = sklearn.metrics.roc_auc_score(y_true, y_score, multi_class='ovo')
                        res["r2"] = float("nan")

#                     logger = logging.getLogger(__name__)
#                     logger.debug("        Loss           " + str(res["loss"]))
#                     logger.debug("        Accuracy:      " + str(res["accuracy"]))
#                     logger.debug("        Soft Accuracy: " + str(res["soft"]))
#                     logger.debug("        AUC:           " + str(res["auc"]))
#                     logger.debug("        R2:            " + str(res["r2"]))

#                     if regularize is not None:
#                         logger.debug("        Regularize:    " + str(reg / float(n)))
                if dataset == "train":
                    log.append([])
                log[-1].append((total / float(n), correct / float(n)) if n != 0 else (None, None))
                if dataset == "valid":
#                     print("-----------------------------------------------------------")
#                     print(res)                  
                    if res["loss"] <= best_res["loss"]:
                        best_res = res
                        best_model = copy.deepcopy(classifier.state_dict())
#                         if iterations - 1 ==0:
#                             with open(path+'y_score3_', 'wb') as fp:
#                                 pickle.dump(y_score, fp)
#                             pred= torch.stack(pred_)
#                             with open(path+'pred3_', 'wb') as fp:
#                                 pickle.dump(pred.tolist(), fp)  
#                             with open(path+'y_true3_', 'wb') as fp:
#                                 pickle.dump(y_true, fp)                              
    print("**********************************************")
    print(best_res)                        
    classifier.load_state_dict(best_model)
              

    # torch.save(classifier.state_dict(), "model.pt")
    return classifier, best_res


In [None]:
for seed in [42]:
    linear(Xtrain, Xtest, seed, 1000)

In [None]:

def get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers):
            print("get data")
            df = markers.loc[markers["cluster"].isin(set_clusters),:]
            feat_tab = df.groupby('cluster')
            df2= feat_tab.apply(lambda x: x.sort_values(["avg_log2FC"], ascending=False)).reset_index(drop=True)
            feat=df2.groupby('cluster').head(n)
            print("data got")
            
            idx_te= np.where(adata1_c.isin (feat.gene.values))[0] 
            idx_tr= np.where(adata_c.isin (adata1_c[idx_te]))[0]            

#             markers = ['HLA-DRA','HLA-DRB1','LYZ','CST3','TYROBP','AP1S2','CSTA','FCN1','MS4A6A','LST1','CYBB','CTSS','DUSP6','IL1B','SGK1','KLF4','CLEC7A','ATP2B1-AS1','MARCKS','SAT1','MYADM','IFI27','IFITM3','ISG15','APOBEC3A','IFI6','TNFSF10','MT2A','MX1','IFIT3','MNDA','S100A12','S100A9','S100A8','MAFB','VCAN','PLBD1','CXCL8','RNASE2','FCGR3A','MS4A7','CDKN1C','AIF1','COTL1','FCER1G','C1QA','RHOC','FCGR3B','IFITM2','NAMPT','G0S2','PROK2','CMTM2','BASP1','BCL2A1','SLC25A37','DEFA3','LTF','LCN2','CAMP','RETN','DEFA4','CD24','PGLYRP1','OLFM4']
#             idx_tr= np.where(adata_c.isin (markers))[0]
#             idx_te= np.where(adata1_c.isin (markers))[0]

            Xtrain = list(map(lambda x: (x[0][:,idx_tr], *x[1:]), Xtrain))
            Xtest = list(map(lambda x: (x[0][:,idx_te], *x[1:]), Xtest))
            Xtrain = list(map(lambda x: (x[0].todense(), *x[1:]), Xtrain))
            Xtest  = list(map(lambda x: (x[0].todense(), *x[1:]), Xtest))    
            return Xtrain,Xtest


In [None]:
def linear(Xtrain, Xtest, seed, ite):
    print("start linear")
    torch.manual_seed(seed)
#     X_train=Xtrain
#     kf = KFold(n_splits=5, shuffle= True, random_state=seed).split(X_train)
#     for elem in kf:
#                         train_index, test_index=elem
#                         Xtrain, Xvalid = list(np.array(X_train)[train_index]), list(np.array(X_train)[test_index])
#                         print("##Train_75%")
#                         torch.manual_seed(seed)
#                         linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
#                         model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=ite, state=state, regression=False,path=path+"/"+label)
#                         print("##Test_25%")
#                         model, res = train_classifier([], Xvalid, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False,path=path+"/"+label+"/"+str(n)+"_")
#                         print('##Test')
#                         model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False,path=path+"/"+label+"/"+str(n)+"_")
                        
#     Xtrain=X_train
    linear = torch.nn.Sequential(Aggregator(),Linear(Xtrain[0][0].shape[1], len(state)))
    print('##Trian_100%')
    model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=ite, state=state, regression=False,path=path+"/"+label+"/"+str(n)+"_")
    print('##Test')
    model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False,path=path+"/"+label)


In [None]:

def GMM_class(Xtrain, Xtest, seed, set_centers):
    print("start generative class with seed: "+str(seed))
    best_model = None
    best_score = -float("inf")
    X_train = Xtrain
    print(X_train[0][0].shape)
    for centers in set_centers:
                    kf = KFold(n_splits=5, shuffle= True, random_state=seed).split(X_train)
                    for elem in kf:
                        train_index, test_index=elem
                        Xtrain, Xvalid = list(np.array(X_train)[train_index]), list(np.array(X_train)[test_index])
                        model = train_class(Xtrain, centers,seed)
                        res = eval_class(model, Xtrain,path=path+"/"+label+"/"+str(seed)+"_"+str(n)+"_")

                        res = eval_class(model, Xvalid,path=path+"/"+label+"/"+str(seed)+"_"+str(n)+"_")

                        if (res["accuracy"] > best_score) or (res["accuracy"] == best_score and res["auc"] > best_score_auc):
                            best_model = model
                            best_score = res["accuracy"]
                            best_score_auc = res["auc"]
                            res_=res
                            best_centers = centers

    with open(path+'/gmm_class_'+str(seed)+"_"+str(best_centers)+"_"+label+"_"+str(n), 'wb') as fp:
                     pickle.dump(best_model, fp)                    
    
    print("##Best Validation")
    print(res_)
    print('best center: ' + str(best_centers))
    print("##Test")
    res =eval_class(best_model, Xtest,path=path+"/"+label+"/"+str(seed)+"_"+str(n)+"_")
    print(res)


In [None]:
def GMM_patient(Xtrain, Xtest, seed, set_centers):
    print("start generative patient with seed: "+str(seed))
    best_model = None
    best_score = -float("inf")
    X_train = Xtrain
    print(X_train[0][0].shape)
    for centers in set_centers:
                    kf = KFold(n_splits=5, shuffle= True, random_state=seed).split(X_train)
                    for elem in kf:
                        train_index, test_index=elem
                        Xtrain, Xvalid = list(np.array(X_train)[train_index]), list(np.array(X_train)[test_index])
                        model = train_patient(Xtrain, centers,seed)
                        res = eval_patient(model, Xtrain,path=path+"/"+label+"/"+str(seed)+"_"+str(n)+"_")
                        res = eval_patient(model, Xvalid,path=path+"/"+label+"/"+str(seed)+"_"+str(n)+"_")

                        if (res["accuracy"] > best_score) or (res["accuracy"] == best_score and res["auc"] > best_score_auc):
                            best_model = model
                            best_score = res["accuracy"]
                            best_score_auc = res["auc"]
                            res_=res
                            best_centers = centers

    with open(path+'/gmm_patient_'+str(seed)+"_"+str(best_centers)+"_"+label+"_"+str(n), 'wb') as fp:
                     pickle.dump(best_model, fp)                    
    
    print("##Best Validation")
    print(res_)
    print('best center: ' + str(best_centers))
    print("##Test")
    res =eval_patient(best_model, Xtest,path=path+"/"+label+"/"+str(seed)+"_"+str(n)+"_")
    print(res)


## 1.Load data (all cells, Monocytes, Monocytes+neutrophils, Neutrophils)

In [None]:
data = collections.defaultdict(list)
path= "../../chr2chr1/"
files= ['Xtest','Xall','Ctest','Call','state']
ext= ['','_mono','_mono_neu','_neu']
for fl in files:
    for ex in ext:
        car= fl+ex
        with open(path+fl+ex+".pkl", "rb") as f:
            buf = pickle.load(f)
        if fl ==  'Xall':
            buf=np.concatenate([buf[0],buf[1]])
        data[car]=buf

## Load marker genes

In [None]:
markers= pd.read_csv('../scripts/marker_cohort2')
markers["avg_log2FC"] = np.abs(markers["avg_log2FC"])

## 2.Prediction based on Monocytes 

### 2.1.Using top 50 genes per cluster

In [None]:
label= 'mono'
Xtrain = data['Xall_mono']
Xtest = data['Xtest_mono']
set_clusters=[7,11,3,4,6]
n = 50
adata_c = data['Call_mono']
adata1_c = data['Ctest_mono']
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)
state = data['state']


### 2.1.1.Linear model

In [None]:
for seed in [42]:
    linear(Xtrain, Xtest, seed, 1000)

### 2.1.2.GMM_class

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 2.1.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 2.2.Using top 20 genes per cluster

In [None]:
Xtrain = data['Xall_mono']
Xtest = data['Xtest_mono']
n = 20
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 2.2.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 2.2.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 2.2.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 2.3.Using top 100 genes per cluster

In [None]:
Xtrain = data['Xall_mono']
Xtest = data['Xtest_mono']
n = 100
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 2.3.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 2.3.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 2.3.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

## 3.Prediction based on Monocytes+Neutrophils 

### 3.1.Using top 50 genes per cluster

In [None]:
label='mono_neu'
Xtrain = data['Xall_mono_neu']
Xtest = data['Xtest_mono_neu']
set_clusters=[7,11,3,4,6,9,14]
n = 50
adata_c = data['Call_mono_neu']
adata1_c = data['Ctest_mono_neu']
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)
state = data['state']


### 3.1.1.Linear model

In [None]:
for seed in [42]:
    linear(Xtrain, Xtest, seed, 1000)

### 3.1.2.GMM_class

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 3.1.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 3.2.Using top 20 genes per cluster

In [None]:
Xtrain = data['Xall_mono_neu']
Xtest = data['Xtest_mono_neu']
n = 20
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 3.2.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 3.2.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 3.2.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 3.3.Using top 100 genes per cluster

In [None]:
Xtrain = data['Xall_mono_neu']
Xtest = data['Xtest_mono_neu']
n = 100
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 3.3.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 3.3.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 3.3.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

## 4.Prediction based on Neutrophils 

### 4.1.Using top 50 genes per cluster

In [None]:
label='neu'
Xtrain = data['Xall_neu']
Xtest = data['Xtest_neu']
set_clusters=[9,14]
n = 50
adata_c = data['Call_neu']
adata1_c = data['Ctest_neu']
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)
state = data['state']


### 4.1.1.Linear model

In [None]:
for seed in [42]:
    linear(Xtrain, Xtest, seed, 1000)

### 4.1.2.GMM_class

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 4.1.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 4.2.Using top 20 genes per cluster

In [None]:
Xtrain = data['Xall_neu']
Xtest = data['Xtest_neu']
n = 20
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 4.2.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 4.2.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 4.2.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 4.3.Using top 100 genes per cluster

In [None]:
Xtrain = data['Xall_neu']
Xtest = data['Xtest_neu']
n = 100
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 4.3.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 4.3.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 4.3.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

## 5.Prediction based on All cells 

### 5.1.Using top 50 genes per cluster

In [None]:
label='all'
Xtrain = data['Xall']
Xtest = data['Xtest']
set_clusters=np.unique(markers["cluster"])
n = 50
adata_c = data['Call']
adata1_c = data['Ctest']
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)
state = data['state']


### 5.1.1.Linear model

In [None]:
for seed in [42]:
    linear(Xtrain, Xtest, seed, 1000)

### 5.1.2.GMM_class

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 5.1.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 5.2.Using top 20 genes per cluster

In [None]:
Xtrain = data['Xall']
Xtest = data['Xtest']
n = 20
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 5.2.1.Linear model

In [None]:
for seed in [42]:
    linear(Xtrain, Xtest, seed, 1000)

### 5.2.2.GMM class 

In [None]:
for seed in [0]:
    GMM_class(Xtrain, Xtest, seed, [21])

### 5.2.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])

### 5.3.Using top 100 genes per cluster

In [None]:
Xtrain = data['Xall']
Xtest = data['Xtest']
n = 100
Xtrain,Xtest = get_data(Xtrain,Xtest,set_clusters,n,adata_c,adata1_c,markers)

### 5.3.1.Linear model

In [None]:
for seed in [42]:
            print("start")
            torch.manual_seed(seed)
            linear = torch.nn.Sequential(Aggregator(), Linear(Xtrain[0][0].shape[1], len(state)))
            model, res = train_classifier(Xtrain, [], [], linear, eta=1e-3, iterations=500, state=state, regression=False)
            model, res = train_classifier([], Xtest, [], model, regularize=None, iterations=1, eta=0, stochastic=True, regression=False, suff='linear')
            print(res)

### 5.3.2.GMM class 

In [None]:
for seed in [0,42,10,1234,4321]:
    GMM_class(Xtrain, Xtest, seed, [21,5])

### 5.3.3.GMM patient 

In [None]:
# for seed in [0,42,10,1234,4321]:
#     GMM_patient(Xtrain, Xtest, seed, [21,5])