In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

In [2]:
#seed
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
model = models.resnet50()
model.fc = nn.Identity()
optimizer = torch.optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=0, nesterov=True)

In [4]:
model = model.cuda()
idx = torch.randperm(128).cuda()
x = torch.randn(128, 3, 224, 224).cuda()
y = torch.randn(128, 2048).cuda()

In [5]:
print("single forward pass")
for _ in range(10):
    z1 = model(x)
    loss1 = torch.mean(z1)
    loss2 = F.mse_loss(z1, y)
    loss = loss1 + loss2
    print(f"loss1: {loss1.item():.4f}, loss2: {loss2.item():.4f}, loss: {loss.item():.4f}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

single forward pass
loss1: 1.0347, loss2: 2.1001, loss: 3.1349
loss1: 1.0329, loss2: 2.0982, loss: 3.1311
loss1: 1.0304, loss2: 2.0929, loss: 3.1233
loss1: 1.0268, loss2: 2.0860, loss: 3.1128
loss1: 1.0226, loss2: 2.0767, loss: 3.0992
loss1: 1.0177, loss2: 2.0654, loss: 3.0831
loss1: 1.0123, loss2: 2.0536, loss: 3.0659
loss1: 1.0053, loss2: 2.0404, loss: 3.0457
loss1: 0.9972, loss2: 2.0242, loss: 3.0214
loss1: 0.9859, loss2: 2.0005, loss: 2.9864


In [6]:
print("double forward pass")
for _ in range(10):
    z1 = model(x)
    z2 = model(x)
    loss1 = torch.mean(z1)
    loss2 = F.mse_loss(z2, y)
    loss = loss1 + loss2
    print(f"loss1: {loss1.item():.4f}, loss2: {loss2.item():.4f}, loss: {loss.item():.4f}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

double forward pass
loss1: 1.0347, loss2: 2.1001, loss: 3.1349
loss1: 1.0328, loss2: 2.0982, loss: 3.1311
loss1: 1.0303, loss2: 2.0923, loss: 3.1227
loss1: 1.0268, loss2: 2.0860, loss: 3.1128
loss1: 1.0225, loss2: 2.0763, loss: 3.0988
loss1: 1.0176, loss2: 2.0659, loss: 3.0835
loss1: 1.0119, loss2: 2.0546, loss: 3.0665
loss1: 1.0045, loss2: 2.0381, loss: 3.0426
loss1: 0.9952, loss2: 2.0197, loss: 3.0149
loss1: 0.9793, loss2: 1.9920, loss: 2.9713
