Skip to content

Commit

Permalink
remove alternate training legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
dandelin committed Jun 11, 2021
1 parent 7a3ad79 commit 98a51e6
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions vilt/modules/vilt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,23 +161,10 @@ def check_non_acc_grad(pl_module):


def set_task(pl_module):
if not pl_module.training:
pl_module.current_tasks = [
k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1
]
return

if not check_non_acc_grad(pl_module):
return

sampling_pools = list()
for k, v in pl_module.hparams.config["loss_names"].items():
sampling_pools.extend([k] * int(v))

g = torch.Generator()
g.manual_seed(pl_module.global_step)
idx = torch.randperm(len(sampling_pools), generator=g)[0]
pl_module.current_tasks = sampling_pools[idx]
pl_module.current_tasks = [
k for k, v in pl_module.hparams.config["loss_names"].items() if v >= 1
]
return


def set_schedule(pl_module):
Expand Down

0 comments on commit 98a51e6

Please sign in to comment.