You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Big thanks for your pytorch implementation of the logavgexp !
I noticed that it is easy for logavgexp to reproduce the max operator (temperature goes nicely to 0); but it has trouble reproducing the mean operator, the temperature stagnates in the following example at 0.35; do your have an explanation for that or ways to circumvent this issue?
Thanks !
import torch
torch.manual_seed(12345)
from logavgexp_pytorch import LogAvgExp
B = 10
N = 20
x = torch.randn(B,N)
#y, _ = x.max(dim=-1, keepdim=True)
y = x.mean(dim=-1, keepdim=True)
logavgexp = LogAvgExp(
temp = 1,
dim = 1,
learned_temp = True,
keepdim = True)
optimizer = torch.optim.Adam(logavgexp.parameters(), lr=0.01)
loss_func = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.5,
patience=100,
verbose=True)
for i in range(10000):
prediction = logavgexp(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step(loss, epoch=i)
print(f"ite: {i}, loss: {loss.item():.2e}, temperature: {logavgexp.temp.exp().item():.4f}")
The text was updated successfully, but these errors were encountered:
Hi,
Big thanks for your pytorch implementation of the logavgexp !
I noticed that it is easy for logavgexp to reproduce the max operator (temperature goes nicely to 0); but it has trouble reproducing the mean operator, the temperature stagnates in the following example at 0.35; do your have an explanation for that or ways to circumvent this issue?
Thanks !
The text was updated successfully, but these errors were encountered: