In [None]:
    cur_logs_path = os.path.join(LOGS_DIR, args.name)
    os.makedirs(cur_logs_path, exist_ok=True)
    
    cur_ckpt_path = os.path.join(CKPT_DIR, args.name)
    os.makedirs(cur_ckpt_path, exist_ok=True)

In [None]:
models_dict = {
                'mmnet': {
                    'model': mmnet,
                    'load_ckpt': 
                },
                'faceid': {
                     'model': faceid,
                     'load_ckpt': "/home/safin/FaceReID/ckpt/joint_dnfr_16.04/faceid/weights_30"
                },
                'argcmargin': {
                     'model': arcmargin,
                     'load_ckpt': "/home/safin/FaceReID/ckpt/joint_dnfr_16.04/faceid/weights_30"
                }
              }

losses_dict = {
                'L1': ,
                'FaceID': 
            }

optimizers_dict = {
                    'general': 
                }

log_dict = {
            
        }

class BaseExpRunner():
    def __init__(self, name, models_dict, schedulers_dict, optimizers_dict, losses_dict):
        self.name = name
        self.cur_logs_path = os.path.join(LOGS_DIR, name)
        self.cur_ckpt_path = os.path.join(CKPT_DIR, name)
        
        self.models_dict = models_dict
        self.schedulers_dict = schedulers_dict
        self.optimizers_dict = optimizers_dict
        self.losses_dict = losses_dict
        self.logs_dict = {}
        
        self.cur_epoch = 0

        self.load_ckpts()
        self.create_ckpt_dirs()
        self.create_log_dirs()

    def load_ckpts(self):
        for model_dict in self.models_dict.keys():
            model_ckpt_path = model_dict.get('load_ckpt')
            if model_ckpt_path is not None:
                model = model_dict['model']
                model_dict['model'] = load_model(model, model_ckpt_path, self.multigpu_mode)

    def create_ckpt_dirs(self):
        for model_name in self.models_dict.keys():
            model_ckpt_path = os.path.join(self.cur_ckpt_path, model_name)
            self.models_dict[model_name]['ckpt_path'] = model_ckpt_path
            os.makedirs(model_ckpt_path, exist_ok=True)
    
    def create_log_dirs(self):
        os.makedirs(self.cur_logs_path, exist_ok=True)
        for model_name in self.models_dict.keys():
            self.logs_dict[model_name] = {
                                            'path': os.path.join(self.cur_logs_path, "train_loss_" + self.name),
                                            'data': []
                                        }
            
    def save_ckpt(self, epoch):
        for model_dict in self.models_dict.values():
            save_model(model_dict['model'], os.path.join(model_dict['ckpt_path'], "weights_%d" % epoch), self.multigpu_mode)
            
    def save_logs(self):
        for log in self.logs_dict.values():
            log_path = log['path'] 
            log_data = log['data']
            np.save(log_path, np.asarray(log_data))
    
    def global_forward(self, sample, batch_idx):
        raise NotImplementedError("Each trainer should define global_forward() method")
        
    def schedulers_step(self):
        for scheduler in self.schedulers_dict.values():
            scheduler.step()
    
    def train_epoch(self, train_dataloader):
        for batch_idx, sample in enumerate(train_dataloader):
            if stop_flag:
                break
            self.global_forward(sample, batch_idx)

    def train(self, train_loader, n_epochs):
        for epoch in range(n_epochs):
            if stop_flag:
                break
            self.schedulers_step()
            self.train_epoch(train_dataloader)

            self.save_ckpt(epoch)
            self.save_logs()
            self.cur_epoch += 1

In [None]:
import torch
# from tensorboardX import SummaryWriter
import os
from utils import prepare_path
from networks import networks_dict

logs_dir = "logs"
ckpt_dir = "ckpt"

models = {
    "faceid": "sphereface",
    "denoiser": "udnet"
}

