In [37]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.tree import DecisionTreeRegressor

def RF_complete(data_t, data_s_list, feature_t, n_tree=50, n_feature=5, f_sample=0.3, n_best_tree=5):
    '''
    TODO:
    - data_s_list might be a list of dataframes... Not just a list corresponding to one thing
    '''
    K = len(data_s_list)
    print(f"{K} additional datasets used for prediction.")

    if feature_t not in data_t.columns:
        print("Feature-of-interest not found in data_t! Please check column names of input data.")
        return None
    else:
        print("Feature-of-interest located!")

    mu = []
    sigma = []

    for k in range(0, K+1):
        data_assist = data_s_list #.iloc[k-1] # Indexing is in terms of rows

        if feature_t in data_assist.columns:
            f_t_ind = list(data_assist.columns).index(feature_t) # Finds col/index of current sera that matches virus of interest
        else:
            print(f"feature_t not found in assisting data {k}!")
            continue

        if not data_t.columns.equals(data_assist.columns):
            print(f"Features not matched for assisting data {k}! Skipped to next data.")
            continue

        trans_true_err = []
        trans_pred_err = []
        print(f"data_assist shape: {data_assist.shape}")
        print(f"Iterating over {data_assist.shape[1]}")
        # Should range be from 0? Is this just cuz of R's indexing?
        for j in range(0, data_assist.shape[1]): # 1->0? Assuming this iterates over cols
            feature_trans = data_assist.columns[j]
            if not data_t.iloc[:, j].isna().sum() > 0:
                rf_1t1 = RF_complete_1t1(data_assist, data_t, feature_t=feature_trans, n_tree=n_tree,
                                         n_feature=n_feature, f_sample=f_sample, k=k)

                if rf_1t1 is not None:
                    trans_true_err.extend(rf_1t1["true_err"])
                    trans_pred_err.extend(rf_1t1["pred_err"])

        if not trans_true_err:
            continue

        lm_coeff = np.polyfit(trans_pred_err, trans_true_err, 1)
        a = lm_coeff[0]
        b = lm_coeff[1]
        c = np.sqrt(np.mean((a * np.array(trans_pred_err) + b - np.array(trans_true_err)) ** 2))

        def f_transfer(x):
            return max(x, a * x + b + c)

        print([f"a={round(a, 3)}", f"b={round(b, 3)}", f"c={round(c, 3)}"])

        rf_1t1 = RF_complete_1t1(data_assist, data_t, feature_t=feature_t, n_best_tree=n_best_tree, n_tree=n_tree,
                                 n_feature=n_feature, f_sample=f_sample, k=k)

        mu.append(rf_1t1["mu"])
        sigma.append(f_transfer(np.mean(rf_1t1["pred_err"])))

    A = 0
    B = 0
    tt = 0
    for k in range(K):
        if sigma[k] is not None:
            tt += 1
            A += mu[k] / sigma[k] ** 2
            B += 1 / sigma[k] ** 2

    print(f"{tt} assisting data used for prediction.")
    return {"predictions": A / B, "errors": 1 / np.sqrt(B)}

def RF_complete_1t1(data_assist, data_t, feature_t, n_tree=50, n_feature=5, f_sample=0.3, n_best_tree=5, k=1):
    if feature_t in data_assist.index:
        f_t_ind = data_assist.index.get_loc(feature_t)
    else:
        print(f"feature_t not found in assisting data {k}!")
        return None
    # IM WORKING ON THIS ...
    if (data_assist.apply(lambda x: x.count(), axis=0) / data_assist.shape[0] > 0.8).sum() > n_feature:
        f_ind = data_assist.index[
            (data_assist.apply(lambda x: x.count(), axis=0) / data_assist.shape[0] > 0.8)].tolist()

        if feature_t in f_ind:
            f_ind.remove(feature_t)
    else:
        print(f"n_feature too large for assisting data {k}! Skipped to next data.")
        return None

    f_tmp_ind = [data_t.columns.get_loc(f) for f in f_ind if f in data_t.columns]
    f_feasible = [f for f in data_t.columns[f_tmp_ind] if data_t[f].count() > 2]

    if len(f_feasible) < 2:
        print(f"n_feature too large for assisting data {k}! Skipped to next data.")
        return None

    data_assist = data_assist.dropna(subset=[feature_t])

    RMSE = []
    f_sel_ind = []
    tree = []

    for i in range(n_tree):
        f_sel_ind.append(np.random.choice(f_ind, n_feature, replace=True))
        sample_sel = np.random.choice(data_assist.shape[0], int(data_assist.shape[0] * f_sample), replace=True)
        data_train = data_assist.iloc[sample_sel, f_sel_ind[i] + [f_t_ind]]

        colm_t = data_train.apply(lambda x: x.mean(), axis=1)
        data_train = data_train - np.outer(np.ones(data_train.shape[1]), colm_t)
        data_train.columns = f_sel_ind[i] + ["target"]

        tree.append(DecisionTreeRegressor(min_samples_split=5))
        tree[i].fit(data_train.iloc[:, :-1], data_train["target"])

        data_test = data_assist.iloc[~sample_sel, f_sel_ind[i] + [f_t_ind]]
        colm_t = data_test.apply(lambda x: x.mean(), axis=1)
        data_test = data_test - np.outer(np.ones(data_test.shape[1]), colm_t)
        pred_t = tree[i].predict(data_test.iloc[:, :-1])
        RMSE.append(np.sqrt(np.mean((pred_t - data_test["target"]) ** 2)))

    pred_list = np.zeros((data_t.shape[0], n_best_tree))

    for i in range(n_best_tree):
        j = np.argsort(RMSE)
        f_t_ind = data_t.columns.get_loc(feature_t)
        f_t_sel_ind = [data_t.columns.get_loc(f) for f in f_sel_ind[j[i]]]

        data_test = data_t.iloc[:, f_t_sel_ind + [f_t_ind]]
        colm_t = data_test.apply(lambda x: x.mean(), axis=1)
        data_test = data_test - np.outer(np.ones(data_test.shape[1]), colm_t)
        data_test.columns = f_sel_ind[j[i]] + ["target"]

        pred_t = tree[j[i]].predict(data_test.iloc[:, :-1])
        pred_t[np.where(data_test.iloc[:, :-1].isna().sum(axis=1) > 0)] = np.nan
        pred


In [38]:
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn

In [39]:
def non_trivialize_df(df):
    # Search columns for trivial feature (virus)
    dropped = 0
    for idx, col in enumerate(df.columns):
        uniques = df[col].unique()
        num_unique = len(uniques)
        if num_unique == 1 and 'sera_table' not in col and "*" in uniques: # Critereon for a missing virus
#             print(idx, col, num_unique, uniques)
            df = df.drop(columns=col)
            dropped += 1 # Count dropped virus
    # Search rows for trivial entry (serum)
    for idx in list(df.index):
        num_unique = len(df.loc[idx].unique())
        if num_unique == 3: # Serum name and table name are 2 constant features.
#             print(idx, num_unique)
            df = df.drop(index=idx)
    df = df.replace("*", np.nan)
    df = df.set_index('Measurements (Sera in Rows/Viruses in Columns)')
    df = df.drop(['sera_table'], axis=1)
    df = df.map(lambda x: pd.to_numeric(x, errors='coerce') )
    # Return data and number of dropped viruses
    return df, dropped


def plot_heatmap(df):
    mask = df.isnull()
    g = sns.heatmap(df, cmap='gray',mask=mask)
    g.set_facecolor('xkcd:salmon')
    return g


def plot_heatmap_subplot(table_dict):
#     keys = sorted(list(table_dict.keys()))
    keys = ["TableS1", "TableS3", "TableS5", "TableS6", "TableS13", "TableS14"]
    N = len(keys)
    fig, axs = plt.subplots(nrows=1, ncols=N, figsize=(30, 5))
    for idx, key in enumerate(keys):
        df, dropped_viruses = table_dict[key]
        percent_missing = sum(list(df.isna().sum())) / (df.shape[0] * df.shape[1]) # Proportion of missing vals in table
        df = np.log(df)
        sera_num = df.shape[0]
        virus_num = df.shape[1]
        mask = df.isnull()
        g = sns.heatmap(df,
                        xticklabels=False,
                        yticklabels=False,
                        cmap='gray',
                        mask=mask,
                        ax=axs[idx])
        g.set_facecolor('xkcd:salmon')
        axs[idx].set_xlabel(f'{sera_num} Sera x {virus_num} Viruses\n{percent_missing * 100:.2f}% Missing')
        axs[idx].set_title(key)
    return fig, axs

In [40]:
'''
data_t: DF of sera (table specific) X Viruses (No viruses dropped yet) to make prediction on
data_s_list: list of columns of data table for virus to train on ... Its the whole df I think
feature_t: A particular virus to predict (And I assume we omit this from training)
n_tree=50
n_feature=5
f_sample=0.3
n_best_tree=5
'''

flu_df = pd.read_csv("../CrossStudyCompletion/Matrix Completion in R/InfluenzaData.csv", sep=',')
sera = flu_df['Measurements (Sera in Rows/Viruses in Columns)'].tolist()
sera_tables = [i[i.index('Table'):] for i in sera]
table_keys = set(sera_tables)
flu_df['sera_table'] = sera_tables

flu_table_dict = {table_key: flu_df.loc[flu_df['sera_table']==table_key] for table_key in table_keys}
filtered_flu_table_dict = {key: non_trivialize_df(arg) for key, arg in list(flu_table_dict.items())}

predict_table = 'TableS14'
train_table = 'TableS13'
data_df = flu_df.set_index("Measurements (Sera in Rows/Viruses in Columns)")
data_df = data_df.map(lambda x: pd.to_numeric(x, errors='coerce') )
data_df = np.log(data_df)
data_df['sera_table'] = sera_tables

data_t = data_df.loc[data_df['sera_table']==predict_table].drop(['sera_table'], axis=1)
data_s = data_df.loc[data_df['sera_table']==train_table].drop(['sera_table'], axis=1)
data_s_list = [[col] for col in list(data_s.columns)]
feature_t = "A/PANAMA/2007/99"

In [41]:
data_df

Unnamed: 0_level_0,A/AUCKLAND/20/2003,A/AUCKLAND/5/96,A/BANGKOK/1/97,A/BEIJING/32/92,A/BRISBANE/10/2007,A/BRISBANE/22/94,A/BRISBANE/22/96,A/BRISBANE/3/2005,A/BRISBANE/342/2003,A/BRISBANE/5/2002,...,RG145K_NL/178/95,VN015/EL134/2008,VN016/EL135/2008,VN017/EL140/2008,VN018/EL204/2009,VN019/EL442/2010,VN020/EL443/2010,VN021/EL444/2010,VN053/2010,sera_table
Measurements (Sera in Rows/Viruses in Columns),Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Serum_A/WELLINGTON/25/93_TableS1,1.609438,3.688879,3.688879,6.461468,1.609438,7.154615,5.075174,1.609438,1.609438,1.609438,...,5.075174,,,,,,,,,TableS1
Serum_A/WELLINGTON/96/93_TableS1,1.609438,2.995732,3.688879,5.075174,1.609438,7.847763,5.075174,1.609438,1.609438,1.609438,...,5.768321,,,,,,,,,TableS1
Serum_A/VICTORIA/9/94_TableS1,1.609438,3.688879,3.688879,5.075174,1.609438,8.540910,5.768321,1.609438,1.609438,1.609438,...,1.609438,,,,,,,,,TableS1
Serum_A/JOHANNESBURG/33/94_TableS1,2.995732,2.995732,4.382027,3.688879,1.609438,7.154615,5.768321,1.609438,1.609438,1.609438,...,5.768321,,,,,,,,,TableS1
Serum_A/SHANDONG/9/93_TableS1,2.302585,2.995732,4.382027,5.768321,1.609438,7.154615,5.075174,1.609438,1.609438,1.609438,...,4.382027,,,,,,,,,TableS1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SubjectA028_Post_TableS14,,2.302585,,,,2.995732,,,,,...,,,3.688879,,,,,,2.995732,TableS14
SubjectA020_Post_TableS14,,2.302585,,,,2.995732,,,,,...,,,2.995732,,,,,,1.609438,TableS14
SubjectB025_Post_TableS14,,4.382027,,,,5.768321,,,,,...,,,2.995732,,,,,,2.995732,TableS14
SubjectB040_Post_TableS14,,2.995732,,,,4.382027,,,,,...,,,3.688879,,,,,,2.995732,TableS14


In [34]:
feature_t in list(data_s.columns)

True

In [35]:
data_s

