In [17]:
import numpy as np
import os
import time
import random
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

from load_data import ECGDataset, ECGCollate, SmartBatchSampler, load_dataset, load_ecg
from resnet1d import ResNet1D
from mask import Mask

%load_ext autoreload
%autoreload 2

os.environ['KMP_DUPLICATE_LIB_OK']='True' # To prevent the kernel from dying.

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def create_tqdm_bar(iterable, desc):
    return tqdm(enumerate(iterable),total=len(iterable), ncols=150, desc=desc)


def train_model(model, train_loader, val_loader, loss_func, tb_logger, epochs=10, name="default"):
    """
    Train the classifier for a number of epochs.
    """
    loss_cutoff = len(train_loader) // 10
    optimizer = torch.optim.Adam(model.parameters(), 0.001)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                            mode='min', 
                                                            factor=0.1, # like in Hannun et al.
                                                            patience=2 # 2 in Hannun et al. "two consecutive epochs"
                                                            )
    for epoch in range(epochs):

        # Training stage, where we want to update the parameters.
        model.train()  # Set the model to training mode

        training_loss = []
        validation_loss = []

        # Create a progress bar for the training loop.
        training_loop = create_tqdm_bar(train_loader, desc=f'Training Epoch [{epoch + 1}/{epochs}]')
        for train_iteration, batch in training_loop:
            optimizer.zero_grad() # Reset the gradients - VERY important! Otherwise they accumulate.
            ecgs, labels = batch # Get the images and labels from the batch, in the fashion we defined in the dataset and dataloader.
            ecgs, labels = ecgs.to(device), labels.to(device) # Send the data to the device (GPU or CPU) - it has to be the same device as the model.


            pred = model(ecgs) # Stage 1: Forward().
            loss = loss_func(pred, labels) # Compute the loss over the predictions and the ground truth.
            loss.backward()  # Stage 2: Backward().
            optimizer.step() # Stage 3: Update the parameters.
            # scheduler.step() # Update the learning rate.


            training_loss.append(loss.item())
            training_loss = training_loss[-loss_cutoff:]

            # Update the progress bar.
            training_loop.set_postfix(curr_train_loss = "{:.8f}".format(np.mean(training_loss)),
                                      lr = "{:.8f}".format(optimizer.param_groups[0]['lr'])
            )

            # Update the tensorboard logger.
            #tb_logger.add_scalar(f'classifier_{name}/train_loss', loss.item(), epoch * len(train_loader) + train_iteration)

        # Validation stage, where we don't want to update the parameters. Pay attention to the classifier.eval() line
        # and "with torch.no_grad()" wrapper.
        model.eval()
        val_loop = create_tqdm_bar(val_loader, desc=f'Validation Epoch [{epoch + 1}/{epochs}]')

        with torch.no_grad():
            for val_iteration, batch in val_loop:
                ecgs, labels = batch
                ecgs, labels = ecgs.to(device), labels.to(device)

                pred = model(ecgs)
                loss = loss_func(pred, labels)
                validation_loss.append(loss.item())
                # Update the progress bar.
                val_loop.set_postfix(val_loss = "{:.8f}".format(np.mean(validation_loss)))

                # Update the tensorboard logger.
                #tb_logger.add_scalar(f'classifier_{name}/val_loss', loss.item(), epoch * len(val_loader) + val_iteration)
        
        scheduler.step(np.mean(validation_loss))
    
    return model


In [3]:
def prune(pruning_fraction: float = 0.2, pruning_layers_to_ignore: str = None, trained_model = None, current_mask: Mask = None) : 
    """
    A one iteration of pruning : returns the new updated mask after pruning.

    trained_model : the original fully trained model.
    pruning_fraction = The fraction of additional weights to prune from the network.
    layers_to_ignore = A comma-separated list of addititonal tensors that should not be pruned.
    """
    current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

    # Determine the number of weights that need to be pruned.
    number_of_remaining_weights = np.sum([np.sum(v) for v in current_mask.values()])
    number_of_weights_to_prune = np.ceil(pruning_fraction * number_of_remaining_weights).astype(int)

    # Determine which layers can be pruned.
    prunable_tensors = set(trained_model.prunable_layer_names)
    if pruning_layers_to_ignore:
        prunable_tensors -= set(pruning_layers_to_ignore.split(','))
    print("prunable_tensors : \n", prunable_tensors)
    # Get the model weights.
    weights = {k: v.clone().cpu().detach().numpy()
                for k, v in trained_model.state_dict().items()
                if k in prunable_tensors}

    # Create a vector of all the unpruned weights in the model.
    weight_vector = np.concatenate([v[current_mask[k] == 1] for k, v in weights.items()])
    threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune]

    new_mask = Mask({k: np.where(np.abs(v) > threshold, current_mask[k], np.zeros_like(v))
                        for k, v in weights.items()})
    for k in current_mask:
        if k not in new_mask: # if this weight was already pruned add it to the new mask
            new_mask[k] = current_mask[k]

    return new_mask



