In [1]:
import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.metrics import log_loss
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn import model_selection

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [130]:
train_features = pd.read_csv('/content/drive/My Drive/moa/train_features.csv')
train_targets_scored = pd.read_csv('/content/drive/My Drive/moa/train_targets_scored.csv')
sample_submission = pd.read_csv('/content/drive/My Drive/moa/sample_submission.csv')
test_features = pd.read_csv('/content/drive/My Drive/moa/test_features.csv')

In [131]:
GENES = [col for col in train_features.columns if col.startswith('g-')]
CELLS = [col for col in train_features.columns if col.startswith('c-')]

In [132]:
train_features["cp_type"].value_counts()
# Can we drop cp_type column? ctl_vehicle is 8% from total.

trt_cp         21948
ctl_vehicle     1866
Name: cp_type, dtype: int64

In [133]:
train = train_features.merge(train_targets_scored, on='sig_id')
test = test_features
# If we choose to drop train_features[train['cp_type']=='ctl_vehicle'], uncomment.
# train = train[train['cp_type']!='ctl_vehicle'].reset_index(drop=True)
# test = test_features[test_features['cp_type']!='ctl_vehicle'].reset_index(drop=True)

target = train[train_targets_scored.columns]
train = train.drop('cp_type', axis=1)               # train["cp_type"].unique() = 'trt_cp'. We cant pass cp_type without encode.
train = train.drop('sig_id', axis=1)

# target                      # 23814 rows × 207 columns. # Its actually the same as train_targets_scored, if we didnt preprocess anythig.

In [134]:
class MoADataset:
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets
        
    def __len__(self):              # len() will use the __len__ method if present to get your object for its length.  
        return (self.features.shape[0])
    
    def __getitem__(self, idx):     # docs: https://docs.python.org/3/reference/datamodel.html#object.__getitem__. In this case returns a dict.
        dct = { 
            'x' : torch.tensor(self.features[idx, :], dtype=torch.float),       # ex: np_array[0, :] -> [1,2]
            'y' : torch.tensor(self.targets[idx, :], dtype=torch.float)            
        }
        return dct
    
class TestDataset:
    def __init__(self, features):
        self.features = features
        
    def __len__(self):
        return (self.features.shape[0])
    
    def __getitem__(self, idx):
        dct = {
            'x' : torch.tensor(self.features[idx, :], dtype=torch.float)
        }
        return dct
    

In [135]:
def train_fn(model, optimizer, scheduler, loss_fn, dataloader, device):
    model.train()
    final_loss = 0
    
    for data in dataloader:
        optimizer.zero_grad()
        inputs, targets = data['x'].to(device), data['y'].to(device)        # Asks for the value of "x" and "y" keys.
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        final_loss += loss.item()
        
    final_loss /= len(dataloader)
    
    return final_loss   

def valid_fn(model, loss_fn, dataloader, device):
    model.eval()
    final_loss = 0
    valid_preds = []
    
    for data in dataloader:
        inputs, targets = data['x'].to(device), data['y'].to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        
        final_loss += loss.item()
        valid_preds.append(outputs.sigmoid().detach().cpu().numpy())
        
    final_loss /= len(dataloader)
    valid_preds = np.concatenate(valid_preds)
    
    return final_loss, valid_preds

def inference_fn(model, dataloader, device):
    model.eval()
    preds = []
    
    for data in dataloader:
        inputs = data['x'].to(device)

        with torch.no_grad():
            outputs = model(inputs)
        
        preds.append(outputs.sigmoid().detach().cpu().numpy())
        
    preds = np.concatenate(preds)
    
    return preds

In [136]:
# process_data(data) uses get_dummies() to create cp_time: 24, 48, 72. cp_dose: D1, D2
def process_data(data):
    data = pd.get_dummies(data, columns=['cp_time','cp_dose'])              
   
    return data

In [137]:
# Simply target without id.
target_cols = target.drop('sig_id', axis=1).columns.tolist()
# We use this comprehension to take into account the dummies created by our process_data()
feature_cols = [col for col in process_data(train).columns if col not in target_cols]

In [138]:
# HyperParameters
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')
EPOCHS = 2
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5
NFOLDS = 5
EARLY_STOPPING_STEPS = 10
EARLY_STOP = False

num_features=len(feature_cols)
num_targets=len(target_cols)
hidden_size=1024

