In [12]:
from proteovae.models.base import Encoder, Guide, Decoder
from proteovae.models import GuidedConfig, GuidedVAE

import torch 
import torch.nn as nn


specify the configuration of the guided model 

In [10]:
input_dim = 32
latent_dim = 10
guided_dim = 1

g = GuidedConfig(input_dim = input_dim, latent_dim = latent_dim, guided_dim = guided_dim)

In [11]:
enc = Encoder(input_dim = input_dim, latent_dim=latent_dim, hidden_dims = [16,])
enc

Encoder(
  (linear_block): Sequential(
    (0): Linear(in_features=32, out_features=16, bias=True)
    (1): ReLU()
  )
  (fc_mu): Linear(in_features=16, out_features=10, bias=True)
  (fc_logvar): Linear(in_features=16, out_features=10, bias=True)
)

In [16]:
class CustomDecoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.fwd_block = nn.Sequential(
            nn.Linear(latent_dim, 2*latent_dim),
            nn.Tanh(),
            nn.Linear(2*latent_dim, input_dim),
        )

    def forward(self, x):
        return self.fwd_block(x)
    

custom_dec = CustomDecoder()
custom_dec

CustomDecoder(
  (fwd_block): Sequential(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=32, bias=True)
  )
)

In [19]:
guided_vae = GuidedVAE(config = g, encoder = enc, decoder = custom_dec, guide = Guide(1, 2))
guided_vae

GuidedVAE(
  (encoder): Encoder(
    (linear_block): Sequential(
      (0): Linear(in_features=32, out_features=16, bias=True)
      (1): ReLU()
    )
    (fc_mu): Linear(in_features=16, out_features=10, bias=True)
    (fc_logvar): Linear(in_features=16, out_features=10, bias=True)
  )
  (decoder): CustomDecoder(
    (fwd_block): Sequential(
      (0): Linear(in_features=10, out_features=20, bias=True)
      (1): ReLU()
      (2): Linear(in_features=20, out_features=32, bias=True)
    )
  )
  (guide): Guide(
    (classifier): Sequential(
      (0): Linear(in_features=1, out_features=2, bias=True)
    )
  )
)