In [None]:
import pandas as pd
import torch
import torchdms
from torchdms.binarymap import DataFactory
from torchdms.model import SingleSigmoidNet

%matplotlib notebook

In [None]:
bmap_by_mutation_count = torchdms.binarymap.from_pickle_file("../_ignore/bmap_by_mutation_count.pkl")

def optional_data_factory(bmap):
    if bmap:
        return DataFactory(bmap)
    return None

factory_by_mutation_count = [
    optional_data_factory(bmap_by_mutation_count[i]) for i in range(len(bmap_by_mutation_count))]

training_factory = factory_by_mutation_count[2]
testing_factory = factory_by_mutation_count[1]

net = SingleSigmoidNet(input_size=training_factory.feature_count(), hidden1_size=1)

In [None]:
learning_rate = 1e-3
epoch_count = 300
batch_size = 500

In [None]:
%%time

net.train()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
criterion = torch.nn.MSELoss()
net, losses = torchdms.model.train_network(
    net,
    training_factory,
    criterion,
    optimizer,
    epoch_count,
    batch_size,
    get_train_loss=True,
)

pd.Series(losses).plot()

In [None]:
results = pd.DataFrame({
    "Observed": testing_factory.Y.numpy(), 
    "Predicted": net(testing_factory.X).detach().numpy().transpose()[0]})

In [None]:
ax = results.plot.scatter(x="Observed", y="Predicted")

In [None]:
ax = results.plot.hexbin(x="Observed", y="Predicted", title="Training on 2x mutants, testing on 1x")

In [None]:
results.corr().iloc[0,1]