In [1]:
import pandas as pd
from collections import Counter
from sklearn import tree
from sklearn.model_selection import train_test_split, cross_validate

# Предобработка

### Нормальизация счетов k-меров

In [2]:
# read raw data
df = pd.read_csv('raw_counts.tsv', sep='\t').set_index('id')
# normalize raw counts: kmer_count / (cen12h1_count + cen12h2_count) * 100
for i in [col for col in df.columns if 'ch' in col]:
    df[i] = df[i] / (df['cen12h1'] + df['cen12h2']) * 100
df = df.drop(['cen12h1', 'cen12h2'], axis=1)
df.head()

Unnamed: 0_level_0,sex,Family status,Superpopulation code,Population code,cenhap1,cenhap2,ch1,ch2a,ch2b,ch2c,ch3,ch4,ch5,ch6,ch7,ch8
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
HG00403,male,father,EAS,CHS,1,1,2.540179,0.003064,0.0,0.024513,0.001532,0.004596,0.618958,0.281902,0.626618,0.075072
HG00404,female,mother,EAS,CHS,1,1,0.837865,0.002116,0.002116,0.033853,0.004232,0.0,0.85056,0.512028,0.793432,0.033853
HG00405,female,child,EAS,CHS,1,1,2.068852,0.004101,0.0,0.018454,0.00205,0.008202,0.832462,0.520801,0.764799,0.114822
HG00406,male,father,EAS,CHS,1,1,2.543106,0.0,0.0,0.038626,0.001545,0.00309,0.801866,0.273469,0.636549,0.007725
HG00407,female,mother,EAS,CHS,1,1,1.63328,0.001567,0.0,0.010972,0.003135,0.003135,0.841719,0.319759,0.733565,0.014107


### Фильтрация данных

In [3]:
# merge sorted cenhaps
df['cenhap1'] = df['cenhap1'].apply(lambda x: str(int(x)))
df['cenhap2'] = df['cenhap2'].apply(lambda x: str(int(x)))
df['cenhap_merged'] = df['cenhap1'] + '-' + df['cenhap2']
df['cenhap_merged'] = df['cenhap_merged'].apply(lambda x: '-'.join(sorted(x.split('-'))))

In [4]:
# remove ch9
df = df[~df['cenhap_merged'].str.contains('9')]
# remove ch10
df = df[~df['cenhap_merged'].str.contains('10')]
# remove single pairs 
cenhap_pairs_freq = Counter(df['cenhap_merged'])
print('Удалены уникальные пары:')
for i in cenhap_pairs_freq:
    if cenhap_pairs_freq[i] == 1: 
        print(i)
        df = df[df['cenhap_merged'] != i]

Удалены уникальные пары:
3-6
3-3
3-7
7-7


### Получение тренировочной и тестовой выборок

In [5]:
# separate X and y
X = df[[col for col in df.columns if 'ch' in col]]
y = df['cenhap_merged'] 
# balanced train_test_split 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, stratify=y, random_state=3)

In [6]:
# look at pairs frequency
total = pd.DataFrame(Counter(df.cenhap_merged).most_common(), columns=['cenhaps', 'total']).set_index('cenhaps')
train = pd.DataFrame(Counter(y_train).most_common(), columns=['cenhaps', 'train']).set_index('cenhaps')
test = pd.DataFrame(Counter(y_test).most_common(), columns=['cenhaps', 'test']).set_index('cenhaps')
pair_counts = pd.concat([total, train, test], axis=1)

# Обучение 

### Подбор глубины дерева

In [7]:
ml_results = pd.DataFrame(index=['depth', 'train', 'test'])
for i in range(1, 16):
    clf = tree.DecisionTreeClassifier(max_depth=i, random_state=6)
    clf.fit(X_train, y_train)
    to_concat = pd.DataFrame([i, round(clf.score(X_train, y_train), 2), round(clf.score(X_test, y_test), 2)], index=['depth', 'train', 'test'])
    ml_results = pd.concat([ml_results, to_concat], axis=1)
ml_results.to_clipboard()
ml_results

Unnamed: 0,0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14
depth,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0
train,0.5,0.64,0.78,0.82,0.87,0.92,0.94,0.96,0.98,0.99,1.0,1.0,1.0,1.0,1.0
test,0.5,0.63,0.77,0.8,0.85,0.86,0.89,0.91,0.93,0.95,0.95,0.96,0.95,0.95,0.95


### Анализ предсказаний и ошибок

In [8]:
clf = tree.DecisionTreeClassifier(max_depth=11, random_state=6)
clf = clf.fit(X_train, y_train)

In [9]:
# feature importance
importance = {}
features = X_train.columns
for i in range(len(features)):
    importance[features[i]] = clf.feature_importances_[i]
importance = pd.DataFrame(importance, index=['feature_importance']).T
importance.sort_values('feature_importance', ascending=False, inplace=True)
importance

Unnamed: 0,feature_importance
ch1,0.339907
ch4,0.205004
ch2b,0.088965
ch8,0.088416
ch2c,0.061898
ch7,0.056911
ch5,0.052336
ch3,0.050856
ch6,0.030347
ch2a,0.025359


In [10]:
# check prediction results
results = pd.concat([X_test, y_test], axis=1)
results['predicted'] = clf.predict(X_test)
results['prob'] = [max(i) for i in clf.predict_proba(X_test)]
results['hit?'] = results['cenhap_merged'] == results['predicted']
mismatches = results[results['hit?']==False]

# count mismatches
mismatches['pair'] = results['cenhap_merged'] + ':' + results['predicted']
Counter(mismatches.pair)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  mismatches['pair'] = results['cenhap_merged'] + ':' + results['predicted']


Counter({'1-4:1-1': 5,
         '1-2:1-1': 4,
         '2-4:3-4': 2,
         '1-8:2-8': 2,
         '4-4:1-4': 2,
         '2-4:4-4': 1,
         '4-4:2-4': 1,
         '2-2:1-2': 1,
         '4-7:1-4': 1,
         '1-1:1-4': 1,
         '2-4:1-4': 1,
         '1-4:1-6': 1,
         '1-5:2-5': 1,
         '2-3:2-2': 1,
         '4-6:3-4': 1,
         '1-1:1-2': 1,
         '1-4:2-4': 1,
         '1-4:4-4': 1,
         '2-8:1-8': 1})

In [11]:
# take a look at frequency of mismathed pairs in dataset
mm_pairs_freq = pd.DataFrame(Counter(mismatches['cenhap_merged']).most_common()).set_index(0)
all_pairs_freq = pd.DataFrame(Counter(y).most_common()).set_index(0)
test_pairs_freq = pd.DataFrame(Counter(y_test).most_common()).set_index(0)
pd.concat([all_pairs_freq, test_pairs_freq, mm_pairs_freq], axis=1).dropna()

Unnamed: 0_level_0,1,1,1
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
1-1,561,185,2.0
1-4,336,111,8.0
1-2,270,89,4.0
2-2,164,54,1.0
4-4,100,33,3.0
2-8,63,21,1.0
2-4,50,16,4.0
1-8,41,13,2.0
2-3,34,11,1.0
1-5,23,8,1.0
