In [1]:
import argparse

import torch.optim as optim
import torch.utils.data.sampler as sampler

from auto_lambda import AutoLambda
from create_network import *
from create_dataset import *
from create_network import MTANDeepLabv3, MTLDeepLabv3
from utils import *
from misc import genWeights

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class trainerDense:

    def __init__(self,
                mode='none',
                port='none',
                network='split',
                weight='equal',
                grad_method='none',
                gpu=0,
                with_noise=False,  # Cambiado 'store_true' a False
                autol_init=0.1,
                autol_lr=1e-4,
                task='all',
                dataset='nyuv2',
                seed = 0, 
                total_epoch = 50):

        self.mode = mode
        self.port = port
        self.network = network
        self.weight = weight
        self.grad_method = grad_method
        self.gpu = gpu
        self.with_noise = with_noise  # Cambiado 'store_true' a False
        self.autol_init = autol_init
        self.autol_lr = autol_lr
        self.task = task
        self.dataset = dataset
        self.seed = int(seed)
        self.device = None
        self.train_tasks = None
   
        self.total_epoch = total_epoch 


    def initialize(self):
        torch.manual_seed(0)
        np.random.seed(0)
        random.seed(0)

        # create logging folder to store training weights and losses
        if not os.path.exists('logging'):
            os.makedirs('logging')


        # define model, optimiser and scheduler
        self.device = torch.device("cuda:{}".format(int(self.gpu)) if torch.cuda.is_available() else "cpu")
        if self.with_noise:
            self.train_tasks = create_task_flags('all', self.dataset, with_noise=True)
        else:
            self.train_tasks = create_task_flags('all', self.dataset, with_noise=False)

        self.pri_tasks = create_task_flags(self.task, self.dataset, with_noise=False)

        train_tasks_str = ''.join(self.task.title() + ' + ' for task in self.train_tasks.keys())[:-3]
        pri_tasks_str = ''.join(self.task.title() + ' + ' for task in self.pri_tasks.keys())[:-3]

        print('Dataset: {} | Training Task: {} | Primary Task: {} in Multi-task / Auxiliary Learning Mode with {}'
            .format(self.dataset.title(), train_tasks_str, pri_tasks_str, self.network.upper()))
        print('Applying Multi-task Methods: Weighting-based: {} + Gradient-based: {}'
            .format(self.weight.title(), self.grad_method.upper()))

        if self.network == 'split':
            self.model = MTLDeepLabv3(self.train_tasks).to(self.device)
        elif self.network == 'mtan':
            self.model = MTANDeepLabv3(self.train_tasks).to(self.device)


    def choose_task_weighting(self, weight):

        if weight == 'uncert':
            logsigma = torch.tensor([-0.7] * len(self.train_tasks), requires_grad=True, device=self.device)
            self.params = list(self.model.parameters()) + [logsigma]
            self.logsigma_ls = np.zeros([self.total_epoch, len(self.train_tasks)], dtype=np.float32)

        if weight in ['dwa', 'equal']:
            self.T = 2.0  # temperatura utilizada en dwa
            self.lambda_weight = np.ones([self.total_epoch, len(self.train_tasks)])
            self.params = self.model.parameters()

        if weight == 'autol':
            self.params = self.model.parameters()
            self.autol = AutoLambda(self.model, self.device, self.train_tasks, self.pri_tasks, self.autol_init)
            self.meta_weight_ls = np.zeros([self.total_epoch, len(self.train_tasks)], dtype=np.float32)
            self.meta_optimizer = optim.Adam([self.autol.meta_weights], lr=self.autol_lr)


        # para probar combinaciones de pesos 
        if weight == 'combinations':
            #numero de pesos
            nw = 4
            #genera pesos
            self.my_weight = genWeights(nw, len(self.train_tasks), self.device)
            self.lambda_weight = np.ones([self.total_epoch, len(self.train_tasks)])
            print('primer lambda ', self.lambda_weight.shape)
            self.params = self.model.parameters()      
            
        self.optimizer = optim.SGD(self.params, lr=0.1, weight_decay=1e-4, momentum=0.9)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.total_epoch)  

        # una copia de train_loader con diferente orden de datos, utilizada para la actualización meta de Auto-Lambda
        if weight == 'autol':
            self.val_loader = torch.utils.data.DataLoader(
                dataset= self.train_set,
                batch_size = self.batch_size,
                shuffle=True,
                num_workers=4
            )



    def define_dataset(self, dataset):

        if dataset == 'nyuv2':
            dataset_path = 'dataset/nyuv2'
            train_set = NYUv2(root=dataset_path, train=True, augmentation=True)
            test_set = NYUv2(root=dataset_path, train=False)
            self.batch_size = 4
        elif dataset == 'cityscapes':
            dataset_path = 'dataset/cityscapes'
            train_set = CityScapes(root=dataset_path, train=True, augmentation=True)
            test_set = CityScapes(root=dataset_path, train=False)
            self.batch_size = 4

    #def define_data_loader(self):

        self.train_loader = torch.utils.data.DataLoader(
            dataset = train_set,
            batch_size = self.batch_size,
            shuffle = True,
            num_workers = 4
        )

        self.test_loader = torch.utils.data.DataLoader(
            dataset = test_set,
            batch_size = self.batch_size,
            shuffle = False
        )    
    
    def apply_gradient_methods(self, grad_method):
    # apply gradient methods
        if grad_method != 'none':
            self.rng = np.random.default_rng()
            self.grad_dims = []
            for mm in self.model.shared_modules():
                for param in mm.parameters():
                    self.grad_dims.append(param.data.numel())
            self.grads = torch.Tensor(sum(self.grad_dims), len(self.train_tasks)).to(self.device)
            

    def update_weights(self, weights, index):
        if self.weight == 'combinations':
            self.lambda_weight = np.tile(weights, (self.total_epoch, 1))            
    
    # Train and evaluate multi-task network

    def train(self):

        if self.weight == 'combinations':
            w = self.my_weight
        
                #self.lambda_weight[i]  = w
        else: 
            w = np.ones([1, 1])        

        # iterar sobre los pesos o hacer solo una iteración si no hay combinaciones
        for j in range(w.shape[-1]):

            self.apply_gradient_methods(self.grad_method)

            self.train_batch = len(self.train_loader)
            self.test_batch = len(self.test_loader)
            self.train_metric = TaskMetric(self.train_tasks, self.pri_tasks, self.batch_size, self.total_epoch, self.dataset)
            self.test_metric = TaskMetric(self.train_tasks, self.pri_tasks, self.batch_size, self.total_epoch, self.dataset, include_mtl=True)

            print('\nProbando combinación de pesos: ', w[:, j])
            #actualizar los pesos con las combinaciones
            if self.weight == 'combinations':  
                #self.lambda_weight[j]  = w[:, j]
                self.lambda_weight = np.tile(w[:, j].numpy(), (self.total_epoch, 1))
                print('lambda ', self.lambda_weight)
                print('lambda shape', self.lambda_weight.shape)
                print('w[:, j].numpy() ', w[:, j].numpy())


            for index in range(self.total_epoch):
                print('Epoca: ', index)

                # Aplicar Dynamic Weight Average
                if self.weight == 'dwa':
                    if index == 0 or index == 1:
                        self.lambda_weight[index, :] = 1.0
                    else:
                        w = []
                        for i, t in enumerate(self.train_tasks):
                            w += [self.train_metric.metric[t][index - 1, 0] / self.train_metric.metric[t][index - 2, 0]]
                        w = torch.softmax(torch.tensor(w) / self.T, dim=0)
                        self.lambda_weight[index] = len(self.train_tasks) * w.numpy()
      

                # Iterar sobre todos los lotes
                self.model.train()
                train_dataset = iter(self.train_loader)

                if self.weight == 'autol':
                    val_dataset = iter(self.val_loader)

                for k in range(self.train_batch):
                    train_data, train_target = next(train_dataset)
                    train_data = train_data.to(self.device)
                    train_target = {task_id: train_target[task_id].to(self.device) for task_id in self.train_tasks.keys()}

                    # Actualizar meta-pesos con Auto-Lambda
                    if self.weight == 'autol':
                        val_data, val_target = next(val_dataset)
                        val_data = val_data.to(self.device)
                        val_target = {task_id: val_target[task_id].to(self.device) for task_id in self.train_tasks.keys()}

                        self.meta_optimizer.zero_grad()
                        self.autol.unrolled_backward(train_data, train_target, val_data, val_target,
                                                    self.scheduler.get_last_lr()[0], self.optimizer)
                        self.meta_optimizer.step()

                    # Actualizar parámetros de la red multi-tarea con pesos de tareas
                    self.optimizer.zero_grad()
                    train_pred = self.model(train_data)
                    train_loss = [compute_loss(train_pred[i], train_target[task_id], task_id) for i, task_id in enumerate(self.train_tasks)]

                    train_loss_tmp = [0] * len(self.train_tasks)
  
                    if self.weight in ['equal', 'dwa', 'combinations']:
                        train_loss_tmp = [w * train_loss[i] for i, w in enumerate(self.lambda_weight[index])]

                    if self.weight == 'uncert':
                        train_loss_tmp = [1 / (2 * torch.exp(w)) * train_loss[i] + w / 2 for i, w in enumerate(self.logsigma)]

                    if self.weight == 'autol':
                        train_loss_tmp = [w * train_loss[i] for i, w in enumerate(self.autol.meta_weights)]

                    loss = sum(train_loss_tmp)

                    if self.grad_method == 'none':
                        loss.backward()
                        self.optimizer.step()

                    # Métodos basados en gradientes aplicados aquí:
                    elif self.grad_method == "graddrop":
                        for i in range(len(self.train_tasks)):
                            train_loss_tmp[i].backward(retain_graph=True)
                            self.grad2vec(self.model, self.grads, self.grad_dims, i)
                            self.model.zero_grad_shared_modules()
                        g = self.graddrop(self.grads)
                        self.overwrite_grad(self.model, g, self.grad_dims, len(self.train_tasks))
                        self.optimizer.step()

                    elif self.grad_method == "pcgrad":
                        for i in range(len(self.train_tasks)):
                            train_loss_tmp[i].backward(retain_graph=True)
                            self.grad2vec(self.model, self.grads, self.grad_dims, i)
                            self.model.zero_grad_shared_modules()
                        g = self.pcgrad(self.grads, self.rng, len(self.train_tasks))
                        self.overwrite_grad(self.model, g, self.grad_dims, len(self.train_tasks))
                        self.optimizer.step()

                    elif self.grad_method == "cagrad":
                        for i in range(len(self.train_tasks)):
                            train_loss_tmp[i].backward(retain_graph=True)
                            self.grad2vec(self.model, self.grads, self.grad_dims, i)
                            self.model.zero_grad_shared_modules()
                        g = self.cagrad(self.grads, len(self.train_tasks), 0.4, rescale=1)
                        self.overwrite_grad(self.model, g, self.grad_dims, len(self.train_tasks))
                        self.optimizer.step()

                    self.train_metric.update_metric(train_pred, train_target, train_loss)

                train_str = self.train_metric.compute_metric()
                self.train_metric.reset()


            #def evaluate(self):

                # evaluating test data
                self.model.eval()
                with torch.no_grad():
                    test_dataset = iter(self.test_loader)
                    for k in range(self.test_batch):
                        test_data, test_target = next(test_dataset)
                        test_data = test_data.to(self.device)
                        test_target = {task_id: test_target[task_id].to(self.device) for task_id in self.train_tasks.keys()}

                        test_pred = self.model(test_data)
                        test_loss = [compute_loss(test_pred[i], test_target[task_id], task_id) for i, task_id in enumerate(self.train_tasks)]

                        self.test_metric.update_metric(test_pred, test_target, test_loss)

                test_str = self.test_metric.compute_metric()
                self.test_metric.reset()

                self.scheduler.step()

                print('Epoch {:04d} | TRAIN:{} || TEST:{} | Best: {} {:.4f}'
                    .format(index, train_str, test_str, self.task.title(), self.test_metric.get_best_performance(self.task)))

                if self.weight == 'autol':
                    self.meta_weight_ls[index] = self.autol.meta_weights.detach().cpu()
                    dict = {'train_loss': self.train_metric.metric, 'test_loss': self.test_metric.metric,
                            'weight': self.meta_weight_ls}

                    print(get_weight_str(self.meta_weight_ls[index], self.train_tasks))

                if self.weight in ['dwa', 'equal', 'combinations']:
                    dict = {'train_loss': self.train_metric.metric, 'test_loss': self.test_metric.metric,
                            'weight': self.lambda_weight}

                    print(get_weight_str(self.lambda_weight[index], self.train_tasks))

                if self.weight == 'uncert':
                    self.logsigma_ls[index] = self.logsigma.detach().cpu()
                    dict = {'train_loss': self.train_metric.metric, 'test_loss': self.test_metric.metric,
                            'weight': self.logsigma_ls}

                    print(get_weight_str(1 / (2 * np.exp(self.logsigma_ls[index])), self.train_tasks))

            np.save('logging/mtl_dense_{}_{}_{}_{}_{}_{}_.npy'
                    .format(self.network, self.dataset, self.task, self.weight, self.grad_method, self.seed), dict)





