In [None]:
!pip install proteovae

In [23]:
from proteovae.models import *
import proteovae.models.base as base
from torchvision import datasets
from torchvision.transforms import ToTensor, Compose, Lambda 
from torch.utils.data import DataLoader 
import torch 
from torch import nn 

In [19]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform= Compose([ToTensor(), Lambda(lambda x: torch.flatten(x))])
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform= Compose([ToTensor(), Lambda(lambda x: torch.flatten(x))])
)

In [22]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

print(next(iter(train_dataloader))[0].shape)

torch.Size([64, 784])


In [24]:
config = GuidedConfig(
    input_dim = 784, 
    latent_dim= 10,
    device = "cuda" if torch.cuda.is_available() else "cpu",
    guided_dim = 1,
    eta = 0,
    gamma = 0,
    beta = 1.0
)

enc = base.Encoder(input_dim=config.input_dim, latent_dim=config.latent_dim, hidden_dims = [128,64,])

class CustomDecoder(base.Decoder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, x):
        x = super().forward(x)
        x = nn.functional.sigmoid(x)

        return x
    
dec = CustomDecoder(output_dim = config.input_dim, latent_dim = config.latent_dim, hidden_dims = [64,128,])

In [25]:
model = GuidedVAE(
    config = config,
    encoder = enc, 
    decoder = dec,
    guide = base.Guide(1,10)
)
model

GuidedVAE(
  (encoder): Encoder(
    (linear_block): Sequential(
      (0): Linear(in_features=784, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=64, bias=True)
      (3): ReLU()
    )
    (fc_mu): Linear(in_features=64, out_features=10, bias=True)
    (fc_logvar): Linear(in_features=64, out_features=10, bias=True)
  )
  (decoder): CustomDecoder(
    (linear_block): Sequential(
      (0): Linear(in_features=10, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=784, bias=True)
    )
  )
  (guide): Guide(
    (classifier): Sequential(
      (0): Linear(in_features=1, out_features=10, bias=True)
    )
  )
)

In [26]:
from proteovae.trainers import BaseTrainer

optimizer = torch.optim.Adam(model.parameters(), lr=1e-03)

trainer = BaseTrainer(model, optimizer)