class Trainer():
    def __init__(self, models, exp_name, losses, dataloader):
        """
        Args:
            models (dict): Dictionary with a models to be trained.
            param2 (str): The second parameter.
        """
        self.model_is_initialized = False
        self.models = models
        slef.models_dict = {}
        self.models_ckpt_path = {}
        self.ckpt_folder_path = os.path.join(ckpt_dir, exp_name)

    def init_models(self, cuda=True, last_ckpt=None):
        for model in self.models.items():
            model_name, model_type = model
            self.models_dict[model_name] = networks_dict[model_type]()
        
        for opt in self.opts.items():
            opt_name, opt_params = opt
            opt_params = [self.models_dict[model_name].parameters() for model_name in opt_params]
            
            if len(opt_params) > 1:
                opt_params = itertools.chain(*opt_params)
            else:
                opt_params = opt_params[0]
            
            self.opt_dict[opt_name] = torch.optim.Adam(opt_params, lr=1e-4)
        
        if last_ckpt is not None:        
            for model_obj in self.models.items():
                model_name, model = model_obj
                ckpt_path = os.path.join(ckpt_folder_path, model_name)
                load_model(model, optimizer, ckpt_path)
            if last_ckpt is not None:
                state_dict = torch.load(last_ckpt)
                self.model.load_state_dict(state_dict)
        
        if cuda:
            for k in self.model_dict.keys():
                self.model_dict[k] = self.model_dict[k].cuda()
        
            
        self.model = self.model.cuda()
        self.criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.,8.])).cuda()
        self.opt = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        
        self.model_is_initialized = True

    def run_experiments(self, exp_name, n_epochs, batch_size=2):
        assert self.model_is_initialized, "Model had not been initialized! Use init_model()!"
        
        cur_logs_path = os.path.join(logs_dir, args.name)
        os.makedirs(cur_logs_path, exist_ok=True)
        
#         self.writer = SummaryWriter(cur_logs_path)

        cur_ckpt_path = os.path.join(ckpt_dir, args.name)
        os.makedirs(cur_ckpt_path, exist_ok=True)
        
        for k in self.models_dict.keys():
            model_ckpt_path = os.path.join(cur_ckpt_path, k)
            os.makedirs(model_ckpt_path, exist_ok=True)
            self.models_ckpt_path[k] = model_ckpt_path


        train_data = Brains(crop_size=(64,64), proper_crop_proba=0.5)
        train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=20)
        
        val_data = Brains(folder='../scanmasks-val/', crop=False, proper_crop_proba=-1)
        val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=True, drop_last=False, num_workers=0)
        
        global_train_i = 0
        for i in range(n_epochs):
            print("Training: epoch %d" % i)
            train_loss, global_train_i = self.train_epoch(i, train_dataloader, global_train_i)
            val_loss = self.validate(i, val_dataloader)
            
            if i%10==0:
                torch.save(self.model.state_dict(), ckpt_folder_path + "model_%d_epochs" % i)
            
            
    def train_epoch(self, n_epoch, dataloader, i):
        self.model.train(True)
    
        epoch_ce_loss = []
        epoch_dice = []
        
        for X_batch, masks_batch in dataloader:
        #         X_batch = X_batch[0]
        #         X_batch = X_batch.float()
            probs = self.model(X_batch.cuda())
            _, preds = probs.max(dim=1, keepdim=True)

        #         print(preds.shape, masks_batch.shape)
            loss = self.criterion(probs, masks_batch.squeeze(1).cuda())
            self.writer.add_scalar('Cross entropy loss per iter', loss.item(), i)
            dice = dice_score(preds.float(), masks_batch.cuda().float()).mean()
            self.writer.add_scalar('Dice score', dice.item(), i)
        #         print(X_batch.shape)
        #         print(torchvision.utils.make_grid(X_batch).shape)
            bool_mask, idx = masks_batch.reshape((*masks_batch.shape[:3],-1)).max(dim=-1)
            if torch.any(bool_mask.reshape((1,-1)).byte()==1):
                self.writer.add_image('Input image vs mask vs pred', torch.cat((X_batch[bool_mask.byte()][0], masks_batch[bool_mask.byte()][0].float(), preds[bool_mask.byte()][0].cpu().float()), dim=1), i)

        #         score = dice_score(F.threshold(preds, 0.5, 1), masks_batch)

            # train on batch
            self.opt.zero_grad()
            loss.backward()
            self.opt.step()
            i += 1
            
            epoch_ce_loss.append(loss.cpu().data.numpy())
            epoch_dice.append(dice.cpu().data.numpy())
            
        total_ce_loss = np.mean(epoch_ce_loss)
        self.writer.add_scalar('Train cross entropy loss per epoch', total_ce_loss, n_epoch)
        self.writer.add_scalar('Train dice per epoch', np.mean(epoch_dice), n_epoch)
        return total_ce_loss, i

