Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bboylyg committed Jan 10, 2022
1 parent d61e4d7 commit 6907ea2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions main.py
Expand Up @@ -28,9 +28,9 @@ def train_step(opt, train_loader, nets, optimizer, criterions, epoch):
activation1_t, activation2_t, activation3_t, _ = tnet(img)

cls_loss = criterionCls(output_s, target)
at3_loss = criterionAT(activation3_s, activation3_t).detach() * opt.beta3
at2_loss = criterionAT(activation2_s, activation2_t).detach() * opt.beta2
at1_loss = criterionAT(activation1_s, activation1_t).detach() * opt.beta1
at3_loss = criterionAT(activation3_s, activation3_t.detach()) * opt.beta3
at2_loss = criterionAT(activation2_s, activation2_t.detach()) * opt.beta2
at1_loss = criterionAT(activation1_s, activation1_t.detach()) * opt.beta1
at_loss = at1_loss + at2_loss + at3_loss + cls_loss

prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
Expand Down Expand Up @@ -88,9 +88,9 @@ def test(opt, test_clean_loader, test_bad_loader, nets, criterions, epoch):
activation1_s, activation2_s, activation3_s, output_s = snet(img)
activation1_t, activation2_t, activation3_t, _ = tnet(img)

at3_loss = criterionAT(activation3_s, activation3_t).detach() * opt.beta3
at2_loss = criterionAT(activation2_s, activation2_t).detach() * opt.beta2
at1_loss = criterionAT(activation1_s, activation1_t).detach() * opt.beta1
at3_loss = criterionAT(activation3_s, activation3_t.detach()) * opt.beta3
at2_loss = criterionAT(activation2_s, activation2_t.detach()) * opt.beta2
at1_loss = criterionAT(activation1_s, activation1_t.detach()) * opt.beta1
at_loss = at3_loss + at2_loss + at1_loss
cls_loss = criterionCls(output_s, target)

Expand Down Expand Up @@ -201,4 +201,4 @@ def main():
train(opt)

if (__name__ == '__main__'):
main()
main()

0 comments on commit 6907ea2

Please sign in to comment.