Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

learned temperature stagnates at a low value (a high value is expected) #2

Open
ldv1 opened this issue Sep 9, 2022 · 0 comments
Open

Comments

@ldv1
Copy link

ldv1 commented Sep 9, 2022

Hi,

Big thanks for your pytorch implementation of the logavgexp !

I noticed that it is easy for logavgexp to reproduce the max operator (temperature goes nicely to 0); but it has trouble reproducing the mean operator, the temperature stagnates in the following example at 0.35; do your have an explanation for that or ways to circumvent this issue?
Thanks !

import torch
torch.manual_seed(12345)
from logavgexp_pytorch import LogAvgExp

B = 10
N = 20
x = torch.randn(B,N)
#y, _ = x.max(dim=-1, keepdim=True)
y    = x.mean(dim=-1, keepdim=True)
logavgexp = LogAvgExp(
    temp = 1,
    dim = 1,
    learned_temp = True,
    keepdim = True)

optimizer = torch.optim.Adam(logavgexp.parameters(), lr=0.01)
loss_func = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.5,
                                                       patience=100,
                                                       verbose=True)

for i in range(10000):
    prediction = logavgexp(x)
    loss = loss_func(prediction, y) 
    optimizer.zero_grad()
    loss.backward()        
    optimizer.step()
    scheduler.step(loss, epoch=i)
    print(f"ite: {i}, loss: {loss.item():.2e}, temperature: {logavgexp.temp.exp().item():.4f}")      
    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant