In [None]:
import pandas as pd
import torch
import torchdms
import torchdms.model
from dms_variants.binarymap import BinaryMap
from torchdms.analysis import Analysis
from torchdms.binarymap import DataFactory

%matplotlib notebook

In [None]:
[aa_func_scores, wtseq] = torchdms.binarymap.from_pickle_file("../_ignore/aa_func_scores_and_wtseq.pkl")

Let's split things up so that we get `per_mutation_count_variants` variants per mutation count.

In [None]:
per_mutation_count_variants = 500
max_mutation_count = 5

indices_split_by_mutation_count = {
    mutation_count: 
        aa_func_scores.loc[aa_func_scores["n_aa_substitutions"] == mutation_count, "n_aa_substitutions"].index
    for mutation_count in range(1,1+max_mutation_count)}

indicator = pd.Series([False]*len(aa_func_scores))
for indices in indices_split_by_mutation_count.values():
    indicator[indices[:per_mutation_count_variants]] = True
    
assert indicator.sum() == per_mutation_count_variants * max_mutation_count

Record the number of mutations for our held-out set:

In [None]:
indicator_classes = []

for mutation_count in indices_split_by_mutation_count.keys():
    indicator_classes.extend([mutation_count]*per_mutation_count_variants)

In [None]:
def bmap_split_of_indicator(indicator):
    not_indicator = [not b for b in indicator]
    return (
        BinaryMap(aa_func_scores.loc[indicator,], expand=True, wtseq=wtseq),
        BinaryMap(aa_func_scores.loc[not_indicator,], expand=True, wtseq=wtseq))

test_data, train_data = bmap_split_of_indicator(indicator)

In [None]:
model = torchdms.model.SingleSigmoidNet(input_size=train_data.binarylength, hidden1_size=1)
analysis = Analysis(model, train_data)

In [None]:
%%time

criterion = torch.nn.MSELoss()
analysis.train(criterion, 300)
pd.Series(analysis.losses).plot()

In [None]:
results = analysis.evaluate(test_data)
results["n_aa_mutations"] = indicator_classes

In [None]:
ax = results.plot.scatter(x="Observed", y="Predicted", c = results.n_aa_mutations, cmap='viridis')

In [None]:
ax = results.plot.hexbin(x="Observed", y="Predicted")

In [None]:
results.corr().iloc[0,1]