Skip to content

Commit

Permalink
Merge pull request #37 from RasmusOrsoe/main
Browse files Browse the repository at this point in the history
Added trainer for multiple datasets (robustness test)
  • Loading branch information
RasmusOrsoe committed Oct 21, 2021
2 parents 675fa11 + 4eb0050 commit c2cf966
Showing 1 changed file with 21 additions and 33 deletions.
54 changes: 21 additions & 33 deletions src/gnn_reco/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ def GetBestParams(self):
return self.best_params

class PiecewiseLinearScheduler(object):
def __init__(self, training_dataset, start_lr, max_lr, end_lr, max_epochs):
def __init__(self, training_dataset_length, start_lr, max_lr, end_lr, max_epochs):
try:
self.dataset_length = len(training_dataset_length)
print('Passing dataset as training_dataset_length to PiecewiseLinearScheduler is deprecated. Please pass integer')
except:
self.dataset_length = training_dataset_length
self._start_lr = start_lr
self._max_lr = max_lr
self._end_lr = end_lr
self._steps_up = int(len(training_dataset)/2)
self._steps_down = len(training_dataset)*max_epochs - self._steps_up
self._steps_up = int(self.dataset_length/2)
self._steps_down = self.dataset_length*max_epochs - self._steps_up
self._current_step = 0
self._lr_list = self._calculate_lr_list()

Expand All @@ -94,15 +99,12 @@ def get_next_lr(self):
self._current_step = self._current_step + 1
return lr

class MultipleDatabasesTrainer(object):
def __init__(self, databases, selections, pulsemap, batch_size, FEATURES, TRUTH, num_workers,optimizer, n_epochs, loss_func, target, device, scheduler = None, patience = 10, early_stopping = True):
self.databases = databases
self.selections = selections
self.pulsemap = pulsemap
self.batch_size = batch_size
self.FEATURES = FEATURES
self.TRUTH = TRUTH
self.num_workers = num_workers
class MultipleDatasetsTrainer(object):
def __init__(self, training_loaders, validation_loaders, num_training_batches, num_validation_batches,optimizer, n_epochs, loss_func, target, device, scheduler = None, patience = 10, early_stopping = True):
self.validation_dataloaders = validation_loaders
self.training_dataloaders = training_loaders
self.num_training_batches = num_training_batches
self.num_validation_batches = num_validation_batches
if early_stopping:
self._early_stopping_method = EarlyStopping(patience = patience)
self.optimizer = optimizer
Expand All @@ -112,31 +114,13 @@ def __init__(self, databases, selections, pulsemap, batch_size, FEATURES, TRUTH,
self.target = target
self.device = device

self._setup_dataloaders()

def __call__(self, model):
trained_model = self._train(model)
self._load_best_parameters(model)
return trained_model

def _setup_dataloaders(self):
self.training_dataloaders = []
self.validation_dataloaders = []
for i in range(len(self.databases)):
db = self.databases[i]
selection = self.selections[i]
training_dataloader, validation_dataloader = make_train_validation_dataloader(db, selection, self.pulsemap, self.batch_size, self.FEATURES, self.TRUTH, self.num_workers)
self.training_dataloader.append(training_dataloader)
self.validation_dataloader.append(validation_dataloader)
return
def _count_minibatches(self):
training_batches = 0
for i in range(len(self.training_dataloaders)):
training_batches +=len(self.training_dataloaders[i])
return training_batches

def _train(self,model):
training_batches = self._count_minibatches()
training_batches = self.num_training_batches
for epoch in range(self.n_epochs):
acc_loss = torch.tensor([0],dtype = float).to(self.device)
iteration = 1
Expand All @@ -155,18 +139,21 @@ def _train(self,model):
self.optimizer.param_groups[0]['lr'] = self.scheduler.get_next_lr().item()
acc_loss += loss
iteration +=1
pbar.update(iteration)
pbar.update(1)
pbar.set_description('epoch: %s || loss: %s'%(epoch, acc_loss.item()/iteration))
validation_loss = self._validate(model)
pbar.set_description('epoch: %s || loss: %s || valid loss : %s'%(epoch,acc_loss.item()/iteration, validation_loss.item()))
if self._early_stopping_method.step(validation_loss,model):
print('EARLY STOPPING: %s'%epoch)
break
return model

def _validate(self,model):
acc_loss = torch.tensor([0],dtype = float).to(self.device)
model.eval()
iterations = 1
pbar_valid = tqdm(total = self.num_validation_batches, unit= 'batches')
pbar_valid.set_description('Validating..')
for validation_dataloader in self.validation_dataloaders:
for batch_of_graphs in validation_dataloader:
batch_of_graphs.to(self.device)
Expand All @@ -175,6 +162,7 @@ def _validate(self,model):
loss = self.loss_func(out, batch_of_graphs, self.target)
acc_loss += loss
iterations +=1
pbar_valid.update(1)
return acc_loss/iterations

def _load_best_parameters(self,model):
Expand Down Expand Up @@ -249,7 +237,7 @@ def __init__(self, dataloader, target, device, output_column_names, post_process
self.device = device
self.post_processing_method = post_processing_method
def __call__(self, model):
self.model = model
self.model = model.to(self.device)
self.model.eval()
self.model.predict = True
if self.post_processing_method == None:
Expand Down

0 comments on commit c2cf966

Please sign in to comment.