# Deep RL 기반 Neural Network Pruning 구현
Neural Network를 pruning 하기 위한 Deep Reinforcement Learning 기반 알고리즘 구현을 목표로 한다.

1. Baseline_Net 학습
    - pruning을 할 Network이며 CIFAR10 dataset을 활용하여 ResNet50을 학습한다.
2. RL기반 Pruning 진행
3. Pruning된 model 재학습
4. model test (Mac 및 param 개수 )

In [2]:
from RL_tool import Env, PolicyNetwork
from datetime import datetime
import torch.optim as optim
import torch

from tqdm import tqdm
import numpy as np

from torch.utils.tensorboard import SummaryWriter

2024-06-13 10:01:30.628825: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-13 10:01:30.859820: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
dtime = datetime.now().strftime("%m%d_%H%M")

writer = SummaryWriter(f'runs/result_{dtime}')

### 1. Baseline_Net 학습
- pruning을 할 Network이며 CIFAR10 dataset을 활용하여 ResNet50을 학습한다.


In [None]:
from resnet_train import train

train(epochs=200, batch_size=128, lr=0.001, model_path='', name=f"ResNet50_{dtime}", is_pruned=False)

### 2. RL기반 Pruning 진행


In [12]:
DNN_path = f"./checkpoints/ResNet50_{dtime}.pth"
env = Env(DNN_path)
policy_network = PolicyNetwork()

learning_rate=0.01
optimizer = optim.Adam(policy_network.parameters(), lr=learning_rate)

In [None]:
Action_list = np.arange(0, 1.00, 0.01)

gamma=0.75
episodes=100

is_best = {"sparsity": 50, "acc":60}
total_rewards = []
total_sparsity = []
total_resnet_acc = []
for episode in range(episodes):
    log_probs = []
    rewards = []
    state = env.reset()
    
    for i in tqdm(range(len(env.order_to_prune)), desc=f"{episode+1}/{episodes}"):
        state = torch.FloatTensor(state).unsqueeze(0)
        action_probs = policy_network(state)
        action = torch.multinomial(action_probs, num_samples=1).item()

        log_prob = torch.log(action_probs.squeeze(0)[action])
        log_probs.append(log_prob)
        next_state, reward = env.step(i, Action_list[action])
        rewards.append(reward)
        
        state = next_state

    discounted_rewards = []
    cumulative_reward = 0
    for reward in reversed(rewards):
        cumulative_reward = reward + gamma * cumulative_reward
        discounted_rewards.insert(0, cumulative_reward)
    
    discounted_rewards = torch.FloatTensor(discounted_rewards)
    discounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-9)
    baseline = discounted_rewards.mean()

    policy_loss = []
    for log_prob, reward in zip(log_probs, discounted_rewards):
        policy_loss.append(-log_prob * (reward - baseline))

    optimizer.zero_grad()
    policy_loss = torch.stack(policy_loss).sum()

    policy_loss.backward()
    optimizer.step()

    
    total_rewards.append(np.mean(rewards))
    total_sparsity.append(state[0]*100)
    total_resnet_acc.append(state[1]*100)

    writer.add_scalar('Total Reward', np.mean(rewards), episode)
    writer.add_scalar('sparsity', round(state[0]*100,2), episode)
    writer.add_scalar('Resnet_acc', round(state[1]*100,2), episode)
    print(f"Episode {episode+1}, return: {np.mean(rewards)}, sparsity: {round(state[0]*100,2)}, Resnet_acc: {round(state[1]*100,2)}/{round(env.resnet.orig_test_acc*100,2)}")
    
    if is_best['sparsity'] < round(state[0]*100,2):
        is_best['sparsity']  =  round(state[0]*100,2)
        env.resnet.save(f"{dtime}_episode{episode}_s{round(state[0]*100,2)}_a{round(state[1]*100,2)}")
    if is_best['acc'] < round(state[1]*100,2):
        is_best['acc'] = round(state[1]*100,2)
        env.resnet.save(f"{dtime}_episode{episode}_s{round(state[0]*100,2)}_a{round(state[1]*100,2)}")

writer.close()

torch.save(policy_network.state_dict(), f"./checkpoints/PolicyNet_{dtime}.pt")

### 3. Pruning된 model 재학습


In [None]:
from resnet_train import train

model_path = "$PATH$" # 재학습할 model.pth 위치
name = "$name$" # 재학습된 model 저장이름
train(epochs=200, batch_size=128, lr=0.001, model_path=model_path, name=name, is_pruned=True)

### 4. model test
- Mac 및 param 개수 

In [None]:
# 
import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn


model_path = "$ Pruned Model Path $"

# model setting
pruned_net = resnet50(weights=ResNet50_Weights.DEFAULT)
pruned_net.fc = nn.Linear(pruned_net.fc.in_features, 10)
pruned_net.load_state_dict(torch.load(model_path))
pruned_net.cuda()

In [18]:
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
  macs, _ = get_model_complexity_info(pruned_net, (3, 224, 224), as_strings=True, backend='pytorch',
                                           print_per_layer_stat=False, verbose=False)
  params = 0
  for i in pruned_net.parameters():
      params += torch.count_nonzero(i).cpu()

  print('##### Pruned Network #####\n')
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8,}'.format('Number of parameters: ', params))


##### Pruned Network #####

Computational complexity:       1.43 GMac
Number of parameters:           7,665,530


In [17]:
baseline_net = resnet50(weights=ResNet50_Weights.DEFAULT)
baseline_net.fc = nn.Linear(baseline_net.fc.in_features, 10)
baseline_net.cuda()

with torch.cuda.device(0):
  macs, _ = get_model_complexity_info(baseline_net, (3, 224, 224), as_strings=True, backend='pytorch',
                                           print_per_layer_stat=False, verbose=False)
  params = 0
  for i in baseline_net.parameters():
      params += torch.count_nonzero(i).cpu()

  print('##### Baseline Network #####\n')
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8,}'.format('Number of parameters: ', params))


##### Baseline Network #####

Computational complexity:       4.13 GMac
Number of parameters:           23,528,522
