In [None]:
import pandas as pd
import torch
import torchdms
import torchdms.model
from dms_variants.binarymap import BinaryMap
from torchdms.analysis import Analysis

%matplotlib notebook

In [None]:
[aa_func_scores, wtseq] = torchdms.binarymap.from_pickle_file("../_ignore/aa_func_scores_and_wtseq.pkl")
[aa_func_scores, wtseq] = torchdms.binarymap.from_pickle_file("../_ignore/gb1.pkl")
aa_func_scores["n_aa_substitutions"] = [len(s.split()) for s in aa_func_scores["aa_substitutions"]]

Let's split things up so that we get `per_mutation_count_variants` variants per mutation count.

In [None]:
per_mutation_count_variants = 250
mutation_count_limit = 5

aa_func_scores["in_test"] = False

for mutation_count, grouped in aa_func_scores.groupby("n_aa_substitutions"):
    if mutation_count > mutation_count_limit:
        break
    to_put_in_test = grouped.sample(n=per_mutation_count_variants).index
    aa_func_scores.loc[to_put_in_test, "in_test"] = True

max_mutation_count = min(mutation_count_limit, max(aa_func_scores["n_aa_substitutions"]))   
assert aa_func_scores["in_test"].sum() == per_mutation_count_variants * max_mutation_count

Record the number of mutations for our held-out set:

In [None]:
def bmap_split(in_test):
    return BinaryMap(
            aa_func_scores.loc[aa_func_scores["in_test"] == in_test,], 
            expand=True, wtseq=wtseq)

test_data = bmap_split(True)
train_data = bmap_split(False)

assert test_data.nvariants == aa_func_scores["in_test"].sum()

In [None]:
model = torchdms.model.SingleSigmoidNet(input_size=train_data.binarylength, hidden1_size=1)
analysis = Analysis(model, train_data)

In [None]:
%%time

criterion = torch.nn.MSELoss()
analysis.train(criterion, 300)
pd.Series(analysis.losses).plot()

In [None]:
results = analysis.evaluate(test_data)
results["n_aa_substitutions"] = \
    aa_func_scores.loc[aa_func_scores["in_test"] == True, "n_aa_substitutions"].reset_index(drop=True)
results.plot.scatter(x="Observed", y="Predicted", c = results["n_aa_substitutions"], cmap='viridis')

In [None]:
results.corr().iloc[0,1]