Unnamed: 0_level_0,A/AUCKLAND/20/2003,A/AUCKLAND/5/96,A/BANGKOK/1/97,A/BEIJING/32/92,A/BRISBANE/10/2007,A/BRISBANE/22/94,A/BRISBANE/22/96,A/BRISBANE/3/2005,A/BRISBANE/342/2003,A/BRISBANE/5/2002,...,NL/823/92,RG145K_NL/178/95,VN015/EL134/2008,VN016/EL135/2008,VN017/EL140/2008,VN018/EL204/2009,VN019/EL442/2010,VN020/EL443/2010,VN021/EL444/2010,VN053/2010
Measurements (Sera in Rows/Viruses in Columns),Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SubjectA028_Pre_TableS13,,3.688879,,,,5.075174,,,,,...,3.688879,,,1.609438,,,,,,2.302585
SubjectA004_Pre_TableS13,,3.688879,,,,5.075174,,,,,...,2.995732,,,2.995732,,,,,,2.995732
SubjectB029_Pre_TableS13,,4.382027,,,,5.768321,,,,,...,4.382027,,,3.688879,,,,,,3.688879
SubjectB030_Pre_TableS13,,2.995732,,,,4.382027,,,,,...,3.688879,,,1.609438,,,,,,2.302585
SubjectB007_Pre_TableS13,,2.302585,,,,2.302585,,,,,...,3.688879,,,4.382027,,,,,,3.688879
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SubjectA007_Post_TableS13,,2.995732,,,,5.075174,,,,,...,3.688879,,,2.302585,,,,,,2.302585
SubjectA035_Post_TableS13,,3.688879,,,,5.075174,,,,,...,5.075174,,,2.995732,,,,,,2.995732
SubjectA030_Post_TableS13,,3.688879,,,,5.075174,,,,,...,5.075174,,,2.995732,,,,,,2.995732
SubjectB028_Post_TableS13,,1.609438,,,,1.609438,,,,,...,1.609438,,,2.995732,,,,,,2.995732


In [44]:
data_assist.columns

Index(['A/AUCKLAND/20/2003', 'A/AUCKLAND/5/96', 'A/BANGKOK/1/97',
       'A/BEIJING/32/92', 'A/BRISBANE/10/2007', 'A/BRISBANE/22/94',
       'A/BRISBANE/22/96', 'A/BRISBANE/3/2005', 'A/BRISBANE/342/2003',
       'A/BRISBANE/5/2002', 'A/CALIFORNIA/7/2004', 'A/CANBERRA/1/96',
       'A/CANBERRA/9/97', 'A/CHRISTCHURCH/1/96', 'A/CHRISTCHURCH/68/99',
       'A/FUJIAN/140/2000', 'A/FUJIAN/411/2002', 'A/FUKUOKA/55/2002',
       'A/JOHANNESBURG/33/94', 'A/NANCHANG/933/95', 'A/NEW_YORK/55/2004',
       'A/NEWCALEDONIA/12/2004', 'A/PANAMA/2007/99', 'A/PERTH/16/2009',
       'A/PERTH/27/2007', 'A/PERTH/5/97', 'A/PERTH/9/97',
       'A/PHILIPPINES/427/2002_MDCK', 'A/PHILIPPINES/472/2002_EGG',
       'A/SHANDONG/9/93', 'A/SINGAPORE/37/2004', 'A/SOUTH_AUSTRALIA/53/2001',
       'A/SOUTH_AUSTRALIA/84/2002', 'A/SYDNEY/228/2000', 'A/SYDNEY/5/97',
       'A/TASMANIA/1/97', 'A/TOWNSVILLE/2/99', 'A/TOWNSVILLE/36/2003',
       'A/TOWNSVILLE/4/2002', 'A/URUGUAY/716/2007', 'A/VICTORIA/1/93',
       'A/VICTOR

In [45]:
K = len(data_s) # 160
K
data_assist = data_s
f_t_ind = list(data_assist.columns).index(feature_t)
data_t.columns.equals(data_t.columns)
f_t_ind = data_assist.columns.get_loc(feature_t)
data_assist
(data_assist.apply(lambda x: x.count()) / data_assist.shape[1] > 0.8).sum() > n_feature

NameError: name 'n_feature' is not defined

In [28]:
RF_complete(data_t = data_t,
            data_s_list=data_s,
            feature_t=feature_t)

160 additional datasets used for prediction.
Feature-of-interest located!
data_assist shape: (81,)


IndexError: tuple index out of range

In [None]:
feature_t in list(flu_df.columns)

In [None]:
filtered_flu_table_dict['TableS14'][0]

In [None]:
# substrRight <- function(x, n){
#   substr(x, nchar(x)-n+1, nchar(x))
# }

# # load data
# data = read.csv("InfluenzaData.csv", header=T)
# dim(data)
# #get indices of the six vaccination datasets
# table.ind = as.character(substrRight(unlist(data[1]),3))
# table(table.ind)
# #create data matrix
# data=data[,-1]
data2=matrix(as.numeric(unlist(data)), ncol=81)
colnames(data2)=colnames(data)
rownames(data2)=rownames(data)
data2=log10(data2)

#Example 1: predict antibody responses against the virus "A.PANAMA.2007.99" in Table S14 using Table S13.

data.t = data2[which(table.ind==unique(table.ind)[6]), ]
data.s.list = list(data2[which(table.ind==unique(table.ind)[5]), ])
feature.t = "A.PANAMA.2007.99"

#run the RF.complete function
out = RF.complete(data.t, data.s.list, feature.t)