You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
loss_fn = CutMixCrossEntropyLoss(True)
if __name__ == "__main__":
set_parameters_requires_grad(model , True)
epochs = 25
for epoch in range(epochs):
print('Epoch ', epoch,'/',epochs-1)
print('-'*15)
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0.0
# Iterate over data.
for inputs,labels in dataloaders[phase]:
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
# zero the parameter gradients
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = loss_fn(outputs, labels)
# we backpropagate to set our learning parameters only in training mode
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == torch.argmax(labels)) # (preds == labels.data) as the usage of .data is not recommended, as it might have unwanted side effect.
# scheduler for weight decay
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / float(dataset_sizes[phase])
epoch_acc = running_corrects / float(dataset_sizes[phase])
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
optimizer.swap_swa_sgd()
The text was updated successfully, but these errors were encountered:
For the loss function I am gtting the error
setting up the training function
The text was updated successfully, but these errors were encountered: