In [1]:
from dask.distributed import LocalCluster, Client
import numpy as np
import pandas as pd
import janitor


cluster = LocalCluster()
client = Client(cluster)

In [2]:
from utils import molecular_weights, featurize_sequence_

In [3]:
drugs = ['ATV', 'DRV', 'FPV', 'IDV', 'LPV', 'NFV', 'SQV', 'TPV']

data = (
    pd.read_csv("data/hiv-protease-data-expanded.csv", index_col=0)
    .query("weight == 1.0")
    .transform_column("sequence", lambda x: len(x), "seq_length")
    .query("seq_length == 99")
    .transform_column("sequence", featurize_sequence_, "features")
    .transform_columns(drugs, np.log10)
)
data.head(3)

features = pd.DataFrame(np.vstack(data['features'])).set_index(data.index)

In [4]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score


def fit_model(data, features, target):
    import janitor
    model = RandomForestRegressor(n_estimators=300)
    
    resistance_data = features.join(data[target]).dropna()
    X, y = resistance_data.get_features_targets(target_column_names=target)
    
    model.fit(X, y)
    return model


def cross_validate(data, features, target):
    import janitor
    model = RandomForestRegressor(n_estimators=500)
    
    resistance_data = features.join(data[target]).dropna()
    X, y = resistance_data.get_features_targets(target_column_names=target)
    
    return -cross_val_score(model, X, y, scoring='neg_mean_squared_error', cv=5)


def predict(model, sequence):
    """
    :param model: sklearn model
    :param sequence: A string, should be 99 characters long.
    """
    assert len(sequence) == 99
    assert set(sequence) == set(molecular_weights.keys())
    
    seqfeat = featurize_sequence_(sequence)
    return model.predict(seqfeat)
    
    

dataf = client.scatter(data)
featuresf = client.scatter(features)


models = dict()
scores = dict()


for drug in drugs:
    models[drug] = client.submit(fit_model, dataf, featuresf, drug)
    scores[drug] = client.submit(cross_validate, dataf, featuresf, drug)
    
models = client.gather(models)

In [7]:
import pickle as pkl
import gzip

for name, model in models.items():
    with gzip.open(f"data/models/{name}.pkl.gz", 'wb') as f:
        pkl.dump(model, f)

In [8]:
scores = client.gather(scores)
with gzip.open("data/scores.pkl.gz", "wb") as f:
    pkl.dump(scores, f)