In [27]:
class PrunedModel(nn.Module): # Remplacer Model par ResNet1D 
    @staticmethod
    def to_mask_name(name):
        return 'mask_' + name.replace('.', '___')

    def __init__(self, model: ResNet1D, mask: Mask):
        if isinstance(model, PrunedModel): raise ValueError('Cannot nest pruned models.')
        super(PrunedModel, self).__init__()
        self.model = model

        for k in self.model.prunable_layer_names:
            if k not in mask: raise ValueError('Missing mask value {}.'.format(k))
            if not np.array_equal(mask[k].shape, np.array(self.model.state_dict()[k].shape)):
                raise ValueError('Incorrect mask shape {} for tensor {}.'.format(mask[k].shape, k))

        for k in mask:
            if k not in self.model.prunable_layer_names:
                raise ValueError('Key {} found in mask but is not a valid model tensor.'.format(k))

        # for k, v in mask.items(): self.register_buffer(PrunedModel.to_mask_name(k), v.float())
        # self._apply_mask()
        device = next(model.parameters()).device 

        for k, v in mask.items(): 
            # On envoie le masque sur le même device que le modèle AVANT de l'enregistrer
            self.register_buffer(PrunedModel.to_mask_name(k), v.float().to(device))
            
        self._apply_mask()

    def _apply_mask(self):
        for name, param in self.model.named_parameters():
            if hasattr(self, PrunedModel.to_mask_name(name)):
                param.data *= getattr(self, PrunedModel.to_mask_name(name))

    def forward(self, x):
        self._apply_mask()
        return self.model.forward(x)

    @property
    def prunable_layer_names(self):
        return self.model.prunable_layer_names

    # @property
    # def output_layer_names(self):
    #     return self.model.output_layer_names

    # @property
    # def loss_criterion(self):
    #     return self.model.loss_criterion

    # def save(self, save_location, save_step):
    #     self.model.save(save_location, save_step)

    # @staticmethod
    # def default_hparams(): raise NotImplementedError()
    # @staticmethod
    # def is_valid_model_name(model_name): raise NotImplementedError()
    # @staticmethod
    # def get_model_from_name(model_name, outputs, initializer): raise NotImplementedError()

In [29]:
pruning_params = {
                  "p_init" : 30,
                  "target_reduction_factor" : 120, 
                  "alpha" : 1.1,
                  "pruning_layers_to_ignore" : None
                  }


def run_lth_ecg(pruning_params, network, train_loader, val_loader, loss_func) : 
    # Randomly initialize the given DL network D. (quelle initialisation ? Hannun et al. -> "He normal")
    pruning_fraction = pruning_params["p_init"]/100
    current_mask = Mask.ones_like(network).numpy()
    initial_weights_number = np.sum([np.sum(v) for v in current_mask.values()]) # eta 
    print(f"eta = {initial_weights_number:.2e}")
    # current_model = network
    initial_untrained_model = copy.deepcopy(network)
    
    remaining_weights_number = initial_weights_number

    D = copy.deepcopy(network) #current_network
    step = 0
    while (initial_weights_number/remaining_weights_number) < pruning_params["target_reduction_factor"]:
        print("="*60,f"STEP : {step}")
        print(f"remaining_weights_number = {remaining_weights_number:.2e}")
        print("current reduction factor = ", np.round(initial_weights_number/remaining_weights_number, 2))
        print("="*60, "\n")
        # Train the DL network with the given data x.
        D = train_model(D, train_loader, val_loader, loss_func, name = "lth_ecg", epochs=1,tb_logger=None)

        # Prune p_init% of weights which are of least magnitude
        new_mask = prune(pruning_fraction, pruning_params["pruning_layers_to_ignore"], D)

        #D_sparse = PrunedModel(D,new_mask) 

        pruning_fraction = pruning_fraction**(1/pruning_params['alpha']) # alpha = 1.1

        # reset unpruned weights to their initial random values and D = D_sparse
        D = PrunedModel(model=copy.deepcopy(initial_untrained_model), mask=new_mask).to(device)
        
        # remaining_weights_number = # On utilise la somme native de Python, et .item() pour extraire la valeur du tenseur
        remaining_weights_number = sum(v.sum().item() for v in new_mask.values())

        current_mask = new_mask
        step+=1

    return current_mask



In [10]:
print("Loading training set...")
train = load_dataset("train.json",256)
train_ecgs, train_labels = train

Loading training set...


100%|██████████| 7676/7676 [00:01<00:00, 4233.13it/s]


In [8]:
# reduciton of size to improve training time
train_ecgs, train_labels = train_ecgs[:1000], train_labels[:1000]

In [9]:
print("Loading dev set...")
val_ecgs,val_labels = load_dataset("dev.json",256)