In [3]:
t = trainerDense(total_epoch = 2, weight = 'combinations')
t.initialize()
t.choose_task_weighting(weight = 'combinations')
t.define_dataset(dataset = 'nyuv2')

Dataset: Nyuv2 | Training Task: All + All + All | Primary Task: All + All + All in Multi-task / Auxiliary Learning Mode with SPLIT
Applying Multi-task Methods: Weighting-based: Combinations + Gradient-based: NONE
primer lambda  (2, 3)


In [4]:
t.train()


Probando combinación de pesos:  tensor([0., 0., 1.], dtype=torch.float64)
lambda  [[0. 0. 1.]
 [0. 0. 1.]]
lambda shape (2, 3)
w[:, j].numpy()  [0. 0. 1.]
Epoca:  0


Epoch 0000 | TRAIN: Seg 2.6003 0.0252 Depth 2.3921 2.3921 Normal 1.0794 43.4219 || TEST: Seg 2.6041 0.0273 Depth 2.7049 2.7049 Normal 1.0318 41.5606 | All -1.9901 | Best: All -1.9901
Task Weighting | Seg 0.0000 Depth 0.0000 Normal 1.0000 
Epoca:  1
Epoch 0001 | TRAIN: Seg 2.6008 0.0248 Depth 2.3801 2.3801 Normal 1.0142 40.9814 || TEST: Seg 4.9981 0.0197 Depth 2.7440 2.7440 Normal 1.1098 44.7186 | All -2.0679 | Best: All -1.9901
Task Weighting | Seg 0.0000 Depth 0.0000 Normal 1.0000 

Probando combinación de pesos:  tensor([0., 1., 0.], dtype=torch.float64)
lambda  [[0. 1. 0.]
 [0. 1. 0.]]
lambda shape (2, 3)
w[:, j].numpy()  [0. 1. 0.]
Epoca:  0
Epoch 0000 | TRAIN: Seg 2.6058 0.0242 Depth 2.4177 2.4177 Normal 0.9947 40.2480 || TEST: Seg 2.6055 0.0250 Depth 2.7202 2.7202 Normal 0.9826 39.6289 | All -1.9729 | Best: All -1.9729
Task Weighting | Seg 0.0000 Depth 1.0000 Normal 0.0000 
Epoca:  1
Epoch 0001 | TRAIN: Seg 2.5901 0.0294 Depth 1.8378 1.8378 Normal 1.0668 43.1756 || TEST: Seg 2.61