In [None]:



class ExpRunner():
    def train_epoch(self):
    
    def train(self):
        for epoch in range(self.n_epochs):
            total = 0
            correct = 0

            train_loss_arr = []
            denoiser_loss_arr = []
            faceid_loss_arr = []

            total_loss = 0
    #         faceid_w = 1
    #         denoise_w = 0
    #         if epoch >= 1:
    #             faceid_w = 1
    #         else:
    #             faceid_w = 0
    #         
    #         if epoch in [0,10,15,18]:
    #             if epoch!=0: lr *= 0.1 #lr *= 0.9
    #             optimizer = optim.SGD(itertools.chain(faceid.parameters(), ArcMargin.parameters()), lr=lr, momentum=0.9, weight_decay=5e-4)
            scheduler.step()
            total_loss = train_epoch(dataloader_train, optimizer, total, correct, total_loss, train_loss_arr, total_denoiser_loss_arr, total_faceid_loss_arr)

            save_model(mmnet, os.path.join(denoiser_ckpt_path, "weights_%d" % epoch), multigpu_mode)
            save_model(faceid, os.path.join(faceid_ckpt_path, "weights_%d" % epoch), multigpu_mode)
            save_model(ArcMargin, os.path.join(arcmargin_ckpt_path, "weights_%d" % epoch), multigpu_mode)

            total_train_loss_arr.append(np.mean(train_loss_arr))
            np.save(os.path.join(cur_logs_path,"train_loss_" + args.name), np.asarray(total_train_loss_arr))
            total_denoiser_loss_arr.append(np.mean(denoiser_loss_arr))
            np.save(os.path.join(cur_logs_path,"denoiser_loss_" + args.name), np.asarray(total_denoiser_loss_arr))
            total_faceid_loss_arr.append(np.mean(faceid_loss_arr))
            np.save(os.path.join(cur_logs_path,"faceid_loss_" + args.name), np.asarray(total_faceid_loss_arr))

    #         total_train_acc_arr.append(100. * correct/total)
    #         np.save(os.path.join(cur_logs_path,"train_faceid_acc_" + args.name), np.asarray(total_train_acc_arr))

            grads = []
            for idx, p in enumerate(list(filter(lambda p: p.grad is not None, faceid.parameters()))):
                grads.append([idx, p.grad.data.norm(2).item()])
            np.save(os.path.join(cur_logs_path,"train_grads_" + args.name  + "_%d" % epoch), np.asarray(grads))
            print("\n")

    #         total = 0
    #         correct = 0
    #         train_loss_arr = []
    #         total_loss = 0
    #         train_epoch(dataloader_val, None, total, correct, total_loss, train_loss_arr)
    #         print("\n")
    #         torch.save(denoiser.state_dict(), ckpt_path + "denoiser_" + args.name + "_%d" % epoch)
    #         np.save("train_loss_" + args.name + "_%d" % epoch, np.asarray(train_loss_arr))


            if stop_flag:
                break