Loading dev set...


100%|██████████| 852/852 [00:00<00:00, 4257.38it/s]


In [25]:
# reduciton of size to improve training time
val_ecgs, val_labels = val_ecgs[:100], val_labels[:100]

In [11]:
train_dataset = ECGDataset(train_ecgs, train_labels)
val_dataset = ECGDataset(val_ecgs, val_labels)

MEAN :  7.4661856  STD :  236.10312
self.classes :  ['A', 'N', 'O', '~']
self.class_to_int :  {'A': 0, 'N': 1, 'O': 2, '~': 3}
MEAN :  8.029898  STD :  242.35907
self.classes :  ['A', 'N', 'O', '~']
self.class_to_int :  {'A': 0, 'N': 1, 'O': 2, '~': 3}


In [12]:
# Instanciation du Sampler intelligent
train_batch_sampler = SmartBatchSampler(train_dataset, 32)
val_batch_sampler = SmartBatchSampler(val_dataset, 32)

train_collate_fn = ECGCollate(
    pad_val_x=train_dataset.pad_value_x_normalized,
    num_classes=train_dataset.num_classes
)

val_collate_fn = ECGCollate(
    pad_val_x=val_dataset.pad_value_x_normalized,
    num_classes=val_dataset.num_classes
)

train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_batch_sampler, 
    collate_fn=train_collate_fn,
    num_workers=4
)

val_loader = DataLoader(
    val_dataset,
    batch_sampler=val_batch_sampler, 
    collate_fn=val_collate_fn,
    num_workers=4
)

Tri du dataset par longueur pour minimiser le padding...
Tri du dataset par longueur pour minimiser le padding...


In [16]:
epochs = 20

loss_func = nn.CrossEntropyLoss() # The loss function we use for classification.

# make model
device_str = "cuda"
device = torch.device(device_str if torch.cuda.is_available() else "cpu")
print(f"Running on {device}")

kernel_size = 16 # 16 in Hannun et al.
stride = 2
n_block = 16 # 16 in Hannun et al.
downsample_gap = 2 # 2 in Hannun et al.
increasefilter_gap = 4 # 4 in Hannun et al.

model = ResNet1D(
    in_channels=1, 
    base_filters=32, # 32 in Hannun et al.
    kernel_size=kernel_size, 
    stride=stride, 
    groups=1, # like a classical ResNet
    n_block=n_block, 
    n_classes=4, 
    downsample_gap=downsample_gap, 
    increasefilter_gap=increasefilter_gap, 
    use_bn=True,
    use_do=True,
    verbose = False
    ).to(device)

Running on cuda


In [30]:
final_mask = run_lth_ecg(pruning_params, model, train_loader, val_loader, loss_func) 

eta = 1.05e+07
remaining_weights_number = 1.05e+07
current reduction factor =  1.0



Training Epoch [1/1]: 100%|██████████████████████████████████████████████| 240/240 [00:47<00:00,  5.01it/s, curr_train_loss=0.72173179, lr=0.00100000]
Validation Epoch [1/1]: 100%|████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 15.54it/s, val_loss=0.69594120]


prunable_tensors : 
 {'basicblock_list.3.conv1.conv.weight', 'basicblock_list.1.conv1.conv.weight', 'basicblock_list.2.conv1.conv.weight', 'basicblock_list.3.conv2.conv.weight', 'basicblock_list.12.conv1.conv.weight', 'basicblock_list.10.conv2.conv.weight', 'basicblock_list.15.conv1.conv.weight', 'basicblock_list.5.conv1.conv.weight', 'basicblock_list.5.conv2.conv.weight', 'basicblock_list.13.conv1.conv.weight', 'basicblock_list.13.conv2.conv.weight', 'basicblock_list.15.conv2.conv.weight', 'first_block_conv.conv.weight', 'basicblock_list.7.conv2.conv.weight', 'basicblock_list.1.conv2.conv.weight', 'basicblock_list.11.conv1.conv.weight', 'basicblock_list.4.conv2.conv.weight', 'basicblock_list.0.conv1.conv.weight', 'basicblock_list.14.conv1.conv.weight', 'basicblock_list.4.conv1.conv.weight', 'basicblock_list.7.conv1.conv.weight', 'basicblock_list.8.conv2.conv.weight', 'basicblock_list.9.conv2.conv.weight', 'basicblock_list.0.conv2.conv.weight', 'basicblock_list.10.conv1.conv.weight', '

Training Epoch [1/1]: 100%|██████████████████████████████████████████████| 240/240 [00:49<00:00,  4.85it/s, curr_train_loss=0.62704772, lr=0.00100000]
Validation Epoch [1/1]: 100%|████████████████████████████████████████████████████████████████████| 27/27 [00:01<00:00, 15.01it/s, val_loss=0.65502224]


KeyError: 'first_block_conv.conv.weight'