In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.distributions import Multinomial
import biom

In [2]:
# some example data
microbes = biom.load_table("./soil_microbes.biom")
metabolites = biom.load_table("./soil_metabolites.biom")

In [3]:
class MicrobeMetaboliteData(Dataset):
    def __init__(self, microbes: biom.table, metabolites: biom.table):
        # arrange
        self.microbes = microbes.to_dataframe().T   
        self.metabolites = metabolites.to_dataframe().T
        
        # only samples that have results
        self.microbes = self.microbes.loc[self.metabolites.index]
      
        # convert to tensors/final form
        self.microbes = torch.tensor(self.microbes.values, dtype=torch.int)
        self.metabolites = torch.tensor(self.metabolites.values, dtype=torch.int64)
        
        # counts
        self.microbe_count = self.microbes.shape[1]
        self.metabolite_count = self.metabolites.shape[1]
        
        # relative frequencies
        self.microbe_relative_frequency = (self.microbes.T
                                      / self.microbes.sum(1)
                                     ).T
        
        self.metabolite_relative_frequency = (self.metabolites.T
                                     / self.metabolites.sum(1)
                                    ).T
        
        self.total_microbe_observations = self.microbes.sum()
       
    def __len__(self):
        return self.total_microbe_observations

In [4]:
example_data = MicrobeMetaboliteData(microbes, metabolites)

In [5]:
example_data.total_microbe_observations.item()

424846

In [6]:
class MMVec(nn.Module):
    def __init__(self, num_microbes, num_metabolites, latent_dim):
        super().__init__()
        #
        self.encoder = nn.Embedding(num_microbes, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, num_metabolites),
            # [batch, sample, metabolite]
            nn.Softmax(dim=2)
        )
        
    # X = batch_size of microbe indexes
    # Y = expected metabolite data
    def forward(self, X, Y):
        
        # pass our random draws to our embedding
        z = self.encoder(X)
        
        # from latent dimensions in embedding through
        # our linear function to predicted metabolite frequencies which
        # we then normalize with softmax
        y_pred = self.decoder(z)
        
        # total_count=0 and validate_args=False allows skipping total count when calling log_prob
        # as there having floating point issues leading to "incorrect" total counts.
        # This multinomial is generated from the output of the single
        forward_dist = Multinomial(total_count=0,
                                  validate_args=False,
                                  probs=y_pred)
        
        # the log probability of drawing our expected results from our "predictions"
        forward_dist = forward_dist.log_prob(Y)
        
        # get sample loss, a sample in each "row"/ zeroeth dimension of the tensor
        forward_dist = forward_dist.mean(0)
        
        # total log probability loss in regards to all samples
        lp = forward_dist.mean()

        return lp

In [7]:
mmvec_model = MMVec(example_data.microbe_count, example_data.metabolite_count, 15)

In [8]:
def train_loop(dataset, model, optimizer, batch_size):
    
    # because we are wanting to look at all of the samples together we are having to 
    # handle our own batching for now. This method currently leads to slight over-
    # sampling but can be refined.
    n_batches = torch.div(dataset.total_microbe_observations.item(),
                          batch_size,
                          rounding_mode = 'floor') + 1
    
    # We will want to implement batching functionality later for
    # paralizability<tm>, but for now running on cpu this works.
    for batch in range(n_batches * epochs):
        
        # the draws we will be training each batch on that will
        # be fed to all samples in our model. This step will probably be
        # moved to a sampler or collate_fn somewhere in the dataset/dataloader
        # but how exactly that will work is not clear at the moment
        draws = torch.multinomial(dataset.microbe_relative_frequency,
                                  batch_size,
                                  replacement=True).T
        
        # "forward step", our model generates our "predictions", so there is no need to
        # call `forward` separately.
        lp = model(draws,
                   dataset.metabolite_relative_frequency)
        
        # this location is idiomatic but flexible
        optimizer.zero_grad()
        
        # the typical training bit.
        lp.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            print(f"loss: {lp.item()}\nBatch #: {batch}")

In [None]:
learning_rate = 1e-3
batch_size = 500
epochs = 25
optimizer = torch.optim.Adam(mmvec_model.parameters(), lr=learning_rate, maximize=True)

# run the training loop    
train_loop(dataset=example_data, model=mmvec_model, optimizer=optimizer, batch_size=batch_size)

loss: -4.114527225494385
Batch #: 0
loss: -3.6144325733184814
Batch #: 100
loss: -3.0469698905944824
Batch #: 200
loss: -2.70939564704895
Batch #: 300
loss: -2.5499744415283203
Batch #: 400
loss: -2.473045587539673
Batch #: 500
loss: -2.4374732971191406
Batch #: 600
loss: -2.421781539916992
Batch #: 700
loss: -2.4101920127868652
Batch #: 800
loss: -2.4041030406951904
Batch #: 900
loss: -2.4012131690979004
Batch #: 1000
loss: -2.397974967956543
Batch #: 1100
loss: -2.3931915760040283
Batch #: 1200
loss: -2.3923048973083496
Batch #: 1300
loss: -2.389982223510742
Batch #: 1400
loss: -2.3868303298950195
Batch #: 1500
loss: -2.3855628967285156
Batch #: 1600
loss: -2.382643222808838
Batch #: 1700
loss: -2.381664991378784
Batch #: 1800
loss: -2.3774473667144775
Batch #: 1900
loss: -2.378610372543335
Batch #: 2000
loss: -2.3776485919952393
Batch #: 2100
loss: -2.376375675201416
Batch #: 2200
loss: -2.3723671436309814
Batch #: 2300
loss: -2.372851848602295
Batch #: 2400
loss: -2.373134136199951