In [None]:
def trainer(cfg, data_cfg, train_fold, val_fold, balance_params):
    
    # Get the dataset objects for training and validation
    train_data, val_data, aux_data = dat.get_dataset_objects(cfg, data_cfg, train_fold, val_fold)
    
    # Get the Pytorch DataLoader objects of train and valid data
    train_loader, val_loader, aux_loader, nepoch, cflag = dat.get_dataloader(cfg, train_data, val_data, aux_data, balance_params)
    
    # Initialise chosen model architecture (Backend + Frontend)
    # Equivalent to the "Network" variable in manual mode without weights
    cfg.model_params.update(dict(_input_length=data_cfg.network_sample_length,
                                 _decimated_bins=data_cfg._decimated_bins))

    # Init Network
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    Network = cfg.model(**cfg.model_params)
    Network = Network.to(device)
    Network = torch.nn.DataParallel(Network, device_ids=[0, 2])

    if os.path.exists(cfg.weights_path):
        weights = torch.load(cfg.weights_path, device)
        # Workaround for module. error
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in weights.items():
            if 'module' not in k:
                k = 'module.'+k
            else:
                k = k.replace('features.module.', 'module.features.')
            new_state_dict[k]=v

        Network.load_state_dict(new_state_dict)
        del weights; del new_state_dict; gc.collect()
    
    ## Display
    print("Sample length for training and testing = {}".format(data_cfg.network_sample_length))
    
    # Optimizer and Scheduler (Set to None if unused)
    optimizer = cfg.optimizer(Network.parameters(), **cfg.optimizer_params)
    scheduler = cfg.scheduler(optimizer, **cfg.scheduler_params)
    # Loss function used
    loss_function = cfg.loss_function
    checkpoint = None
    
    return Network, val_loader

In [None]:
""" Prepare Data """
cfg = Validate_1epoch_D4_BEST_background_estimation
# Get data creation/usage configuration
data_cfg = DefaultOTF

# Get input data length
# Used in torch summary and to initialise norm layers
dat.input_sample_length(data_cfg)

# Make export dir
dat.make_export_dir(cfg)

# Prepare input data for training and testing
# This should create/use a dataset and save a copy of the lookup table
dat.get_summary(cfg, data_cfg, cfg.export_dir)

Network, val_loader = trainer(cfg, data_cfg, None, None, None)

In [None]:
def validation_phase(Network, validation_samples):
    # Evaluation of a single validation batch
    with torch.amp.autocast('cuda'):
        # Gradient evaluation is not required for validation and testing
        # Make sure that we don't do a .backward() function anywhere inside this scope
        with torch.no_grad():
            validation_output = Network(validation_samples)
    
    # Returning quantities if saving data
    return validation_output

In [None]:
Network.eval()
save_voutput = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with torch.no_grad():
    pbar = tqdm(val_loader, bar_format='{l_bar}{bar:50}{r_bar}{bar:-10b}', position=0, leave=True)
    for nbatch, (validation_samples, _, _) in enumerate(pbar):
        # Set validation dtype and device
        validation_samples = validation_samples.to(device=device)
        # Run training phase and get loss and accuracy
        voutput = validation_phase(Network, validation_samples)
        save_voutput.append(voutput)