Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update: GDumb + minor fixes #20

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ experiment runs with one task on one dataset, results would be equivalent to 'co
| non-incremental supervised learning | yes | yes | 1 |

Current available approaches include:
<div align="center">
<p align="center"><b>
Finetuning • Freezing • Joint

LwF • iCaRL • EWC • PathInt • MAS • RWalk • EEIL • LwM • DMC • BiC • LUCIR • IL2M
<div align="center"><p align="center"><b>

[Finetuning](./src/approach/finetuning.py) • [Freezing](./src/approach/freezing.py) • [Joint](./src/approach/joint.py)

[LwF](./src/approach/lwf.py) • [iCaRL](./src/approach/icarl.py) • [EWC](./src/approach/ewc.py) •
[PathInt](./src/approach/path_integral.py) • [MAS](./src/approach/mas.py) • [RWalk](./src/approach/r_walk.py) •
[EEIL](./src/approach/eeil.py) • [LwM](./src/approach/lwm.py) • [DMC](./src/approach/dmc.py) •
[BiC](./src/approach/bic.py) • [LUCIR](./src/approach/lucir.py) • [IL2M](./src/approach/il2m.py) •
[GDumb](./src/approach/gdumb.py)

</b></p>
</div>

Expand Down
7 changes: 7 additions & 0 deletions src/approach/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,10 @@ can be combined with Freezing by using `--freeze-after num_task (int)`. However,
`--approach il2m`
[ICCV 2019](https://openaccess.thecvf.com/content_ICCV_2019/papers/Belouadah_IL2M_Class_Incremental_Learning_With_Dual_Memory_ICCV_2019_paper.pdf)
| [code](https://github.com/EdenBelouadah/class-incremental-learning/tree/master/il2m)

### GDumb
`--approach gdumb`
[ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123470511.pdf)
| [code](https://github.com/drimpossible/GDumb)

* `--regularization`: Use regularization (default='cutmix')
107 changes: 107 additions & 0 deletions src/approach/gdumb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
import numpy as np
from copy import deepcopy
from argparse import ArgumentParser
from torch.utils.data.dataloader import default_collate

from utils import cutmix_data
from .incremental_learning import Inc_Learning_Appr
from datasets.exemplars_dataset import ExemplarsDataset


class Appr(Inc_Learning_Appr):
"""Class implementing the GDumb approach
described in https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123470511.pdf
Original code available at https://github.com/drimpossible/GDumb
"""

def __init__(self, model, device, nepochs=100, lr=0.05, lr_min=0.0005, lr_factor=3, lr_patience=5, clipgrad=10.0,
momentum=0.9, wd=1e-6, multi_softmax=False, wu_nepochs=1, wu_lr_factor=1, fix_bn=False,
eval_on_train=False, logger=None, exemplars_dataset=None, regularization='cutmix'):
super(Appr, self).__init__(model, device, nepochs, lr, lr_min, lr_factor, lr_patience, clipgrad, momentum, wd,
multi_softmax, wu_nepochs, wu_lr_factor, fix_bn, eval_on_train, logger,
exemplars_dataset)
self.regularization = regularization
self.init_model = None

have_exemplars = self.exemplars_dataset.max_num_exemplars + self.exemplars_dataset.max_num_exemplars_per_class
assert (have_exemplars > 0), 'Error: GDumb needs exemplars.'

@staticmethod
def exemplars_dataset_class():
return ExemplarsDataset

@staticmethod
def extra_parser(args):
"""Returns a parser containing the approach specific parameters"""
parser = ArgumentParser()
parser.add_argument('--regularization', default='cutmix', required=False,
help='Use regularization (default=%(default)s)')
return parser.parse_known_args(args)

def train_loop(self, t, trn_loader, val_loader):
"""Contains the epochs loop"""
# 1. GDumb resets the network before learning a new task, relying only on the exemplars stored so far
if t == 0:
# Keep the randomly initialized model from time step 0
self.init_model = deepcopy(self.model)
else:
# Reinitialize the model (backbone) for very task from time step 1
self.model.model = deepcopy(self.init_model.model)
for layer in self.model.heads.children():
layer.reset_parameters()

# EXEMPLAR MANAGEMENT -- select training subset from current task and exemplar memory
aux_loader = torch.utils.data.DataLoader(trn_loader.dataset + self.exemplars_dataset,
batch_size=trn_loader.batch_size,
shuffle=True,
num_workers=trn_loader.num_workers,
pin_memory=trn_loader.pin_memory)
self.exemplars_dataset.collect_exemplars(self.model, aux_loader, val_loader.dataset.transform)

# Set new set of exemplars as the only data to train on
trn_loader = torch.utils.data.DataLoader(self.exemplars_dataset,
batch_size=trn_loader.batch_size,
shuffle=True,
num_workers=trn_loader.num_workers,
pin_memory=trn_loader.pin_memory)

# FINETUNING TRAINING -- contains the epochs loop
super().train_loop(t, trn_loader, val_loader)

def train_epoch(self, t, trn_loader):
"""Runs a single epoch"""
self.model.train()
if self.fix_bn and t > 0:
self.model.freeze_bn()
for images, targets in trn_loader:
# Get exemplars
if len(self.exemplars_dataset) > 0:
# 2. Balanced batches
exemplar_indices = torch.randperm(len(self.exemplars_dataset))[:trn_loader.batch_size]
images_exemplars, targets_exemplars = default_collate([self.exemplars_dataset[i]
for i in exemplar_indices])
images = torch.cat((images, images_exemplars), dim=0)
targets = torch.cat((targets, targets_exemplars), dim=0)
Comment on lines +115 to +123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmasana if the trn_loader is overridden with the exemplar_dataset, why here we do the permutation and not simply use images/targets out of it? It's more like drawing samples with the repetition than balanced batches, right?


# 3. Apply cutmix as regularization
do_cutmix = self.regularization == 'cutmix' and np.random.rand(1) < 0.5 # cutmix_prob (Sec.4)
if do_cutmix:
images, targets_a, targets_b, lamb = cutmix_data(x=images, y=targets, alpha=1.0) # cutmix_alpha (Sec.4)
# Forward current model
outputs = self.model(images.to(self.device))
loss = lamb * self.criterion(t, outputs, targets_a.to(self.device))
loss += (1.0 - lamb) * self.criterion(t, outputs, targets_b.to(self.device))
else:
outputs = self.model(images.to(self.device))
loss = self.criterion(t, outputs, targets.to(self.device))

# Backward
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clipgrad)
self.optimizer.step()

def criterion(self, t, outputs, targets):
"""Returns the loss value"""
return torch.nn.functional.cross_entropy(torch.cat(outputs, dim=1), targets)
8 changes: 5 additions & 3 deletions src/approach/lwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ def eval(self, t, val_loader):

def attention_distillation_loss(self, attention_map1, attention_map2):
"""Calculates the attention distillation loss"""
attention_map1 = torch.norm(attention_map1, p=2, dim=1)
attention_map2 = torch.norm(attention_map2, p=2, dim=1)
return torch.norm(attention_map2 - attention_map1, p=1, dim=1).sum(dim=1).mean()
attention_map1 = torch.nn.functional.normalize(attention_map1.view(attention_map1.size(0), -1),
p=2, dim=1, eps=1e-12, out=None)
attention_map2 = torch.nn.functional.normalize(attention_map2.view(attention_map1.size(0), -1),
p=2, dim=1, eps=1e-12, out=None)
return torch.norm(attention_map2 - attention_map1, p=1, dim=1).mean()

def cross_entropy(self, outputs, targets, exp=1.0, size_average=True, eps=1e-5):
"""Calculates cross-entropy with temperature scaling"""
Expand Down
14 changes: 14 additions & 0 deletions src/datasets/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@
'flip': True,
'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023))
},
'cifar100_fixed': {
'path': join(_BASE_DATA_PATH, 'cifar100'),
'resize': None,
'pad': 4,
'crop': 32,
'flip': True,
'normalize': ((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
'class_order': [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82,
83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99
]
},
'cifar100_icarl': {
'path': join(_BASE_DATA_PATH, 'cifar100'),
'resize': None,
Expand Down
2 changes: 2 additions & 0 deletions src/gridsearch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self):
'r_walk': {
'lamb': [20],
},
'gdumb': {
},
}
self.current_lr = self.params['general']['lr'][0]
self.current_tradeoff = 0
Expand Down
7 changes: 5 additions & 2 deletions src/main_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
import importlib
import numpy as np
import torch.multiprocessing
from functools import reduce

import utils
Expand Down Expand Up @@ -91,8 +92,8 @@ def main(argv=None):
parser.add_argument('--eval-on-train', action='store_true',
help='Show train loss and accuracy (default=%(default)s)')
# gridsearch args
parser.add_argument('--gridsearch-tasks', default=-1, type=int,
help='Number of tasks to apply GridSearch (-1: all tasks) (default=%(default)s)')
parser.add_argument('--gridsearch-tasks', default=0, type=int,
help='Number of tasks to apply GridSearch (default=%(default)s)')

# Args -- Incremental Learning Framework
args, extra_args = parser.parse_known_args(argv)
Expand Down Expand Up @@ -120,6 +121,8 @@ def main(argv=None):
else:
print('WARNING: [CUDA unavailable] Using CPU instead!')
device = 'cpu'
# In case the dataset is too large
torch.multiprocessing.set_sharing_strategy('file_system')
# Multiple gpus
# if torch.cuda.device_count() > 1:
# self.C = torch.nn.DataParallel(C)
Expand Down