Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder is executed 2 times in VAE #105

Closed
tatsuhiko-inoue opened this issue Jan 16, 2020 · 2 comments
Closed

Encoder is executed 2 times in VAE #105

tatsuhiko-inoue opened this issue Jan 16, 2020 · 2 comments

Comments

@tatsuhiko-inoue
Copy link

以下のような VAE を実装して実行すると Encoder が2回実行されます。

import torch
import torch.nn as nn
from pixyz.distributions import Normal
from pixyz.losses import KullbackLeibler
from pixyz.models import VAE
import torch.optim as optim

class Encoder(Normal):
    def __init__(self):
        super().__init__(cond_var=["x"], var=["z"], name="q")
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        print("Encoder")
        return {"loc": self.linear(x), "scale": 1.0}

class Decoder(Normal):
    def __init__(self):
        super().__init__(cond_var=["z"], var=["x"], name="p")

    def forward(self, z):
        print("Decoder")
        return {"loc": z, "scale": 1.0}

def prior():
    return Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), var=["z"], features_shape=[10], name="p_{prior}")

q = Encoder()
p = Decoder()

prior = prior()
kl = KullbackLeibler(q, prior)

mdl = VAE(q, p, regularizer=kl, optimizer=optim.Adam, optimizer_params={"lr":1e-3})

x = torch.zeros((10, 10))
loss = mdl.train({"x": x})

出力

Encoder
Decoder
Encoder

KL divergence と再構成誤差のそれぞれで Encoder を実行しているように見えます。
Encoder を2回実行すると、その分学習時間が長くかかるため、1回で済ませたいのですが、方法はありますでしょうか?

@masa-su
Copy link
Owner

masa-su commented Mar 17, 2020

ありがとうございます.
#109 2回実行される問題については,こちらのプルリクで対応中です.

@masa-su
Copy link
Owner

masa-su commented Dec 14, 2021

対応済み(v0.3.3で完全対応)

@masa-su masa-su closed this as completed Dec 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants