# Performing ensemble learning

In this notebook we use our ensemble of `random forest` models on the data located at `.../pca_engineered_datasets/pca32_95comps/train_splits/split_xx/random_forest`. Each of the directories of this form contain an ensemble of models.  

In [1]:
%%time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
import os
import gc
import shutil
import csv
import requests
from tqdm import tqdm
from collections import OrderedDict, defaultdict, Counter
import seaborn as sns
import json
from sklearn.ensemble import RandomForestClassifier
#from sklearn.decomposition import PCA
#from joblib import dump, load
import joblib
import xgboost as xgb
from sklearn.metrics import accuracy_score

CPU times: user 563 ms, sys: 49.2 ms, total: 612 ms
Wall time: 637 ms


## Loading data

In [2]:
#parent_dir = "/content/genetic_engineering_attribution"
parent_dir = "/home/rio/data_sets/genetic_engineering_attribution"

#### Data

In [3]:
%%time
### pca directory
pca_dir = os.path.join(parent_dir,"pca")

### pca engineered data sets
pca_engineered_datasets_dir = os.path.join(parent_dir,"pca_engineered_datasets")

### pca_32_95comp dir
pca_32_95comp_dir = os.path.join(pca_engineered_datasets_dir,"pca_32_95comp")

### train, val and test dir
train_val_test_sequence_id_dir = os.path.join(pca_32_95comp_dir,"train_val_test_sequence_id")
train_dir = os.path.join(train_val_test_sequence_id_dir,"train")
val_dir = os.path.join(train_val_test_sequence_id_dir,"val")
test_dir = os.path.join(train_val_test_sequence_id_dir,"test")
### paths to csvs
#train_path = os.path.join(pca_32_95comp_dir,"train.csv")
#val_path = os.path.join(pca_32_95comp_dir,"val.csv")
#test_path = os.path.join(pca_32_95comp_dir,"test.csv")

### train splits dir
train_splits_dir = os.path.join(pca_32_95comp_dir,"train_splits")

### Printing directories
print("train_splits_dir: ",train_splits_dir)
print("train_dir: ",train_dir)
print("val_dir: ",val_dir)
print("test_dir: ",test_dir)

### loading dataframes
#df_train = pd.read_csv(train_path,index_col=0)
#df_val = pd.read_csv(val_path,index_col=0)
#df_test = pd.read_csv(test_path,index_col=0)

### Printing shapes:
#print(f"Shape of df_train: {df_train.shape}")
#print(f"Shape of df_val: {df_val.shape}")
#print(f"Shape of df_test: {df_test.shape}")

train_splits_dir:  /home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/train_splits
train_dir:  /home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/train_val_test_sequence_id/train
val_dir:  /home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/train_val_test_sequence_id/val
test_dir:  /home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/train_val_test_sequence_id/test
CPU times: user 430 µs, sys: 57 µs, total: 487 µs
Wall time: 312 µs


## Selecting features and targets

#### features

In [4]:
features = ['pca_0', 'pca_1', 'pca_2', 'pca_3', 'pca_4', 'pca_5', 'pca_6', 'pca_7', 'pca_8', 'pca_9', 
            'pca_10', 'pca_11', 'pca_12', 'pca_13', 'pca_14', 'pca_15', 'pca_16', 'pca_17', 'pca_18', 
            'pca_19', 'pca_20', 'pca_21', 'pca_22', 'pca_23', 'pca_24', 'pca_25', 'pca_26', 'pca_27', 
            'pca_28', 'pca_29', 'pca_30', 'pca_31', 'pca_32', 'pca_33', 'pca_34', 'pca_35', 'pca_36', 
            'pca_37', 'pca_38', 'pca_39', 'pca_40', 'pca_41', 'pca_42', 'pca_43', 'pca_44', 'pca_45', 
            'pca_46', 'pca_47', 'pca_48', 'pca_49', 'pca_50', 'pca_51', 'pca_52', 'pca_53', 'pca_54', 
            'pca_55', 'pca_56', 'pca_57', 'pca_58', 'pca_59', 'pca_60', 'pca_61', 'pca_62', 'pca_63', 
            'pca_64', 'pca_65', 'pca_66', 'pca_67', 'pca_68', 'pca_69', 'pca_70', 'pca_71', 'pca_72', 
            'pca_73', 'pca_74', 'pca_75', 'pca_76', 'pca_77', 'pca_78', 'pca_79', 'pca_80', 'pca_81', 
            'pca_82', 'pca_83', 'pca_84', 'pca_85', 'pca_86', 'pca_87', 'pca_88', 'pca_89', 'pca_90', 
            'pca_91', 'pca_92', 'pca_93', 'pca_94', 
            'bacterial_resistance_ampicillin', 'bacterial_resistance_chloramphenicol', 
            'bacterial_resistance_kanamycin', 'bacterial_resistance_other', 
            'bacterial_resistance_spectinomycin', 'copy_number_high_copy', 
            'copy_number_low_copy', 'copy_number_unknown', 'growth_strain_ccdb_survival', 
            'growth_strain_dh10b', 'growth_strain_dh5alpha', 'growth_strain_neb_stable', 
            'growth_strain_other', 'growth_strain_stbl3', 'growth_strain_top10', 
            'growth_strain_xl1_blue', 'growth_temp_30', 'growth_temp_37', 'growth_temp_other',
            'selectable_markers_blasticidin', 'selectable_markers_his3', 'selectable_markers_hygromycin', 
            'selectable_markers_leu2', 'selectable_markers_neomycin', 'selectable_markers_other',
            'selectable_markers_puromycin', 'selectable_markers_trp1', 'selectable_markers_ura3', 
            'selectable_markers_zeocin', 'species_budding_yeast', 'species_fly', 'species_human', 
            'species_mouse', 'species_mustard_weed', 'species_nematode', 'species_other', 'species_rat', 
            'species_synthetic', 'species_zebrafish']
print("Features to be used: ")
print(features)
print("Number of features: ", len(features))

