Skip to content

Commit

Permalink
Merge pull request #2 from SauravMaheshkar/main
Browse files Browse the repository at this point in the history
Fix Weights and Biases Logging and paths
  • Loading branch information
hlml committed Apr 16, 2022
2 parents 35f9e47 + 4bd88b5 commit aeb8597
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
**/__pycache__/
*.pyc
venv/
wandb/
10 changes: 7 additions & 3 deletions ease_of_teaching/forget_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@
group_name = args_raw.group_vars[0] + str(getattr(args_raw, args_raw.group_vars[0]))
for var in args_raw.group_vars[1:]:
group_name = group_name + '_' + var + str(getattr(args_raw, var))
wandb.init(project="ease_of_teaching_forget",
group=args_raw.fname,
name=group_name)
wandb.init(project="fortuitous_forgetting",
group="ease_of_teaching")
for var in args_raw.group_vars:
wandb.config.update({var:getattr(args_raw, var)})
else:
wandb.init(project="fortuitous_forgetting",
group="ease_of_teaching")
for var in args_raw.group_vars:
wandb.config.update({var: getattr(args_raw, var)})

args = vars(args_raw) # convert python object to dict
# args = parser.parse() # parsed argument from CLI
Expand Down
2 changes: 1 addition & 1 deletion llf_ke/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):

# General Config
parser.add_argument(
"--data", help="path to dataset base directory", default="/home/datasets"
"--data", help="path to dataset base directory", default="./datasets"
)

parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd")
Expand Down
4 changes: 2 additions & 2 deletions llf_ke/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
datasets_dir = '/home/datasets'
checkpoints_dir = '/home/checkpoints'
datasets_dir = './datasets'
checkpoints_dir = './checkpoints'
10 changes: 7 additions & 3 deletions llf_ke/train_KE_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,15 @@ def start_KE(cfg):
group_name = cfg.group_vars[0] + str(getattr(cfg, cfg.group_vars[0]))
for var in cfg.group_vars[1:]:
group_name = group_name + '_' + var + str(getattr(cfg, var))
wandb.init(project="llf_ke",
group=cfg.group_name,
name=group_name)
wandb.init(project="fortuitous_forgetting",
group="llf_ke")
for var in cfg.group_vars:
wandb.config.update({var:getattr(cfg, var)})
else:
wandb.init(project="fortuitous_forgetting",
group="llf_ke")
for var in cfg.group_vars:
wandb.config.update({var: getattr(cfg, var)})

if cfg.seed is not None and cfg.fix_seed: #FIXING SEED LEADS TO SAME REINITIALIZATION VALUES FOR EACH GENERATION
random.seed(cfg.seed)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pandas==1.3.5
Pillow==9.0.0
PyYAML==6.0
scipy==1.7.3
torch==1.7.0
torchvision==0.11.2
torch==1.10.2
torchvision==0.11.3
tqdm==4.62.3
wandb==0.12.10
11 changes: 8 additions & 3 deletions targeted_forgetting/mixed_group_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchvision.datasets.mnist import *
ACT = F.relu

data_path = '/home/datasets'
data_path = './datasets'
use_cuda = torch.cuda.is_available()


Expand Down Expand Up @@ -623,11 +623,16 @@ def main():
group_name = args.group_vars[0] + str(getattr(args, args.group_vars[0]))
for var in args.group_vars[1:]:
group_name = group_name + '_' + var + str(getattr(args, var))
wandb.init(project="mixed_group_training",
group=args.fname,
wandb.init(project="fortuitous_forgetting",
group="mixed_group_training",
name=group_name)
for var in args.group_vars:
wandb.config.update({var: getattr(args, var)})
else:
wandb.init(project="fortuitous_forgetting",
group="mixed_group_training")
for var in args.group_vars:
wandb.config.update({var: getattr(args, var)})

if args.train_data == 'mnist':
trans = ([transforms.ToTensor()])
Expand Down

0 comments on commit aeb8597

Please sign in to comment.