<a href="https://colab.research.google.com/github/jaehong31/mobilenetv2_tutorial_2022/blob/main/mobilenetv2_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cloning Tutorial Codes

In [None]:
!git clone https://github.com/jaehong31/mobilenetv2_tutorial_2022.git

In [None]:
%cd mobilenetv2_tutorial_2022/src

In [None]:
!apt-get install tree
!tree -L 3

# Install additional packages

In [None]:
!pip install timm
!pip install -U PyYAML

# Import packages

In [21]:
from arguments import get_args
from models import get_model
from datasets import get_dataset
import time, datetime
import torch.distributed as dist
from logger import create_logger
import numpy as np
import torch    

# Get arguments (Please see `arguments.py` and `configs/base_cifar10.yaml`)

In [None]:
args = get_args()
logger = create_logger(output_dir=args.log_dir, name=f"{args.tag}")   
args.train.base_lr = float(args.train.base_lr) * args.train.batch_size / 512 
args.train.warmup_lr = float(args.train.warmup_lr) * args.train.batch_size / 512 
args.train.min_lr = float(args.train.min_lr) * args.train.batch_size / 512 

# Get Datasets and Models

MobileNet-v2 is implemented in: `models/backbones/MobileNetV2.py`

Basic model formulation for training including optimizer and lr_scheduler is implemented in: `models/utils/model.py`




In [None]:
dataset = get_dataset(args)
train_loader, test_loader = dataset.get_data_loaders() 
len_train_loader = len(train_loader)
model = get_model(args, len_train_loader, logger)

# Define Evaluation Function

In [24]:
@torch.no_grad()
def evaluate(model, test_loader, logger, loss=torch.nn.CrossEntropyLoss()):    
    sparsify_weights(args, model, logger)

    corrects, totals = 0, 0
    for images, labels in test_loader:
        preds = model(images.to(args.device))
        test_loss = loss(preds, labels.to(args.device))
        
        preds = preds.argmax(dim=1)
        correct = (preds == labels.to(args.device)).sum().item()               
        
        corrects += correct
        totals += preds.shape[0]
    logger.info(f'Accuracy: {(corrects/totals)*100:.2f} % ({corrects}/{totals}), Test Loss: {test_loss:.4f}')


def sparsify_weights(args, model, logger):
    # TODO 
    # 1. if pruning methods are applied, update the weights to be sparse given threshold (args.hyperparameters.XXXX.thr)
    # 2. print weight sparsity (%) using logger.info() (excluding batchnorm params and biases)
    pass

# Training MobileNet-V2

In [None]:
if hasattr(model, 'set_task'):
    logger.info(f'set task')      
    model.set_task()
  
start_time = time.time()
for epoch in range(0, args.train.num_epochs):
    start = time.time()
    model.train()

    tr_losses = 0.
    tr_p_losses = 0.
    # training phase
    for idx, (images, labels) in enumerate(train_loader):
        data_dict = model.observe(images, labels)
        tr_losses += data_dict['loss']
        tr_p_losses += data_dict['penalty']                

    if (epoch + 1) % args.eval.interval_epochs == 0:
        evaluate(model, test_loader, logger)

    epoch_time = time.time() - start
    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
    logger.info('LR: {}, \
                        TR_LOSS: {}, TR_P_LOSS: {}'.format(
                            np.round(data_dict['lr'],6), 
                            np.round(tr_losses/len_train_loader, 4), 
                            np.round(tr_p_losses/len_train_loader, 4)))

if hasattr(model, 'end_task'):
    logger.info(f'end task')      
    model.set_task()

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f'TOTAL TRAINING TIME {total_time_str}')

# **[TODO] Implement L1 Norm during Training**
1. Implement the penalty term in `models/l1norm.py`
2. Change model.method to ***L1NORM*** in `configs/base_cifar10.yaml`
3. Implement ***sparsity_weights*** function
4. Control hyperparameters

# **[TODO] Implement Grouped Norm (a.k.a, Structured Sparsity) during Training**
Same process