Features to be used: 
['pca_0', 'pca_1', 'pca_2', 'pca_3', 'pca_4', 'pca_5', 'pca_6', 'pca_7', 'pca_8', 'pca_9', 'pca_10', 'pca_11', 'pca_12', 'pca_13', 'pca_14', 'pca_15', 'pca_16', 'pca_17', 'pca_18', 'pca_19', 'pca_20', 'pca_21', 'pca_22', 'pca_23', 'pca_24', 'pca_25', 'pca_26', 'pca_27', 'pca_28', 'pca_29', 'pca_30', 'pca_31', 'pca_32', 'pca_33', 'pca_34', 'pca_35', 'pca_36', 'pca_37', 'pca_38', 'pca_39', 'pca_40', 'pca_41', 'pca_42', 'pca_43', 'pca_44', 'pca_45', 'pca_46', 'pca_47', 'pca_48', 'pca_49', 'pca_50', 'pca_51', 'pca_52', 'pca_53', 'pca_54', 'pca_55', 'pca_56', 'pca_57', 'pca_58', 'pca_59', 'pca_60', 'pca_61', 'pca_62', 'pca_63', 'pca_64', 'pca_65', 'pca_66', 'pca_67', 'pca_68', 'pca_69', 'pca_70', 'pca_71', 'pca_72', 'pca_73', 'pca_74', 'pca_75', 'pca_76', 'pca_77', 'pca_78', 'pca_79', 'pca_80', 'pca_81', 'pca_82', 'pca_83', 'pca_84', 'pca_85', 'pca_86', 'pca_87', 'pca_88', 'pca_89', 'pca_90', 'pca_91', 'pca_92', 'pca_93', 'pca_94', 'bacterial_resistance_ampicillin', 'b

#### targets

Below we map our target values to integers, to make things simpler.

In [5]:
all_targets = [None] + sorted(['00Q4V31T', '012VT4JK', '028IO5W2', '03GRNN7N', '03Y3W51H', '09MQV1TY', '0A4AHRCT', '0A9M05NC', '0B9GCUVV', '0CL7QVG8', '0CML4B5I', '0DTHTJLJ', '0FFBBVE1', '0HWCWFNU', '0L3Y6ZB2', '0M44GDO8', '0MDYJM3H', '0N3V9P9M', '0NP55E93', '0PJ91ZT6', '0R296F9R', '0T2AZBD6', '0URA80CN', '0VRP2DI6', '0W6O08VX', '0WHP4PPK', '0XPTGGLP', '0XS4FHP3', '0Y24J5G2', '10TEBWK2', '11TTDKTM', '131RRHBV', '13LZE1F7', '14PBN8C2', '15D0Z97U', '15S88O4Q', '18C9J8EH', '19CAUKJB', '1AP294AT', '1B9BJ2IP', '1BE35FI1', '1CIHYCE4', '1DJ9L58E', '1DTDCRUO', '1EDZ6CA7', '1HCQTAYT', '1HK4VXP8', '1IXFZ3HO', '1K11RCST', '1KC6XYO6', '1KNFJ6KQ', '1KZHNVYR', '1LBGAU5Z', '1NXRMDN6', '1OQJ21E9', '1OWZDF82', '1PA232PA', '1PIGWQFY', '1Q1IUY3G', '1S515B69', '1TC200QC', '1TI4HS4X', '1UOA7CA1', '1UREJUSJ', '1UU0CHTK', '1VPOX8VI', '1VQS4WNS', '1X0VC0O1', '1XU60MET', '1ZC8RPN1', '20ABQYHS', '20CEB9KE', '216DWMG6', '21ZFBX5E', '24SL2992', '25UVYUID', '26KK8UM5', '27OS3BTP', '28D4D4QM', '298AMR5C', '29D6Q091', '2AQG6I31', '2BAFY4GP', '2CJHRNWD', '2FCX4O0X', '2GGU2QA2', '2GSZMU46', '2GTLIT33', '2H37WPKA', '2HNZZYDB', '2JPNC9X6', '2KDACBQT', '2L336TQL', '2M3CXS8N', '2MCB7LXW', '2MQ2NPMA', '2NEXWXMT', '2PY8K6GU', '2Q33W599', '2SSVM7H9', '2TVMHQTW', '2TXY439E', '2VP4JPB9', '2VTLZHDS', '2VX4F6RC', '2XC1478M', '2XX0N87I', '2Y9L13L4', '2YCH1PUI', '2YLQA8OZ', '303BN0Z0', '318RH8P0', '330L4OIV', '33AR5KVE', '343M819H', '34TE1Q0A', '35MKXPL0', '36W150XW', '36XLYYGZ', '37VO60SB', '384ASNLB', '38MDETY1', '38MEQ4SU', '39LLQ2PB', '39TEZ0C3', '39TPBOL7', '3BGLF8BC', '3C2VZQ2R', '3C952KY7', '3D9CMQ4V', '3EARN0Z7', '3EYBG174', '3EZXYI3U', '3FPH0N6R', '3FW33G68', '3GEXBRC0', '3KCEM7V4', '3L314D8W', '3LSNTL1N', '3MDRJUI2', '3MX1D3LD', '3N169DM2', '3NSJ6N02', '3O1GIAV7', '3QP4D23X', '3RK54JUW', '3TLD81QQ', '3TUFYWQN', '3TXFYNKG', '3X2GGDHW', '3XE0BJDW', '3YAQWNBK', '3YEGUN04', '3YYEC52Y', '40MD0YZ3', '40ZI3TDN', '443NZOSB', '448QVC4C', '44N2CYI9', '459BZKP3', '4648UZGD', '46AZ97U9', '48F0EUVN', '49571DXY', '49YZILWR', '4CKAV3LS', '4DGGCYVE', '4DGMNDIC', '4E7187A9', '4GF31RCS', '4GHCND6Z', '4IADYZ8R', '4IDTMY10', '4J7KEYE2', '4KSHU5M7', '4LCFACE1', '4LQ8L195', '4M3XG8RC', '4O39WLXM', '4O5RQHEF', '4PKCMX7O', '4QK5ZDHA', '4QU07FT7', '4RCA1UZG', '4RHLX089', '4S1LIWGV', '4TIT4L5F', '4U5LAAN5', '4VHMF1RI', '4WAQ4VFB', '4WRI77CU', '4X2RTV2D', '4Y4DT3SL', '4ZYW54M8', '50NBGIOB', '52Y9GFGK', '54C6PEBH', '54ZFOPSF', '558GIQ68', '55HTZ7T0', '579G0TJI', '57FHO8YC', '57NGF1YS', '58BSUZQB', '5ASQZ0OT', '5AUVXXDU', '5BNUT8AW', '5BTY65G6', '5CBNCRST', '5FUDT1QA', '5H71LUBY', '5K2PTY6L', '5KXWXV9G', '5LH9NUMK', '5OBD73W0', '5OF7OYEA', '5OFUVG9U', '5PC2F8NE', '5PR9OSRS', '5Q9ETXJL', '5QLBIUXN', '5QY2HU8J', '5SCOFTY2', '5SGMS705', '5V3Z108E', '5W2PCT95', '5X9VNAN3', '5Z4CMIY5', '5ZB8I3T0', '5ZW05824', '60HBQEP8', '62PKSARW', '638UYIQC', '64FFXH4M', '65CCBIXK', '669R7ER0', '66XSSS3Q', '685KTH3G', '68OY1RK5', '69M351P4', '6AT20D5S', '6DBY872A', '6E28DNQK', '6KT0EAKX', '6LQ0W02R', '6NCTAA30', '6NKNB308', '6NULQ6KP', '6PS2LHCV', '6PXRABDR', '6QBXXYN4', '6QUCW04X', '6SBB6IL2', '6T9SGGS1', '6TT5CXVI', '6TTWEXT3', '6UGWNYCX', '6UI9XACW', '6UXF7L28', '6WD2LIHN', '6WT1F4RJ', '6XVBD39G', '6YSX60MZ', '7039MMH2', '709K4VRB', '7185O9V8', '71R7TM8L', '738FBTIL', '73RKEO3U', '747XMBIJ', '74RXUGS4', '74TS5KG4', '78QGAL01', '78XDAJNS', '7ANCD9AK', '7DMNXU84', '7E63E5RD', '7F905YRZ', '7GWW4637', '7IHPTKFF', '7KG191H8', '7MUAYEHW', '7NGLQ1CA', '7O3PWIL0', '7OV5K86R', '7PWA4ZJN', '7QEORFJN', '7QF2VB5B', '7QWHL2C6', '7SW79VAJ', '7T28F53W', '7TYZHD5J', '7UU8O65I', '7WKS90AG', '7X3RSRT5', '7XPDUYJE', '7XU8ACPI', '7YSTNZME', '7ZV0Z1T9', '81QAZACE', '82NXGO4K', '862RYK1K', '86ET7WW4', '88E6O06E', '8ABA3MWO', '8BF8ANNO', '8C0T09C6', '8C9737JL', '8D4D6M5V', '8ECLELF1', '8EKC599S', '8F0XPAZX', '8FT6HD4D', '8FZMCIFG', '8G29TDOS', '8H6M75LF', '8HI3GY44', '8HW91I4K', '8HZXGARR', '8IPYO6SS', '8JKDTT0Y', '8K0HZBL0', '8MUKKVMF', '8MW998Z0', '8N5EPD5C', '8OBT3FSQ', '8ORZZFA7', '8RIKS696', '8SW7WFE6', '8T12OXHS', '8US76O46', '8VCFY56I', '8VI1RY3M', '8VLB2R3D', '8WAY3T1E', '8Z6SANMH', '8ZB94ICE', '8ZB99KHH', '904V6V2S', '909V5A2H', '91Y7NKBM', '91Z8RRSB', '92WF5WVN', '93R70J1L', '93WIIL7Y', '97FR69TQ', '97PR85CP', '99A19JAD', '9DBCRJYM', '9DKQF2I2', '9DRMDPIZ', '9G5XH4HI', '9GDHC3D0', '9HPM9NFY', '9HRDSOST', '9IVIPDX5', '9JRKFKVC', '9KHXMSMW', '9KV8R3HP', '9LSH625Y', '9MC0DPDJ', '9MC1YKKZ', '9MEFUZQN', '9MG50RM7', '9MZBKXJF', '9PWYZMNS', '9QQZ79I6', '9R765PJF', '9SJCUIKS', '9SSQ1FSY', '9U0DELRD', '9WEGTUIJ', '9WQQKFVK', '9XE0FL8P', '9Y5EWA8O', '9YM3QINZ', '9ZTEQPA4', 'A0ADXLZU', 'A0Z7XCDN', 'A1738D1Z', 'A18S09P2', 'A1A8EROR', 'A1J0YXZX', 'A2A1R52R', 'A2U1AIC1', 'A332O9JW', 'A3FZPLM1', 'A3QUOXIX', 'A44GW57T', 'A4BM0B6A', 'A6RCKKER', 'A768XIWP', 'A78F2YFJ', 'A7CK3WNB', 'A810BWR5', 'A8FZHMOS', 'A9G8OKRG', 'AATDRXYQ', 'AAURK3RG', 'ABMAPCYN', 'ABWCZWFU', 'ACO8WWPF', 'ADB7SAPN', 'AG93GZYN', 'AHMVJ2VP', 'AL7N3DL2', 'AM8AJH2H', 'AMSPTQVJ', 'AMV4U0A0', 'AOCCEP3S', 'AOFJN8HX', 'AOFPYGHC', 'AOKRU4AF', 'AOQQU910', 'AR433PVR', 'AS30HPUK', 'AUCMR8HU', 'AUUSW2YZ', 'AUZNSS79', 'AV7ONIVD', 'AWWC1KIV', 'B131HDBV', 'B17J3JSX', 'B1I4L0XW', 'B25KOPVH', 'B2BULVFH', 'B4L9R8JU', 'B517ID6W', 'B832TQ6U', 'B8FC99WI', 'B8YR9IIK', 'B9H5SLHK', 'BBTA1L43', 'BBZJCYJ0', 'BD9EXLDM', 'BDQOSDFG', 'BDSEVK9M', 'BH7HW7XH', 'BHKOO62U', 'BHNI9DCI', 'BHW9ILRC', 'BJKTDFN4', 'BL2TLVFC', 'BLC9WIIM', 'BLFM4YKK', 'BLNELN02', 'BN8BMXPM', 'BNFZZTKX', 'BP2X9ITX', 'BPT27UPE', 'BQJ79YS3', 'BSEEWS00', 'BSH6LB19', 'BTQL3UFQ', 'BV6PVSO5', 'BV8D4RYV', 'BWFN4ZI7', 'BXMEKONO', 'BY5IEG4O', 'BZBNZDNS', 'C1BIUBL5', 'C35C2C2W', 'C4W63WJ2', 'CA0MBQ9S', 'CAO2H0WE', 'CAQEITX6', 'CB714TAM', 'CBCQST29', 'CBFKYZ9S', 'CBKRHK4I', 'CDM3SRRP', 'CDU1LWN3', 'CEATO4LM', 'CENOJ84D', 'CFDEOSH4', 'CFOET28L', 'CFQ9PAJA', 'CHTQ7QLX', 'CJFLQNE1', 'CK1M5UHL', 'CKDZNQV2', 'CLO7VQ12', 'CNX48K3H', 'COEMYLH1', 'COVE5WRD', 'CRP30ATM', 'CTJGWLX0', 'CTLP20Y9', 'CWZP8AQK', 'CY64689U', 'CYCSYMQ3', 'CZUGPH88', 'D0EKC82X', 'D0NFHXL2', 'D0YWREJ5', 'D10S0UDQ', 'D1BZRMOB', 'D2N5DOSQ', 'D3KJQCYH', 'D4PJE56U', 'D4Q1QMRJ', 'D63K976U', 'D7L6VZNV', 'D8MRQA91', 'D8OQ3YNK', 'DD0JBK3T', 'DE6NAU7D', 'DEFNZK0A', 'DEWKAO5I', 'DGE8LLAJ', 'DGQ2L6KM', 'DJW5U56I', 'DKA65CRR', 'DLSU0QRX', 'DN01XVIU', 'DQGG01WF', 'DRFCUPZO', 'DSE2G8LF', 'DY0KIZZ9', 'DZ2XFGQS', 'E3CE5WE9', 'E3CRPQL7', 'E3FFACSU', 'E4EF2K0A', 'E4T4IQMG', 'E59C5N01', 'E5OB5QF1', 'E6G69ESA', 'E6TPDVWA', 'E7CPRIYW', 'E7EZD62E', 'E8100WU0', 'E8GMEHFW', 'EA2DKNTD', 'EBF1G8Z7', 'ED0OS5OF', 'EEC8D29F', 'EFKGYR79', 'EI8B4WEC', 'EJ3T17DB', 'EJXP2QAW', 'EKHYS325', 'EKXAPD70', 'EL9FN1LB', 'ELF2BN3S', 'ELX1D1DS', 'EMJXDINV', 'EMNH5MYX', 'EN78WKI4', 'EOQAQ9X1', 'EPDX32D3', 'EQPB3YTZ', 'ER1IJR80', 'ETR2SP13', 'EW4ZXWSN', 'EXQZ5V7S', 'EYOJGC9T', 'EZ40BRHE', 'EZL4HNHH', 'EZMV5TKG', 'F0ESSJYM', 'F0MOWJYA', 'F1X6DMDH', 'F3D2JAYU', 'F3S4VUQI', 'F50DBVIK', 'F8I0DT7Z', 'F8LNIZ27', 'FCI1HZ3G', 'FEBWERSN', 'FH8TEJI1', 'FHR8UUYO', 'FHZYKEUV', 'FJTJ4KY0', 'FLHGDG0P', 'FLSWA4NU', 'FLU9ZT18', 'FMJ19E48', 'FN1RKQ2M', 'FN38BX60', 'FNKCHGB7', 'FNM1Z945', 'FPH5H8JT', 'FQ8V2QHL', 'FRFT0H8N', 'FRK40JVP', 'FRX9XJYW', 'FSR0IC6I', 'FVYCRUFK', 'FWOZ05UZ', 'FXBIP7LS', 'FXRWH0M9', 'FZ37IFWH', 'G2P73NZ0', 'G4UJDFPK', 'G57JANUL', 'G6MP6EIN', 'G7MXLRV8', 'G81LO0AZ', 'G8QWQL1C', 'GB45D1XV', 'GBX3MNVS', 'GDV3S3ZG', 'GHG5MDER', 'GJKR73YA', 'GJPI1WIV', 'GKY6ZB15', 'GKY7BZOQ', 'GLOJFBA0', 'GLUZC5HC', 'GM3HKY2J', 'GS8G1IFF', 'GSK9JT39', 'GSNU5TXL', 'GT4RHNUE', 'GTVTUGVY', 'GUCIE6TT', 'GUWYJRRS', 'GWJ0A1IK', 'GWP6E8FA', 'GYCOAVYS', 'GYCY8LCF', 'GZMPRX5J', 'H0WSDLJE', 'H12S8X2Q', 'H1G4FFR7', 'H20JGHP0', 'H3D82ATM', 'H3RWZ7UR', 'H48Y5BOY', 'H5Y73UHQ', 'H9RBDN30', 'HB3OQUA5', 'HCW1Y9QM', 'HGN5HD65', 'HGPS0FQN', 'HHSIC4NY', 'HI7ZNYCK', 'HJNGSDJ5', 'HK78MCH7', 'HNGYSI62', 'HODOBX62', 'HQC2OFGM', 'HRFD8R1G', 'HRWBEBRE', 'HT51BMN1', 'HTXABMRS', 'HV6GZXC3', 'HVAG84XI', 'HVBBJM37', 'HVN93I56', 'HVXSID0M', 'HVZMFFNW', 'HX2XDS73', 'HX5NMCPJ', 'HY9DN23J', 'HZ5C2E4C', 'I0J54PBT', 'I16TS2B4', 'I1RQMFZC', 'I2ATV1DI', 'I2N7C27Y', 'I3UODLOR', 'I5L6E1U2', 'I5RNBXF3', 'I6B3VKYD', 'I7FXTVDP', 'I8U0Q5FP', 'I9MWC6I3', 'IBBLXRDR', 'ICDP084U', 'ICRBJL24', 'ID37U3DA', 'IDXJ25FE', 'IGHBC70Q', 'IH12MVU4', 'IIWFYXGG', 'IJEA3NUI', 'IL47R85Z', 'ILKPIFSA', 'IM2JLO1B', 'IMFV7GM3', 'IMVSI4VW', 'INDCDVP0', 'INELF20P', 'INJ6L6NB', 'IO2FYB6G', 'IO56YRTG', 'IOKPSO7K', 'IOOQONCI', 'IOPR6B78', 'IP9XMFII', 'IPV1W17S', 'IPVYEI8G', 'IQPZXRU2', 'IS75OD95', 'ISMP5LYF', 'IUJPYIRX', 'IYKXT23R', 'IZD0O5Q0', 'IZSQDCWP', 'J0NVCXDJ', 'J1UFMOCR', 'J339EI56', 'J3752QSY', 'J3L1KD1J', 'J3YKGOCX', 'J5WRC3DJ', 'J648LM1S', 'J70NZZIW', 'J7PWRE94', 'J9M11KX1', 'JAEI655A', 'JB8JTFSG', 'JC35D8WT', 'JC6LUZLT', 'JCHNPTSF', 'JDENEZ6I', 'JICWX3AS', 'JJBJFUAT', 'JK9C0VN8', 'JKUCC6UK', 'JL1OZP2G', 'JMJD18BP', 'JN497K3S', 'JNB98WP1', 'JNU5CAOV', 'JO1WTZOB', 'JPI7LZJ3', 'JPO7CTQP', 'JQ4YBT3Z', 'JQ7Z5Q44', 'JQJ499YN', 'JRBK08H6', 'JRDHZ51W', 'JRRTJ3GV', 'JS1KUAD6', 'JS59HL6M', 'JSEGAB8K', 'JT4GYL2P', 'JUC55NLK', 'JUYW4QZ1', 'JVWQ5HEJ', 'JWVCJ3UR', 'JWYYB1L5', 'JXDP2C4M', 'JYZ82A2B', 'JZ1RSLKQ', 'JZ2KQL0P', 'JZS556ZA', 'JZTRRSKQ', 'K1DU5H0C', 'K1K1AESM', 'K212MH7P', 'K25LXPOI', 'K3QD4AHX', 'K4AGNZ3R', 'K57LN37R', 'K83DA8K5', 'KB0YFLBH', 'KD7N9YDF', 'KDW3ZVWJ', 'KDZ388UF', 'KF32BDPB', 'KFWFMIUK', 'KG943QKP', 'KGMINGSB', 'KH4VOX9Q', 'KJJYCUJ7', 'KKG07XA9', 'KKIO1X0Z', 'KM3OV97R', 'KMPCXZUY', 'KMSH5BSO', 'KRS7ST1L', 'KSFFKSV7', 'KU0G64D0', 'KUGU9MQC', 'KUH39TQR', 'KV5TCH8S', 'KVLIE219', 'KWH2Y6KA', 'L0FS3EPM', 'L27ULB0P', 'L2HRYP1A', 'L2UTYYJT', 'L3OPGJO5', 'L3RQSW75', 'L3SSKU27', 'L5AMS3QT', 'L657W1BK', 'L76WWQ74', 'L78GOBQS', 'L905DK46', 'LDCSZOKC', 'LF9AQIHZ', 'LFQ6YRHV', 'LGEAIIK8', 'LGTP4O86', 'LHMKC873', 'LHNLO8Q8', 'LKC4LOOM', 'LKR5NGJZ', 'LKVB0S84', 'LL11R5T6', 'LM6LV3JB', 'LNTF6KP8', 'LPBA27LH', 'LPQY1SEL', 'LQ6K46C8', 'LU684LJ9', 'LUHRMKEB', 'LUI0TOT2', 'LVXSGLT6', 'LWQ8FULT', 'LXBPBCS3', 'LXOZJ3TV', 'LXPTXE5K', 'LYY8P69T', 'M1CZ7MK8', 'M2HPA1EK', 'M2R84KMY', 'M2W28OUV', 'M3B15QGL', 'M3MFQNC7', 'M46L0EBU', 'M4V0NJ97', 'M59DNUXD', 'M9265ASV', 'M9PHW06O', 'MB9HHEPN', 'MBQUJESG', 'MDCIP8E0', 'MEKV5BRI', 'MEVIH0XF', 'MFZHQ165', 'MGQBELNN', 'MH0GC0GY', 'MIUE47ZL', 'MJR1CR7U', 'ML1YCDCG', 'ML5W6LDB', 'MLGLKKI7', 'MMU3QFIP', 'MNV2YSWZ', 'MOCIAZ0D', 'MQKR83SM', 'MQQTIYIC', 'MQRIDTFZ', 'MULMC195', 'MUO5QBB6', 'MV1CMX4O', 'MXV7CSHI', 'MZOM2K35', 'N0CP1NI7', 'N0FDUY5E', 'N5LOOJSR', 'N5X3YG2I', 'N764BFJU', 'N7BY4DKZ', 'N8FNYI0A', 'N8X63KYC', 'N9I581ZL', 'NBCZC85X', 'ND7I48LA', 'ND88CY09', 'NDDT3NOB', 'NDZT8PV3', 'NHNLVWDR', 'NIKHJTWP', 'NIRCF0RK', 'NK0S2WH6', 'NKPC0Z4Q', 'NKRRLD5O', 'NMQKJMH3', 'NNNIMDVI', 'NPWC1BXV', 'NQVW27OC', 'NR26DCAB', 'NRRH4BON', 'NT9Y0D19', 'NTLCS343', 'NUOEY3LD', 'NUSJ1NGL', 'NUYVBFLU', 'NWE84W10', 'NWKWVAIA', 'NX7I9PQG', 'NYI75N90', 'O1LMIA6M', 'O3M287V6', 'O4VJ2EV7', 'O55K40VQ', 'O5PJEO54', 'O69KS0OS', 'O7NEA7KO', 'O8E18PJ4', 'OAEZWMZR', 'OAPTL0AF', 'OB97CO94', 'OCJ3W2EF', 'OEGM98R5', 'OERPDTWW', 'OG01U0FT', 'OJ9HCGTB', 'OKI0Z2UO', 'OKK933IV', 'OKWROFEH', 'OL1HWRRD', 'OL59ZZX5', 'OML0TEF3', 'ON2CU60C', 'ON9AXMKF', 'ONPQ2I44', 'OOKK1JHN', 'OPPRIPN9', 'OUA1CRWO', 'OUJLF506', 'OVPHRVOD', 'OYRI4NVE', 'P361G1OD', 'P3Q11IAK', 'P4H26KKX', 'P8PW7Q1Q', 'PEUBDA2B', 'PFI6E05S', 'PFNRAGJP', 'PGWZZALU', 'PHQEJTNO', 'PIT16TZ9', 'PJYVLL0Z', 'PKC5LJ6W', 'PMCWG8N5', 'PNWFSSF0', 'POKTJVRL', 'PONI61NE', 'POZMOX9T', 'PQZ6Z3YJ', 'PRU3JF6Y', 'PRYT0A2P', 'PS6MZN15', 'PSY58O49', 'PUECZ8ZI', 'PV7QTHJV', 'PW7GT7TE', 'PXT3AJ7C', 'PY8VPVM5', 'PYX7I7X5', 'Q1D88JO2', 'Q1M9RXYR', 'Q21CAL4Z', 'Q2K8NHZY', 'Q2LO2OGN', 'Q35PXLRT', 'Q3O4J4HB', 'Q5V3EKJC', 'QASMCASJ', 'QC3VEU4P', 'QEOKKUF1', 'QJ5LYZHA', 'QJJAG1IV', 'QJMUUPFK', 'QL3AU1NN', 'QNE79S52', 'QNKGHIRB', 'QNQQVRNB', 'QPA31HRW', 'QQFF3LO5', 'QQR3SE8Y', 'QR91QBR2', 'QSLQZQH2', 'QT44Y8VV', 'QTIRUM0G', 'QUFMTUB3', 'QUUKEGL5', 'QV09SDY8', 'QV71AJ91', 'QVAHXT35', 'QVAZPYQ8', 'QYBCIW4J', 'QYZ57QTQ', 'QZ1V5GME', 'QZ8BT14M', 'QZD4I9UW', 'R1BX2NZI', 'R1OFLDKQ', 'R2O5C424', 'R3AAYF7V', 'R3QOGZZF', 'R5B3KVZI', 'R67AMR4P', 'R6QNKUC4', 'R830GQGO', 'RASRCD7I', 'RBL3SN1I', 'RBLPDV4R', 'RBMLZBYW', 'RD5YXSBA', 'RD62G56Y', 'RE7IER1C', 'REKW7MRF', 'RF45YZMF', 'RFUY4U4W', 'RFYO6TO0', 'RGD51NW1', 'RHH1X0A2', 'RHSAJGR1', 'RIEIBCRF', 'RKJHZGDQ', 'RNSK8HLJ', 'RP37N5WN', 'RQUURTUT', 'RRIG3SH3', 'RSMDF425', 'RYUA3GVO', 'RZCRWMTU', 'RZPGGEG4', 'RZPT9APG', 'RZT9JPDV', 'S0Z5J1EW', 'S15Z6XG6', 'S2PFIP6S', 'S2ZYVBUF', 'S5CBU2AX', 'S7345IVO', 'S768X16I', 'S824JJ06', 'SAONBMNO', 'SBQXQOPV', 'SBWHI6Y6', 'SCKCR39J', 'SD7VPKVQ', 'SDNECLRB', 'SEAEY0CN', 'SEH3FI81', 'SEVOI9NR', 'SEX60YJE', 'SFPE2DX4', 'SGAZ5VOA', 'SGIINS2G', 'SHKNA9S1', 'SIUTK5SR', 'SIUXBYDS', 'SLG5DZG2', 'SLVO27W6', 'SM3HAKL8', 'SNNICLKQ', 'SNZP9G8K', 'SOPNMXWX', 'SQB9N47Y', 'SR345GAS', 'SRZSX1LR', 'SSVDNEY9', 'ST2DCNR0', 'SU06AE5D', 'SUUFTUWK', 'SW00LEHT', 'SWHE2RH1', 'SYQSKHN2', 'SZ0MR59K', 'T18CGW8H', 'T3KHULCH', 'T4J4YRDK', 'T5R7YFPH', 'T8R673OI', 'T9LSOTV6', 'T9ZHWQE9', 'TBJE6V15', 'TBUHVONI', 'TCKOTGYJ', 'TD593FIM', 'TE1TWCPZ', 'TFTOGJOD', 'TGPPSF7M', 'THD393NW', 'THW6JGC7', 'TI21BGNU', 'TIAPP57M', 'TJLVHJ87', 'TK932JM1', 'TKLYRWYO', 'TNR495LD', 'TQAA3UHV', 'TRM5SRRW', 'TTU1NVDI', 'TU2W2LCB', 'TUO2TVTX', 'TVQC1R4D', 'TWH1XFPL', 'TWV05PEP', 'TYJN7K7A', 'TYQ2T01H', 'TZ8JAEO6', 'TZL79DYX', 'U0U7F3EW', 'U2C1NG0D', 'U2C2VVY8', 'U2OZU4IY', 'U2VWRM3F', 'U2ZEEFLD', 'U3QRAT06', 'U47IUY9C', 'U49ISLNE', 'U5966IDO', 'U5ZJCLCX', 'U69N21WU', 'U6DS14AT', 'U6TNOS7M', 'U74I1JYB', 'U8FRHWSV', 'U8SWTHB5', 'UAY0HW9A', 'UBO7MS4D', 'UBWK5LJH', 'UBXL2EGE', 'UC094GDG', 'UCC4KYQL', 'UCVUALGM', 'UEZVPK90', 'UFAQZXPY', 'UFEO02VM', 'UFTYVG6Y', 'UH5Z524P', 'UHU62P41', 'UJNF3UO2', 'UJSK2U9A', 'UK4B4I7A', 'UKG1R822', 'ULOHU3PC', 'ULVU086L', 'UMDZG9XM', 'UMM76IOX', 'UMOD7PGG', 'UNAGKRY0', 'UNE947CO', 'UO4MVLJS', 'UP3750KB', 'UQUIUCVA', 'URO46KFW', 'URY1ZVZI', 'US8KF8X3', 'UVXQ3O4K', 'UWWS6RWO', 'UXK3D4GF', 'UYCX4ZJS', 'UYLJZRPN', 'UYPE34HA', 'V04Z48C3', 'V1YVL2DL', 'V3JDHWOB', 'V4A28VLV', 'V4RKPN30', 'V5C3CWTK', 'V6X2Z58S', 'V8MF2IKQ', 'VAGUTU8C', 'VB04AEHZ', 'VDSDXJ71', 'VDYHUCQB', 'VE48SF8D', 'VFCTUL5J', 'VFOEJ2CS', 'VGCXUCRO', 'VGWO9SBA', 'VHPX9GYO', 'VJU9EYFE', 'VKN3L279', 'VKU9G6Y5', 'VMU0L6UM', 'VO0ATBFS', 'VOT8OKU2', 'VRZZPHI4', 'VW6ZY2L1', 'VYW7T8YY', 'VZLS9GCK', 'W184Y53L', 'W1STLS0T', 'W2DYAZID', 'W7WRIFC0', 'W9QZOUW7', 'WAL364PD', 'WB78G3XF', 'WBGCVIO8', 'WD8MHX8N', 'WDNYZZHJ', 'WG42FGWA', 'WG7S6W2T', 'WHLUO40S', 'WK162QYQ', 'WK4NBYSB', 'WKRC8NSD', 'WKYJ6R7D', 'WL3FJI96', 'WL8VMHWG', 'WM3Q8LBC', 'WM9JWC4B', 'WNEX0Y1X', 'WP6H3E2T', 'WQ1DVVYG', 'WQBN4WGH', 'WRDZ1CVS', 'WSHPKJ3H', 'WTFS8JV2', 'WTYMIZ88', 'WUARWGNF', 'WUR2UJYP', 'WWDAZG6C', 'WX0HMR4F', 'WZX61W39', 'WZZLL8O4', 'X0VJJXGQ', 'X2PFPX2S', 'X4WO7LHO', 'X4YNMN9Z', 'X6497O49', 'X6LFEBK7', 'X920R0YN', 'X9RNN0YD', 'XCWSW5T9', 'XD80LQN2', 'XE4D68OI', 'XHQPAVRU', 'XLYFD8RW', 'XOEVMQZT', 'XP1SRNTB', 'XP5B8615', 'XPQ9IYZC', 'XR7GR7UE', 'XRENDLF1', 'XSA3Y2H6', 'XTKRJ8N6', 'XU8GASLQ', 'XV32YHEZ', 'XY9JOM6L', 'XYB5NWR4', 'Y060M6TK', 'Y324NGPN', 'Y3HA6UDE', 'Y4G53L4X', 'Y4X5JU76', 'Y575VUS1', 'Y5YH740Y', 'Y620TYKH', 'Y6EC9YQA', 'Y73L2QKM', 'Y81SHRRC', 'YCD71LRY', 'YCNWCC0Z', 'YCY2FFYZ', 'YDPNP1KR', 'YE9BU3J3', 'YEA0ZZZP', 'YEZ30YUQ', 'YFSGJUTL', 'YGFI5B9G', 'YGFIQ8SA', 'YHUR7HZ6', 'YHX2594T', 'YKXRSB4N', 'YL8AOR9Q', 'YLS2HEMR', 'YMHGXK99', 'YMWK7JKH', 'YP4WCV92', 'YQ3L8TWE', 'YQITW66D', 'YTGT3GEX', 'YTOOMPZ8', 'YW85XPTE', 'YWQZUSA8', 'YWZAEK5A', 'YXKFDH6S', 'YY5Y32CI', 'YZX8R26H', 'Z1C99MVU', 'Z1Y066QU', 'Z6LWLWFZ', 'Z7YFK3I0', 'Z7ZKDLZG', 'Z80NVAXF', 'Z8BWVZZX', 'ZAYLY2YU', 'ZB6DPIG5', 'ZB862XHR', 'ZBQD50GN', 'ZC07UYVV', 'ZCU48L3S', 'ZEAZQ1QQ', 'ZEB7PDQK', 'ZEBTRK7D', 'ZEJOQQJF', 'ZELU1VMX', 'ZFBSIW7Q', 'ZGY1YZ7P', 'ZH6LR5MO', 'ZIGUIE0J', 'ZIJRW95G', 'ZK6YBV02', 'ZLSXM0KN', 'ZMCRIYYJ', 'ZMEZU4BS', 'ZMUIMBDX', 'ZOI7FJEN', 'ZQ5A6IY9', 'ZQNGGY33', 'ZSHS4VJZ', 'ZT1IP3T6', 'ZU6860XU', 'ZU6TVFFU', 'ZU75P59K', 'ZUI6TDWV', 'ZWFD8OHC', 'ZX06ZDZN', 'ZZJVE4HO'])
print("All target values (these are the lab_ids), including None: ")
print(all_targets)
print("\n")
print("Number of different target values (lab_ids): ", len(all_targets))

All target values (these are the lab_ids), including None: 
[None, '00Q4V31T', '012VT4JK', '028IO5W2', '03GRNN7N', '03Y3W51H', '09MQV1TY', '0A4AHRCT', '0A9M05NC', '0B9GCUVV', '0CL7QVG8', '0CML4B5I', '0DTHTJLJ', '0FFBBVE1', '0HWCWFNU', '0L3Y6ZB2', '0M44GDO8', '0MDYJM3H', '0N3V9P9M', '0NP55E93', '0PJ91ZT6', '0R296F9R', '0T2AZBD6', '0URA80CN', '0VRP2DI6', '0W6O08VX', '0WHP4PPK', '0XPTGGLP', '0XS4FHP3', '0Y24J5G2', '10TEBWK2', '11TTDKTM', '131RRHBV', '13LZE1F7', '14PBN8C2', '15D0Z97U', '15S88O4Q', '18C9J8EH', '19CAUKJB', '1AP294AT', '1B9BJ2IP', '1BE35FI1', '1CIHYCE4', '1DJ9L58E', '1DTDCRUO', '1EDZ6CA7', '1HCQTAYT', '1HK4VXP8', '1IXFZ3HO', '1K11RCST', '1KC6XYO6', '1KNFJ6KQ', '1KZHNVYR', '1LBGAU5Z', '1NXRMDN6', '1OQJ21E9', '1OWZDF82', '1PA232PA', '1PIGWQFY', '1Q1IUY3G', '1S515B69', '1TC200QC', '1TI4HS4X', '1UOA7CA1', '1UREJUSJ', '1UU0CHTK', '1VPOX8VI', '1VQS4WNS', '1X0VC0O1', '1XU60MET', '1ZC8RPN1', '20ABQYHS', '20CEB9KE', '216DWMG6', '21ZFBX5E', '24SL2992', '25UVYUID', '26KK8UM5', '27OS3BTP

#### Target dictionary

In [6]:
targets_dict = {t: ix for ix, t in enumerate(all_targets)}
print("targets_dict: ")
print(targets_dict)

targets_dict: 
{None: 0, '00Q4V31T': 1, '012VT4JK': 2, '028IO5W2': 3, '03GRNN7N': 4, '03Y3W51H': 5, '09MQV1TY': 6, '0A4AHRCT': 7, '0A9M05NC': 8, '0B9GCUVV': 9, '0CL7QVG8': 10, '0CML4B5I': 11, '0DTHTJLJ': 12, '0FFBBVE1': 13, '0HWCWFNU': 14, '0L3Y6ZB2': 15, '0M44GDO8': 16, '0MDYJM3H': 17, '0N3V9P9M': 18, '0NP55E93': 19, '0PJ91ZT6': 20, '0R296F9R': 21, '0T2AZBD6': 22, '0URA80CN': 23, '0VRP2DI6': 24, '0W6O08VX': 25, '0WHP4PPK': 26, '0XPTGGLP': 27, '0XS4FHP3': 28, '0Y24J5G2': 29, '10TEBWK2': 30, '11TTDKTM': 31, '131RRHBV': 32, '13LZE1F7': 33, '14PBN8C2': 34, '15D0Z97U': 35, '15S88O4Q': 36, '18C9J8EH': 37, '19CAUKJB': 38, '1AP294AT': 39, '1B9BJ2IP': 40, '1BE35FI1': 41, '1CIHYCE4': 42, '1DJ9L58E': 43, '1DTDCRUO': 44, '1EDZ6CA7': 45, '1HCQTAYT': 46, '1HK4VXP8': 47, '1IXFZ3HO': 48, '1K11RCST': 49, '1KC6XYO6': 50, '1KNFJ6KQ': 51, '1KZHNVYR': 52, '1LBGAU5Z': 53, '1NXRMDN6': 54, '1OQJ21E9': 55, '1OWZDF82': 56, '1PA232PA': 57, '1PIGWQFY': 58, '1Q1IUY3G': 59, '1S515B69': 60, '1TC200QC': 61, '1TI4HS4

#### reverse target dictionary

In [7]:
reverse_targets_dict = {ix: t for ix, t in enumerate(all_targets)}
print("reverse_targets_dict: ")
print(reverse_targets_dict)

reverse_targets_dict: 
{0: None, 1: '00Q4V31T', 2: '012VT4JK', 3: '028IO5W2', 4: '03GRNN7N', 5: '03Y3W51H', 6: '09MQV1TY', 7: '0A4AHRCT', 8: '0A9M05NC', 9: '0B9GCUVV', 10: '0CL7QVG8', 11: '0CML4B5I', 12: '0DTHTJLJ', 13: '0FFBBVE1', 14: '0HWCWFNU', 15: '0L3Y6ZB2', 16: '0M44GDO8', 17: '0MDYJM3H', 18: '0N3V9P9M', 19: '0NP55E93', 20: '0PJ91ZT6', 21: '0R296F9R', 22: '0T2AZBD6', 23: '0URA80CN', 24: '0VRP2DI6', 25: '0W6O08VX', 26: '0WHP4PPK', 27: '0XPTGGLP', 28: '0XS4FHP3', 29: '0Y24J5G2', 30: '10TEBWK2', 31: '11TTDKTM', 32: '131RRHBV', 33: '13LZE1F7', 34: '14PBN8C2', 35: '15D0Z97U', 36: '15S88O4Q', 37: '18C9J8EH', 38: '19CAUKJB', 39: '1AP294AT', 40: '1B9BJ2IP', 41: '1BE35FI1', 42: '1CIHYCE4', 43: '1DJ9L58E', 44: '1DTDCRUO', 45: '1EDZ6CA7', 46: '1HCQTAYT', 47: '1HK4VXP8', 48: '1IXFZ3HO', 49: '1K11RCST', 50: '1KC6XYO6', 51: '1KNFJ6KQ', 52: '1KZHNVYR', 53: '1LBGAU5Z', 54: '1NXRMDN6', 55: '1OQJ21E9', 56: '1OWZDF82', 57: '1PA232PA', 58: '1PIGWQFY', 59: '1Q1IUY3G', 60: '1S515B69', 61: '1TC200QC', 

## Writing the pipeline that generates predictions

Model directories are of the form `.../pca_engineered_datasets/pca32_95comps/train_splits/split_xx/random_forest`

#### Generate directories 

In [8]:
def generate_dir(directory,delete_dir=True):
    if not os.path.isdir(directory):
        print(f"Creating directory {directory}")
        os.makedirs(directory, exist_ok=True)
    elif delete_dir:
        print(f"Directory {directory} already exists. Deleting an recreating.")
        shutil.rmtree(directory)
        os.makedirs(directory, exist_ok=True)
    else:
        print(f"Directory {directory} already exists. I will either overwrite or add files to it.")

#### Prediction functions

In [26]:
def predict_single_df(df,model,features,reverse_targets_dict,all_lab_ids=all_targets[1:],average=True):
    ### features
    x = df.loc[:,features].values
    ### predictions
    y_pred = model.predict_proba(x)
    ### classes (lab_ids)
    in_lab_ids = [reverse_targets_dict[c] for c in model.classes_]
    out_lab_ids = [lab_id for lab_id in all_lab_ids if lab_id not in in_lab_ids] 
    ### probabilities dict
    probs_dict = {"sequence_id":df.sequence_id.values.tolist()}
    for lab_id in all_lab_ids:
        if lab_id in in_lab_ids:
            probs_dict[lab_id] = y_pred[:,in_lab_ids.index(lab_id)]#.tolist()
        else:
            probs_dict[lab_id] = y_pred[:,in_lab_ids.index(None)]/len(out_lab_ids)#.tolist()
    probs_dict = pd.DataFrame(probs_dict)
    ### ordering columns
    probs_dict = probs_dict.loc[:,["sequence_id"]+all_lab_ids]
    if average:
        probs_dict = probs_dict.groupby("sequence_id",as_index=False).mean() 
    return probs_dict

def predict_seq_id_dir(seq_id_dir,model,features,reverse_targets_dict,all_lab_ids=all_targets[1:],verbose=False):
    seq_id_files = sorted(os.listdir(seq_id_dir))
    if verbose:
        pbar = tqdm(seq_id_files)
    else:
        pbar = seq_id_files
    concat_dfs = []
    for f in pbar:
        if verbose:
            pbar.set_description(f"Processing sequence {f}")
        df_seq = pd.read_csv(os.path.join(seq_id_dir,f),index_col=0)
        df_pred = predict_single_df(df_seq,model,features,reverse_targets_dict,all_lab_ids,average=True)
        concat_dfs.append(df_pred)
    df_pred = pd.concat(concat_dfs)
    return df_pred

def predict_single_split(seq_id_dir,models_dir,features,reverse_targets_dict,n_models_samples=None,
                         all_lab_ids=all_targets[1:],
                  savedir=None,delete_dir=False,verbose=False):
    if savedir is not None:
        generate_dir(savedir,delete_dir)
    model_joblibs = os.listdir(models_dir)
    if n_models_samples is not None:
        model_joblibs = np.random.choice(model_joblibs,n_models_samples)
    if verbose:
        pbar = tqdm(model_joblibs)
    else:
        pbar = model_joblibs
    for mj in pbar:
        if verbose:
            pbar.set_description(f"Predicting with {mj}")
        model = joblib.load(os.path.join(models_dir,mj))
        df_pred = predict_seq_id_dir(seq_id_dir,model,features,reverse_targets_dict,
                                     all_lab_ids,verbose=False)
        savepath = os.path.join(savedir,mj.split(".")[0]+".csv")
        df_pred.to_csv(savepath,index=False)

def predict_several_splits(seq_id_dir,splits_dir,model_name,features,reverse_targets_dict,
                           n_models_samples=None,all_lab_ids=all_targets[1:],splits_range=None,
                           random_splits=False,
                  savedir=None,delete_dir=False,verbose=False):
    splits = sorted(os.listdir(splits_dir))
    if splits_range is not None:
        start,end = splits_range 
        splits = splits[start:end]
    elif random_splits:
        splits = np.random.choice(random_splits)
    if verbose:
        pbar = tqdm(splits)
    else:
        pbar = splits
    for s in pbar:
        if verbose:
            pbar.set_description(f"Processing split {s}")
        models_dir = os.path.join(splits_dir,s,model_name)
        split_savedir = os.path.join(savedir,model_name,"splits",s)
        predict_single_split(seq_id_dir,models_dir,features,reverse_targets_dict,n_models_samples,all_lab_ids,
                  split_savedir,delete_dir,verbose=True)
    
        
            
        
    

    

## Predicting

#### Predicting on test set

In [27]:
%%time
np.random.seed(6533)
seq_id_dir = test_dir
splits_dir = train_splits_dir
model_name = "random_forest"
n_models_samples = 3 #THIS SHOULD BE NONE!!!!!!!
all_lab_ids=all_targets[1:]
splits_range=(0,1)
random_splits=False
savedir = os.path.join(pca_32_95comp_dir,"predictions","test")
delete_dir = False
verbose = True
predict_several_splits(seq_id_dir,splits_dir,model_name,features,reverse_targets_dict,n_models_samples,
                           all_lab_ids,splits_range,random_splits,savedir,delete_dir,verbose)

Processing split split_00:   0%|          | 0/1 [00:00<?, ?it/s]

Creating directory /home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/predictions/test/random_forest/splits/split_00


Processing split split_00: 100%|██████████| 1/1 [31:38<00:00, 1898.51s/it]

CPU times: user 31min 34s, sys: 3.73 s, total: 31min 38s
Wall time: 31min 38s





## Ensembling predictions

In [46]:
#def mean_prediction(preds):
all_lab_ids=all_targets[1:]
def ensemble_single_split_predictions(split_pred_dir,all_lab_ids,decimals=4,savedir=None):
    preds = sorted(os.listdir(split_pred_dir))
    pbar = tqdm(preds)
    df_mean = None
    N = len(preds)
    for p in pbar:
        pbar.set_description(f"Loading file {p}")
        df_pred = pd.read_csv(os.path.join(split_pred_dir,p),index_col="sequence_id")
        if df_mean is None:
            df_mean = 1/N*df_pred
        else:
            df_mean += 1/N*df_pred
    df_mean.reset_index(inplace=True)
    if decimals is not None:
        df_mean.loc[:,all_lab_ids] = np.round(df_mean.loc[:,all_lab_ids].values,decimals=decimals)
        
    if savedir is not None:
        generate_dir(savedir,delete_dir=False)
        savepath = os.path.join(savedir,split_pred_dir.split("/")[-1]+".csv")
        df_mean.to_csv(savepath,index=False)
    return df_mean
    

#### Ensembling



In [50]:
%%time
split_pred_dir = os.path.join(pca_32_95comp_dir,"predictions","test","random_forest","splits","split_00")
decimals=8
savedir = os.path.join(pca_32_95comp_dir,"submissions","random_forest","splits")
ens_df = ensemble_single_split_predictions(split_pred_dir,all_lab_ids,decimals,savedir)

Loading file GWP6E8FA_LXOZJ3TV_L76WWQ74_A3FZPLM1_9HRDSOST_G81LO0AZ_D10S0UDQ_RKJHZGDQ_50NBGIOB.csv: 100%|██████████| 3/3 [00:10<00:00,  3.47s/it]


Directory /home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/submissions/random_forest/splits already exists. I will either overwrite or add files to it.
CPU times: user 29.4 s, sys: 789 ms, total: 30.2 s
Wall time: 31.9 s


#### Loading submission

In [42]:
%%time
submission_path = os.path.join(pca_32_95comp_dir,"submissions","random_forest","splits","split_00.csv")
ens_df = pd.read_csv(submission_path)
ens_df.shape

CPU times: user 3.39 s, sys: 52 ms, total: 3.44 s
Wall time: 3.44 s


(18816, 1315)

In [43]:
ens_df

Unnamed: 0,sequence_id,00Q4V31T,012VT4JK,028IO5W2,03GRNN7N,03Y3W51H,09MQV1TY,0A4AHRCT,0A9M05NC,0B9GCUVV,...,ZQNGGY33,ZSHS4VJZ,ZT1IP3T6,ZU6860XU,ZU6TVFFU,ZU75P59K,ZUI6TDWV,ZWFD8OHC,ZX06ZDZN,ZZJVE4HO
0,002IS,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,...,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000392,0.002058,0.000540
1,004JQ,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,...,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.002440,0.000606,0.000596
2,007Q0,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,...,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.007291,0.000458,0.000634
3,00BHU,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,...,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000687,0.000287,0.000425
4,00BPP,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,...,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.001227,0.000393,0.000547
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18811,ZZIKO,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,...,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000309,0.001142,0.000462
18812,ZZNXV,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,...,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.003875,0.000208,0.000291
18813,ZZPM5,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,...,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.031832,0.000665,0.000296
18814,ZZPZ5,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,...,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.037494,0.000328,0.000279


In [36]:
ens_df.shape

(18816, 1315)

In [54]:
test_path = os.path.join(parent_dir,"original_data","test_values.csv" )
df_test = pd.read_csv(test_path)

In [55]:
df_test

Unnamed: 0,sequence_id,sequence,bacterial_resistance_ampicillin,bacterial_resistance_chloramphenicol,bacterial_resistance_kanamycin,bacterial_resistance_other,bacterial_resistance_spectinomycin,copy_number_high_copy,copy_number_low_copy,copy_number_unknown,...,species_budding_yeast,species_fly,species_human,species_mouse,species_mustard_weed,species_nematode,species_other,species_rat,species_synthetic,species_zebrafish
0,E0VFT,AGATCTATACATTGAATCAATATTGGCAATTAGCCATATTAGTCAT...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
1,TTRK5,GCGCGCGTTGACATTGATTATTGACTAGTTATTAATAGTAATCAAT...,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,2Z7FZ,GCTTAAGCGGTCGACGGATCGGGAGATCTCCCGATCCCCTATGGTG...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,VJI6E,ATGATGATGATGTCCCTGAACAGCAAGCAGGCGTTTAGCATGCCGC...,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,721FI,GGTACCGAGCTCTTACGCGTGCTAGCCATACTATCAGCCACTTGTG...,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18811,4GGKP,CCCGGGGTTATTAATAGTAATCAATTACGGGGTCATTAGTTCATAG...,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
18812,37SHJ,CGAAAAGCCCTGACAACCCTTGTTCCTAAAAAGGAATAAGCGTTCG...,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
18813,JS1MB,GAGCGGCCGCCACTGTGCTGGATATCTGCAGAATTCCACCACACTG...,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
18814,N642G,GCTTTNCTCCGGTGTCACTCCCAGGTCCAACTGCACCTCGGTTCTA...,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [65]:
not_in = []
for seq in tqdm(ens_df.sequence_id):
    if seq not in df_test.sequence_id.values:
        not_in.append(seq)
not_in

100%|██████████| 18816/18816 [00:04<00:00, 4192.75it/s]


['4299',
 '9436',
 '0.0',
 '2.1e+37',
 '389000.0',
 '4560000.0',
 '48800000.0',
 '5.0000000000000004e+44',
 '8.000000000000004e+201']

In [66]:
not_in = []
for seq in tqdm(df_test.sequence_id):
    if seq not in ens_df.sequence_id.values:
        not_in.append(seq)
not_in

100%|██████████| 18816/18816 [00:04<00:00, 4152.44it/s]


['8E201',
 '21E36',
 '0E013',
 '488E5',
 '09436',
 '5E044',
 '456E4',
 '04299',
 '389E3']

In [69]:
t = pd.read_csv(os.path.join(train_val_test_sequence_id_dir,"test","8E201.csv"))

In [72]:
t.info(verbose=True,null_counts=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 139 columns):
 #   Column                                Non-Null Count  Dtype  
---  ------                                --------------  -----  
 0   Unnamed: 0                            100 non-null    int64  
 1   sequence_id                           100 non-null    float64
 2   lab_id                                0 non-null      float64
 3   sequence                              100 non-null    object 
 4   seq_length                            100 non-null    int64  
 5   pca_0                                 100 non-null    float64
 6   pca_1                                 100 non-null    float64
 7   pca_2                                 100 non-null    float64
 8   pca_3                                 100 non-null    float64
 9   pca_4                                 100 non-null    float64
 10  pca_5                                 100 non-null    float64
 11  pca_6              

In [63]:
len(ens_df.sequence_id.unique())

18816

In [62]:
np.sum(ens_df.sequence_id == '8E201')

0

In [37]:
ens_df

Unnamed: 0,sequence_id,00Q4V31T,012VT4JK,028IO5W2,03GRNN7N,03Y3W51H,09MQV1TY,0A4AHRCT,0A9M05NC,0B9GCUVV,...,ZQNGGY33,ZSHS4VJZ,ZT1IP3T6,ZU6860XU,ZU6TVFFU,ZU75P59K,ZUI6TDWV,ZWFD8OHC,ZX06ZDZN,ZZJVE4HO
0,002IS,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,...,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000540,0.000392,0.002058,0.000540
1,004JQ,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,...,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.000596,0.002440,0.000606,0.000596
2,007Q0,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,...,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.000634,0.007291,0.000458,0.000634
3,00BHU,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,...,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000425,0.000687,0.000287,0.000425
4,00BPP,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,...,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.000547,0.001227,0.000393,0.000547
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18811,ZZIKO,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,...,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000462,0.000309,0.001142,0.000462
18812,ZZNXV,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,...,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.000291,0.003875,0.000208,0.000291
18813,ZZPM5,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,...,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.000296,0.031832,0.000665,0.000296
18814,ZZPZ5,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,...,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.000279,0.037494,0.000328,0.000279


In [44]:
ens_df.max(axis=1)

0        0.100216
1        0.119106
2        0.063625
3        0.162366
4        0.079060
           ...   
18811    0.137209
18812    0.196369
18813    0.159507
18814    0.156420
18815    0.038794
Length: 18816, dtype: float64

In [21]:
d = os.path.join(pca_32_95comp_dir,"predictions","test","random_forest","split_00")
d

'/home/rio/data_sets/genetic_engineering_attribution/pca_engineered_datasets/pca_32_95comp/predictions/test/random_forest/split_00'

In [22]:
d.split("/")[-1]

'split_00'

In [12]:
predictions_dir

NameError: name 'predictions_dir' is not defined