In [139]:
class Model(nn.Module):
    def __init__(self, num_features, num_targets, hidden_size):
        super(Model, self).__init__()
        self.batch_norm1 = nn.BatchNorm1d(num_features)
        self.dropout1 = nn.Dropout(0.2)
        self.dense1 = nn.utils.weight_norm(nn.Linear(num_features, hidden_size))
        
        self.batch_norm2 = nn.BatchNorm1d(hidden_size)
        self.dropout2 = nn.Dropout(0.5)
        self.dense2 = nn.utils.weight_norm(nn.Linear(hidden_size, hidden_size))
        
        self.batch_norm3 = nn.BatchNorm1d(hidden_size)
        self.dropout3 = nn.Dropout(0.5)
        self.dense3 = nn.utils.weight_norm(nn.Linear(hidden_size, num_targets))
    
    def forward(self, x):
        x = self.batch_norm1(x)
        x = self.dropout1(x)
        x = F.relu(self.dense1(x))
        
        x = self.batch_norm2(x)
        x = self.dropout2(x)
        x = F.relu(self.dense2(x))
        
        x = self.batch_norm3(x)
        x = self.dropout3(x)
        x = self.dense3(x)
        
        return x

In [140]:
folds = train.copy()
folds = folds.sample(frac=1).reset_index(drop=True)             # to randomize
kf = model_selection.KFold(n_splits=5)
for fold, (t_idx, v_idx) in enumerate(kf.split(X=folds)):
    folds.loc[v_idx, 'kfold'] = fold
folds['kfold'] = folds['kfold'].astype(int)           # Otherwise 0.0, 1.0, 2.0, 3.0...

train = process_data(folds)

test_df = process_data(test)
x_test  = test_df[feature_cols].values
test_dataset = TestDataset(x_test)
testloader = torch.utils.data.DataLoader(test_dataset)

predictions = np.zeros((3982,206))

In [141]:
def run_training(fold):
    trn_idx = train[train['kfold'] != fold].index
    val_idx = train[train['kfold'] == fold].index

    train_df = train[train['kfold'] != fold].reset_index(drop=True)
    valid_df = train[train['kfold'] == fold].reset_index(drop=True)

    x_train, y_train  = train_df[feature_cols].values, train_df[target_cols].values
    x_valid, y_valid =  valid_df[feature_cols].values, valid_df[target_cols].values

    train_dataset = MoADataset(x_train, y_train)
    valid_dataset = MoADataset(x_valid, y_valid)

    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)      # It has a len of 187. 187 * 128(BATCH_SIZE) = 23936. Contains the 23814 rows of the train_df.
    validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = Model(
        num_features=num_features,
        num_targets=num_targets,
        hidden_size=hidden_size,
    )

    model.to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e3,
                                                max_lr=1e-2, epochs=EPOCHS, steps_per_epoch=len(trainloader))
    loss_fn = nn.BCEWithLogitsLoss()

    # oof = np.zeros((len(train), target.iloc[:, 1:].shape[1]))
    best_loss = np.inf                  # Represents a positive infinite

    for epoch in range(EPOCHS):
        train_loss = train_fn(model, optimizer, scheduler, loss_fn, trainloader, DEVICE)
        print(f"EPOCH: {epoch}, train_loss: {train_loss}")
        valid_loss, valid_preds = valid_fn(model, loss_fn, validloader, DEVICE)
        print(f"EPOCH: {epoch}, valid_loss: {valid_loss}")
        
        if valid_loss < best_loss:
            print(f"updating best model on Fold={fold}") 
            best_loss = valid_loss
            # oof[val_idx] = valid_preds
            torch.save(model.state_dict(), f"FOLD{fold}_.pth")
    
    fold_preds = inference_fn(model, testloader, DEVICE)
    global predictions
    predictions = predictions + fold_preds
    print(predictions)    
      
for run_k_fold in range(5):              # 5 folds
    run_training(run_k_fold)

predictions /= 5

