# GMM

In [44]:
import pandas as pd
import numpy as np
from sklearn.mixture import GaussianMixture
import pickle

class GMM_MODEL():

    def __init__(self):
        pass

    
    def fit(self, data):
        """
        PARAMETERS
        ----------
        
        data : DataFrame
            must contain the two groups, diseased and control
        
        """
        
        c = data[data.group == "control"]
        d = data[data.group == "diseased"]
        
        c_ = c[c.columns.difference(["id","group"])]
        d_ = d[d.columns.difference(["id","group"])]
        
        self.model_control = GaussianMixture(n_components=8).fit(X = c_, y = None) 
        self.model_diseased = GaussianMixture(n_components=8).fit(X = d_, y = None) 


    def save(self, filepath):
        # save the model
        with open(filepath + "/control", 'wb') as f:
            pickle.dump(self.model_control, f)
        with open(filepath + "/diseased", 'wb') as f:
            pickle.dump(self.model_diseased, f)

    def load(self, filepath):
        
        """
        PARAMETERS
        ----------
        filepath : str
            should point to a directory containing the two models of control and diseased
        
        """
        # load the model
        # set self.model
        
        
        # TODO if filepath exists.
        with open(filepath + "/control", 'rb') as f:
            self.model_control = pickle.load(f)
        with open(filepath + "/diseased", 'rb') as f:
            self.model_diseased = pickle.load(f)
        
        
    def generate_patients(self, nr_markers=12, nr_cells = 20000, nr_patients = 20, column_names=None, group=None):
        """
        PARAMETERS:
        ----------
        nr_markers : int
            nr of markers

        nr_cells : int
            number of cells per patient
        
        nr_patients : int
            number of patients to generate
        
        column_names : 
            Dataframe.columns, names for the markers
        
        group : str
            "control" or "diseased"

        RETURNS:
        -------
        patients : dataframe
            list of patients, each patient with `sample_size` cells.
        """
        
        assert (group in ["control", "diseased"]), "group must be one of [control, diseased]"
        
        
        if group == "control":
            assert (self.model_control), "load the models first"
            model = self.model_control
        else :
            assert (self.model_diseased), "load the models first"
            model = self.model_diseased
            
        if model:

            patients = np.empty(shape=(nr_patients* nr_cells, nr_markers))
            p_id = np.empty(shape=nr_patients*nr_cells, dtype="int32")

            for i in range(nr_patients):
                p_id[nr_cells*i : nr_cells*(i+1)] = np.full(shape=(nr_cells), fill_value=i+1, dtype="int32")
                patients[nr_cells*i : nr_cells*(i+1)] = model.sample(nr_cells)[0]

            patients_df = pd.DataFrame(patients, columns=column_names)
            patients_df["id"] = p_id
            patients_df["group"] = group

        return patients_df
        

In [51]:
gmm = GMM_MODEL()

In [49]:
gmm.save("MODELS/GMM")

In [52]:
gmm.load("MODELS/GMM")

In [28]:
df = pd.read_csv("ModifiedDATA/scaled_ra.csv")

In [30]:
N_CELL_SAMPLES = 2000
subsample = df.groupby('id', group_keys=False).apply(lambda x: x.sample(n=N_CELL_SAMPLES)) 

In [48]:
gmm.fit(subsample)

In [54]:
gmm.generate_patients(column_names = subsample.columns.difference(["id","group"]), group="diseased")

Unnamed: 0,145Nd_CD4,146Nd_CD8a,147Sm_CD20,148Nd_CD16,151Eu_CD123,159Tb_CD11c,160Gd_CD14,169Tm_CD45RA,170Er_CD3,174Yb_HLA-DR,176Yb_CD56,209Bi_CD61,id,group
0,0.495243,4.241391,-0.012408,0.844422,-0.001307,-0.183268,0.068175,4.965247,3.555475,1.152267,0.716031,-0.052773,1,diseased
1,0.358705,4.504121,-0.052779,0.108222,0.000883,0.032805,-0.097660,2.129804,3.832658,-0.077787,-0.851410,0.073500,1,diseased
2,-0.238235,3.930953,0.010106,0.488239,-0.000597,-0.230511,0.048141,4.672152,3.470866,0.400043,-0.226148,0.022016,1,diseased
3,0.072435,5.101774,-0.114706,0.304335,0.000295,-0.019897,0.061322,1.020050,4.445771,-0.081137,-0.259879,-0.036755,1,diseased
4,0.124389,5.129063,-0.250924,0.021672,-0.000163,-0.044127,0.073385,-1.392127,4.428739,0.181459,1.111469,0.016252,1,diseased
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399995,-0.000753,-0.230076,-0.000175,0.000080,0.385122,3.175472,2.540384,0.190653,-0.257035,3.055751,0.603597,2.907183,20,diseased
399996,-0.001363,0.286860,0.001449,-0.000229,-0.075011,1.093928,0.979621,0.354613,-0.189236,-0.534822,0.190477,3.263591,20,diseased
399997,-0.001833,0.137201,-0.000950,-0.000068,0.266550,4.624211,1.918705,1.747225,0.946287,5.441156,0.210631,-1.385347,20,diseased
399998,-0.000911,-0.034270,-0.001009,0.001217,0.483815,2.138414,2.637258,1.351109,-0.203457,3.021548,0.503772,2.729601,20,diseased