EPOCH: 0, train_loss: 0.1403922475811919
EPOCH: 0, valid_loss: 0.01940150406995886
updating best model on Fold=0
EPOCH: 1, train_loss: 0.018808677955061796
EPOCH: 1, valid_loss: 0.018533614660172087
updating best model on Fold=0
[[0.00140438 0.00181455 0.00177666 ... 0.00169914 0.00157534 0.00210482]
 [0.0010858  0.00136895 0.00182723 ... 0.00149622 0.00329186 0.00186699]
 [0.00087761 0.00052244 0.00103695 ... 0.00087652 0.00130898 0.00114798]
 ...
 [0.00114555 0.00100715 0.00110471 ... 0.00159549 0.00307395 0.00131794]
 [0.00126276 0.00106065 0.00151006 ... 0.00107573 0.00116278 0.00211597]
 [0.00092951 0.00114269 0.00131762 ... 0.00140431 0.00208056 0.00141018]]
EPOCH: 0, train_loss: 0.140029293641248
EPOCH: 0, valid_loss: 0.01926086048938726
updating best model on Fold=1
EPOCH: 1, train_loss: 0.018792859334873672
EPOCH: 1, valid_loss: 0.018364380187305965
updating best model on Fold=1
[[0.00268737 0.00354354 0.00328327 ... 0.00330219 0.00365534 0.00412797]
 [0.00225714 0.00313579 0.

In [142]:
sample_submission = pd.read_csv('/content/drive/My Drive/moa/sample_submission.csv')

In [143]:
y = pd.DataFrame(data=predictions)
y.columns = target_cols

In [144]:
sub = sample_submission.drop(columns=target_cols)
frames = [sub, y]
sub = pd.concat(frames, axis=1)

In [145]:
submission = sub.to_csv("submission.csv",index=False)
teste = pd.read_csv("submission.csv")
teste

Unnamed: 0,sig_id,5-alpha_reductase_inhibitor,11-beta-hsd1_inhibitor,acat_inhibitor,acetylcholine_receptor_agonist,acetylcholine_receptor_antagonist,acetylcholinesterase_inhibitor,adenosine_receptor_agonist,adenosine_receptor_antagonist,adenylyl_cyclase_activator,adrenergic_receptor_agonist,adrenergic_receptor_antagonist,akt_inhibitor,aldehyde_dehydrogenase_inhibitor,alk_inhibitor,ampk_activator,analgesic,androgen_receptor_agonist,androgen_receptor_antagonist,anesthetic_-_local,angiogenesis_inhibitor,angiotensin_receptor_antagonist,anti-inflammatory,antiarrhythmic,antibiotic,anticonvulsant,antifungal,antihistamine,antimalarial,antioxidant,antiprotozoal,antiviral,apoptosis_stimulant,aromatase_inhibitor,atm_kinase_inhibitor,atp-sensitive_potassium_channel_antagonist,atp_synthase_inhibitor,atpase_inhibitor,atr_kinase_inhibitor,aurora_kinase_inhibitor,...,protein_synthesis_inhibitor,protein_tyrosine_kinase_inhibitor,radiopaque_medium,raf_inhibitor,ras_gtpase_inhibitor,retinoid_receptor_agonist,retinoid_receptor_antagonist,rho_associated_kinase_inhibitor,ribonucleoside_reductase_inhibitor,rna_polymerase_inhibitor,serotonin_receptor_agonist,serotonin_receptor_antagonist,serotonin_reuptake_inhibitor,sigma_receptor_agonist,sigma_receptor_antagonist,smoothened_receptor_antagonist,sodium_channel_inhibitor,sphingosine_receptor_agonist,src_inhibitor,steroid,syk_inhibitor,tachykinin_antagonist,tgf-beta_receptor_inhibitor,thrombin_inhibitor,thymidylate_synthase_inhibitor,tlr_agonist,tlr_antagonist,tnf_inhibitor,topoisomerase_inhibitor,transient_receptor_potential_channel_antagonist,tropomyosin_receptor_kinase_inhibitor,trpv_agonist,trpv_antagonist,tubulin_inhibitor,tyrosine_kinase_inhibitor,ubiquitin_specific_protease_inhibitor,vegfr_inhibitor,vitamin_b,vitamin_d_receptor_agonist,wnt_inhibitor
0,id_0004d9e33,0.001219,0.001695,0.001457,0.010081,0.017208,0.003812,0.002288,0.005983,0.001107,0.014914,0.018806,0.003097,0.001024,0.002245,0.000989,0.000971,0.002484,0.004664,0.005051,0.001791,0.002332,0.002554,0.001106,0.001438,0.001253,0.001177,0.001121,0.001223,0.003205,0.001935,0.001566,0.002354,0.002188,0.000978,0.000921,0.001142,0.004203,0.001362,0.003603,...,0.003469,0.001257,0.003102,0.001050,0.001061,0.003556,0.000967,0.001854,0.001325,0.001232,0.011406,0.020664,0.002095,0.001805,0.001942,0.001091,0.016058,0.001321,0.002330,0.000886,0.001280,0.003147,0.001443,0.001505,0.001790,0.001681,0.000971,0.001750,0.003217,0.001405,0.001007,0.001614,0.002768,0.006160,0.003242,0.001003,0.008424,0.001496,0.001846,0.001703
1,id_001897cda,0.000974,0.001683,0.001936,0.007444,0.011880,0.005587,0.002342,0.007115,0.001501,0.014217,0.015911,0.004770,0.001484,0.002840,0.001507,0.001389,0.002333,0.004168,0.005180,0.002476,0.002219,0.002644,0.001428,0.002124,0.001523,0.001517,0.001368,0.002306,0.003861,0.002703,0.001526,0.003180,0.002271,0.001567,0.001327,0.002068,0.006619,0.002292,0.008526,...,0.004899,0.002259,0.002617,0.004366,0.001645,0.004023,0.001449,0.002131,0.003393,0.001736,0.011377,0.025690,0.002300,0.001972,0.002842,0.001558,0.013640,0.001516,0.004545,0.001303,0.001787,0.002416,0.001227,0.001773,0.001831,0.001693,0.001499,0.002640,0.006517,0.001744,0.001399,0.002510,0.002687,0.010270,0.006757,0.001415,0.014903,0.001616,0.003091,0.002267
2,id_002429b5b,0.000748,0.000547,0.001109,0.007365,0.008920,0.001894,0.002201,0.002065,0.000670,0.008350,0.012754,0.000741,0.000553,0.000850,0.000880,0.000964,0.001220,0.002591,0.001287,0.001436,0.001215,0.003253,0.000570,0.001792,0.000628,0.000685,0.000679,0.000776,0.002758,0.001446,0.000873,0.001769,0.002147,0.000619,0.000636,0.000575,0.002123,0.000569,0.000896,...,0.003135,0.001007,0.001566,0.002918,0.000720,0.001357,0.000695,0.001061,0.001364,0.001066,0.007775,0.008841,0.001686,0.001025,0.000890,0.001520,0.006072,0.001312,0.001352,0.000830,0.000646,0.001189,0.001455,0.000680,0.001097,0.001011,0.000828,0.001058,0.002133,0.000652,0.000717,0.000650,0.001562,0.004057,0.001425,0.000612,0.002203,0.001260,0.001577,0.001325
3,id_00276f245,0.000725,0.000995,0.001426,0.006804,0.008114,0.003143,0.001670,0.003749,0.001024,0.009819,0.012396,0.002346,0.000923,0.002036,0.001002,0.001094,0.001331,0.003152,0.002418,0.001591,0.001551,0.002021,0.000923,0.001487,0.000930,0.000971,0.000816,0.001433,0.002511,0.001760,0.000851,0.002063,0.001561,0.000963,0.000856,0.001250,0.004313,0.001279,0.004168,...,0.003582,0.001456,0.001830,0.004450,0.001032,0.002376,0.000956,0.001454,0.002049,0.000912,0.007084,0.015397,0.001846,0.001086,0.001601,0.001213,0.008665,0.001075,0.002888,0.000916,0.001106,0.001611,0.000874,0.001114,0.001364,0.001069,0.000997,0.001416,0.003838,0.000960,0.000862,0.001282,0.002007,0.009754,0.003840,0.000870,0.009858,0.001150,0.002006,0.001719
4,id_0027f1083,0.001435,0.000830,0.001231,0.008723,0.012680,0.002166,0.002730,0.002622,0.000801,0.011267,0.018331,0.001178,0.000749,0.001086,0.001100,0.000947,0.001870,0.003642,0.002194,0.001786,0.001647,0.004360,0.000701,0.001933,0.000853,0.000839,0.001028,0.000747,0.002982,0.001738,0.001552,0.002205,0.002786,0.000745,0.000764,0.000682,0.002515,0.000741,0.001277,...,0.003819,0.001072,0.002504,0.002190,0.000890,0.001933,0.000801,0.001661,0.001326,0.001789,0.011163,0.012132,0.002002,0.001836,0.001212,0.001747,0.010434,0.001711,0.001885,0.000918,0.000867,0.001944,0.002314,0.000888,0.001655,0.001792,0.000869,0.001632,0.002953,0.000938,0.000951,0.001030,0.001874,0.003957,0.001531,0.000775,0.002210,0.001639,0.001589,0.001398
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3977,id_ff7004b87,0.000815,0.001745,0.001424,0.006447,0.009276,0.003688,0.001267,0.005380,0.001265,0.011393,0.011630,0.007181,0.001334,0.003926,0.000869,0.000896,0.001613,0.003623,0.004110,0.001369,0.002030,0.001310,0.001218,0.001114,0.001180,0.001097,0.000974,0.001562,0.001844,0.001515,0.000999,0.002062,0.001321,0.001099,0.000883,0.002112,0.005990,0.002417,0.011279,...,0.003186,0.001195,0.002451,0.001981,0.001148,0.003723,0.000979,0.001716,0.001926,0.000789,0.007125,0.017520,0.001635,0.001273,0.001959,0.000989,0.014155,0.000902,0.004638,0.000751,0.001644,0.002647,0.000715,0.001536,0.001618,0.001215,0.000896,0.001714,0.004462,0.001286,0.000915,0.002308,0.002448,0.015452,0.006116,0.000967,0.022595,0.001141,0.001660,0.001615
3978,id_ff925dd0d,0.001803,0.001192,0.001258,0.008481,0.019224,0.003062,0.003143,0.003657,0.000868,0.015002,0.021788,0.001463,0.000801,0.001213,0.001116,0.000955,0.002959,0.004695,0.004012,0.002022,0.001790,0.004815,0.000826,0.002026,0.001046,0.001053,0.001103,0.000852,0.004080,0.002156,0.002027,0.002428,0.003161,0.000817,0.000875,0.000739,0.002997,0.000812,0.001417,...,0.004138,0.001227,0.003002,0.001513,0.000935,0.002607,0.000850,0.001641,0.001298,0.002020,0.013552,0.018315,0.002149,0.002224,0.001561,0.001432,0.013908,0.001640,0.002033,0.000937,0.000990,0.002795,0.002601,0.001112,0.001726,0.002153,0.000936,0.001768,0.003182,0.001202,0.000946,0.001201,0.002005,0.003254,0.001854,0.000875,0.002650,0.001609,0.001719,0.001400
3979,id_ffb710450,0.001092,0.000972,0.001124,0.006341,0.017261,0.004056,0.002616,0.004635,0.000806,0.013957,0.019750,0.001414,0.000791,0.001096,0.001175,0.000875,0.002324,0.003896,0.004249,0.001982,0.001353,0.003988,0.000800,0.001788,0.000946,0.000950,0.000971,0.001082,0.004056,0.002393,0.001571,0.002522,0.002697,0.000864,0.000845,0.000827,0.003673,0.000853,0.001796,...,0.003926,0.001505,0.002216,0.002331,0.000957,0.002582,0.000873,0.001501,0.001692,0.001894,0.013088,0.025715,0.001928,0.001830,0.001672,0.001193,0.012410,0.001427,0.002033,0.000940,0.000894,0.001998,0.001465,0.001051,0.001343,0.001663,0.000991,0.002182,0.003740,0.001117,0.000941,0.001306,0.001633,0.003430,0.002462,0.000859,0.003119,0.001356,0.002160,0.001406
3980,id_ffbb869f2,0.001162,0.001280,0.001781,0.009510,0.017216,0.002831,0.003254,0.002749,0.001156,0.007804,0.013928,0.001353,0.000891,0.002108,0.000864,0.001920,0.002970,0.004673,0.001757,0.001263,0.002635,0.003526,0.000964,0.002678,0.001216,0.001561,0.000778,0.001273,0.003163,0.001547,0.001195,0.001792,0.002676,0.000924,0.000940,0.000816,0.002510,0.001008,0.000825,...,0.004446,0.001279,0.002209,0.003667,0.000810,0.001733,0.000903,0.001093,0.001335,0.000891,0.009533,0.014655,0.001588,0.000882,0.001813,0.001783,0.004531,0.001243,0.002319,0.001140,0.001274,0.002868,0.001977,0.001415,0.002253,0.001261,0.001077,0.000673,0.002789,0.001086,0.000818,0.000852,0.002749,0.004810,0.003416,0.000850,0.005649,0.001608,0.001734,0.001797
