<a href="https://colab.research.google.com/github/nightted/ML-LeeHongYi-HW/blob/master/HW_7_Network_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!gdown --id '19CzXudqN58R3D-1G8KeFWk8UDQwlb8is' --output food-11.zip
!unzip food-11.zip

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image 
import matplotlib.pyplot as plt

In [4]:
# Define Student-Net : MobileNet
class StudentNet(nn.Module):

  def __init__(self, b_size = 16 , width_multi = 1):

    super(StudentNet,self).__init__()
    # Define the sub_blok size
    block_multiplier = [1, 2, 4, 8, 16, 16, 16, 16]
    block_size = [b_size*size for size in block_multiplier]

    #Unknown method 照做XD
    for i in range(3,7):
      block_size[i] = int(block_size[i]*width_multi) # width_multi as pourning ratio 

    self.cnn = nn.Sequential(
        
        # 第一個 layer 先不做 DW&PW
        nn.Sequential(
            nn.Conv2d(3,block_size[0],3,1,1),
            nn.BatchNorm2d(block_size[0]),
            nn.ReLU6(), # Relu6 限制最大輸出為6,最小輸出為0, mobile-net的Activation皆全部使用Relu6
            nn.MaxPool2d(2,2,0)
        )
        ,

        nn.Sequential(
            nn.Conv2d(block_size[0],block_size[0],3,1,1,groups=block_size[0]), # Depth-Wise Convlution
            nn.BatchNorm2d(block_size[0]),
            nn.ReLU6(),    # Relu6 限制最大輸出為6,最小輸出為0, mobile-net的Activation皆全部使用Relu6
            nn.Conv2d(block_size[0],block_size[1],1),  # Point-Wise Convlution
            #這邊不用再過 Relu, 經驗上 PW 完再過 Relu 效果會變差
            nn.MaxPool2d(2,2,0) # down sampling 
        )
        ,

        nn.Sequential(
            nn.Conv2d(block_size[1],block_size[1],3,1,1,groups=block_size[1]),
            nn.BatchNorm2d(block_size[1]),
            nn.ReLU6(),    
            nn.Conv2d(block_size[1],block_size[2],1),           
            nn.MaxPool2d(2,2,0)
        )
        ,

        nn.Sequential(
            nn.Conv2d(block_size[2],block_size[2],3,1,1,groups=block_size[2]), 
            nn.BatchNorm2d(block_size[2]),
            nn.ReLU6(),    
            nn.Conv2d(block_size[2],block_size[3],1), 
            nn.MaxPool2d(2,2,0) 
        )
        ,
        #這邊就不再做 down-sampling 
        nn.Sequential(
            nn.Conv2d(block_size[3],block_size[3],3,1,1,groups=block_size[3]), 
            nn.BatchNorm2d(block_size[3]),
            nn.ReLU6(),    
            nn.Conv2d(block_size[3],block_size[4],1) 
        )
        , 

        nn.Sequential(
            nn.Conv2d(block_size[4],block_size[4],3,1,1,groups=block_size[4]), 
            nn.BatchNorm2d(block_size[4]),
            nn.ReLU6(),    
            nn.Conv2d(block_size[4],block_size[5],1) 
        )
        , 

        nn.Sequential(
            nn.Conv2d(block_size[5],block_size[5],3,1,1,groups=block_size[5]), 
            nn.BatchNorm2d(block_size[5]),
            nn.ReLU6(),    
            nn.Conv2d(block_size[5],block_size[6],1) 
        )
        , 

        nn.Sequential(
            nn.Conv2d(block_size[6],block_size[6],3,1,1,groups=block_size[6]), 
            nn.BatchNorm2d(block_size[6]),
            nn.ReLU6(),    
            nn.Conv2d(block_size[6],block_size[7],1) 
        )
        , 

        # 這邊我們採用Global Average Pooling。
        # 如果輸入圖片大小不一樣的話，就會因為Global Average Pooling壓成一樣的形狀，這樣子接下來做FC就不會對不起來。
        nn.AdaptiveAvgPool2d((1,1)),

    )

    self.fc = nn.Sequential(
        #directly project to dim = 11
        nn.Linear(block_size[7],11)
    )
  
  def forward(self,x):

    out = self.cnn(x)
    out = out.view(out.size()[0],-1)
    out = self.fc(out)

    return out


In [5]:
# Pre-process the data
import os 
class FoodData(Dataset):

  def __init__(self,training_mode = True):

    img_path = './food-11'
    if training_mode:
      path = os.path.join(img_path,'training')
      #data augumentation
      self.transform = transforms.Compose([
                          transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
                          transforms.RandomHorizontalFlip(),
                          transforms.RandomRotation(15),
                          transforms.ToTensor(),
                                          ]) 
    else :
      path = os.path.join(img_path,'validation')
      self.transform = transforms.Compose([
                          transforms.CenterCrop(256),
                          transforms.ToTensor(),
                                          ])

    self.x_path = [ os.path.join(path,paths) for paths in sorted(os.listdir(path))]
    self.y_label = [ paths.split('_')[0] for paths in sorted(os.listdir(path))]

  def __getitem__(self,index):
    
    Data_X = Image.open(self.x_path[index])
    Data_X = self.transform(Data_X)
    Data_Y = int(self.y_label[index]) # 槓!! 每次都忘記轉 INT @@ .....

    return Data_X ,Data_Y
  
  def __len__(self):

    return len(self.x_path)

def dataloader(mode , batch_size = 32):

  training_mode = True if mode == 'training' else False
  dataset = FoodData(training_mode)
  dataloader = torch.utils.data.DataLoader(dataset ,batch_size=batch_size ,shuffle=training_mode)

  return dataloader


In [6]:
#Download pre-train Net
!gdown --id '12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL' --output student_custom_small.bin

Downloading...
From: https://drive.google.com/uc?id=12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL
To: /content/student_custom_small.bin
  0% 0.00/1.05M [00:00<?, ?B/s]100% 1.05M/1.05M [00:00<00:00, 69.9MB/s]


In [7]:
def NetPruning(student_net,new_student_net):
  params = student_net.state_dict()
  new_params = new_student_net.state_dict()

  #Grab the idx of neuron with top gamma value
  selected_neuron_index = []
  for i in range(8):
    #different layer i 
    gamma = params["cnn.{}.1.weight".format(i)]
    gamma_new = new_params["cnn.{}.1.weight".format(i)]
    len_net = len(gamma) #extract the numbers of neuron in old net
    len_net_new = len(gamma_new )#extract the numbers of neuron in new net

    priority = torch.argsort(gamma,descending=True) #Sort the neuron idx by their gamma value 
    selected_neuron_index.append(priority[:len_net_new])
    #print("layer :{}".format(i),",original len:{}".format(len_net),",after shrink :{}".format(len_net_new),
          #",shrink ratio:{}%".format(int(len_net_new/len_net*100)),",select neuron:",priority) #Grab the top #(new_net neuron) gamma value idx of neuron  

  #print(len_net,len_net_new,len_net_new/len_net*100)
  #print("select neuron",selected_neuron_index)

  NOWIN_LAYER = 1 #set the current layer mark(note that it's start from the "second" layer.)
  for (subnet,para) , (subnet_new,para_new) in zip(params.items(),new_params.items()):

    #Only deal with the cnn layers
    if subnet.startswith('cnn') and para.size() != torch.Size([]) and NOWIN_LAYER != len(selected_neuron_index):
      # if is in PW layer(when encounter 3.weight), special handle for NOWIN_LAYER transform from NOWIN_LAYER -> NOWIN_LAYER+1
      if subnet.endswith("3.weight"):
        # if is in "LAST" PW layer , only need to delete the COLUMN VECTOR(the deleted neuron in NOWIN_LAYER) 
        if len(selected_neuron_index) == NOWIN_LAYER+1:
          new_params[subnet] = para[ : , selected_neuron_index[NOWIN_LAYER] ]
        # if is NOT in "LAST" PW layer, delete the COLUMN VECTOR(the deleted neuron in NOWIN_LAYER) and ROW VECTOR(the deleted neuron in NOWIN_LAYER+1)
        else:
          new_params[subnet] = para[selected_neuron_index[NOWIN_LAYER+1] ][ : , selected_neuron_index[NOWIN_LAYER] ]
        # After done 3.weight modification ,do layer transform from NOWIN_LAYER -> NOWIN_LAYER+1
        NOWIN_LAYER += 1  

      else:
      # In cnn.{layer}.{0,1,2} or "cnn.{layer}.{3}.bias" , only need to grab the select neuron .(Or in matrix operation , only need to delete the "ROW VECTOR". )
        new_params[subnet] = para[ selected_neuron_index[NOWIN_LAYER] ]
    
    else:
    # In F/C layer only need to copy the whole old_net to new_net 
      new_params[subnet] = para


  new_student_net.load_state_dict(new_params) #Remember to load the dict!!!!!
  return new_student_net

Concept of CNN multiple and Pruning :
![alt text](https://drive.google.com/uc?export=view&id=1fStw0OsajARGzjjOKlv7C1DBSsEDkDWr)

In [8]:
# Prepare the dataloader 
train_dataloader = dataloader('training', batch_size=32)
valid_dataloader = dataloader('validation', batch_size=32)

In [9]:
# Prepare the Net
student_net = StudentNet().cuda()
student_net.load_state_dict(torch.load("/content/student_custom_small.bin"))

<All keys matched successfully>

In [10]:
optimizer = optim.AdamW(student_net.parameters(),lr = 0.001)
criterion = F.cross_entropy

In [11]:
def run_epoch(dataloader, update=True, alpha=0.5):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 處理 input
        inputs, labels = batch_data
        inputs = inputs.cuda()
        labels = labels.cuda()
  
        logits = student_net(inputs)
        loss = criterion(logits, labels)
        if update:
            loss.backward()
            optimizer.step()

        total_hit += torch.sum(torch.argmax(logits, dim=1) == labels).item()
        total_num += len(inputs)
        total_loss += loss.item() * len(inputs)

    return total_loss / total_num, total_hit / total_num

now_width_mult = 1
for i in range(5):
    now_width_mult *= 0.95
    new_net = StudentNet(width_multi=now_width_mult).cuda()
    params = student_net.state_dict()
    student_net = NetPruning(student_net, new_net)
    now_best_acc = 0
    for epoch in range(5):
        student_net.train()
        train_loss, train_acc = run_epoch(train_dataloader, update=True)
        student_net.eval()
        valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)
        # 在每個width_mult的情況下，存下最好的model。
        '''
        if valid_acc > now_best_acc:
            now_best_acc = valid_acc
            torch.save(net.state_dict(), f'custom_small_rate_{now_width_mult}.bin')
        '''
        print('rate {:6.4f} epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(now_width_mult, 
            epoch, train_loss, train_acc, valid_loss, valid_acc))


rate 0.9500 epoch   0: train loss: 0.4695, acc 0.8675 valid loss: 1.1268, acc 0.8012
rate 0.9500 epoch   1: train loss: 0.4903, acc 0.8632 valid loss: 1.1077, acc 0.8012
rate 0.9500 epoch   2: train loss: 0.4795, acc 0.8667 valid loss: 1.1079, acc 0.7983
rate 0.9500 epoch   3: train loss: 0.4833, acc 0.8683 valid loss: 1.1157, acc 0.7927
rate 0.9500 epoch   4: train loss: 0.4809, acc 0.8695 valid loss: 1.1585, acc 0.7968
rate 0.9025 epoch   0: train loss: 0.5926, acc 0.8417 valid loss: 1.1729, acc 0.7808
rate 0.9025 epoch   1: train loss: 0.5810, acc 0.8399 valid loss: 1.1631, acc 0.7854
rate 0.9025 epoch   2: train loss: 0.5928, acc 0.8390 valid loss: 1.1649, acc 0.7834
rate 0.9025 epoch   3: train loss: 0.5982, acc 0.8400 valid loss: 1.2053, acc 0.7813
rate 0.9025 epoch   4: train loss: 0.5967, acc 0.8419 valid loss: 1.1885, acc 0.7822
rate 0.8574 epoch   0: train loss: 0.7265, acc 0.8053 valid loss: 1.2153, acc 0.7653
rate 0.8574 epoch   1: train loss: 0.6873, acc 0.8080 valid loss:

In [12]:
'''
#Grab the idx of neuron with top gamma value
selected_neuron_index = []
for i in range(8):
  #different layer i 
  gamma = params["cnn.{}.1.weight".format(i)]
  gamma_new = new_params["cnn.{}.1.weight".format(i)]
  len_net = len(gamma) #extract the numbers of neuron in old net
  len_net_new = len(gamma_new )#extract the numbers of neuron in new net

  priority = torch.argsort(gamma) #Sort the neuron idx by their gamma value 
  selected_neuron_index.append(priority[:len_net_new])
  #print("layer :{}".format(i),",original len:{}".format(len_net),",after shrink :{}".format(len_net_new),
        #",shrink ratio:{}%".format(int(len_net_new/len_net*100)),",select neuron:",priority) #Grab the top #(new_net neuron) gamma value idx of neuron  

#print(len_net,len_net_new,len_net_new/len_net*100)
#print("select neuron",selected_neuron_index)
'''

'\n#Grab the idx of neuron with top gamma value\nselected_neuron_index = []\nfor i in range(8):\n  #different layer i \n  gamma = params["cnn.{}.1.weight".format(i)]\n  gamma_new = new_params["cnn.{}.1.weight".format(i)]\n  len_net = len(gamma) #extract the numbers of neuron in old net\n  len_net_new = len(gamma_new )#extract the numbers of neuron in new net\n\n  priority = torch.argsort(gamma) #Sort the neuron idx by their gamma value \n  selected_neuron_index.append(priority[:len_net_new])\n  #print("layer :{}".format(i),",original len:{}".format(len_net),",after shrink :{}".format(len_net_new),\n        #",shrink ratio:{}%".format(int(len_net_new/len_net*100)),",select neuron:",priority) #Grab the top #(new_net neuron) gamma value idx of neuron  \n\n#print(len_net,len_net_new,len_net_new/len_net*100)\n#print("select neuron",selected_neuron_index)\n'

In [13]:
'''
NOWIN_LAYER = 1 #set the current layer mark(note that it's start from the "second" layer.)
for (subnet,para) , (subnet_new,para_new) in zip(params.items(),new_params.items()):

  #Only deal with the cnn layers
  if subnet.startswith('cnn') and para.size() != torch.Size([]) and NOWIN_LAYER != len(selected_neuron_index):
    # if is in PW layer(when encounter 3.weight), special handle for NOWIN_LAYER transform from NOWIN_LAYER -> NOWIN_LAYER+1
    if subnet.endswith("3.weight"):
      # if is in "LAST" PW layer , only need to delete the COLUMN VECTOR(the deleted neuron in NOWIN_LAYER) 
      if len(selected_neuron_index) == NOWIN_LAYER+1:
        new_params[subnet] = para[ : , selected_neuron_index[NOWIN_LAYER] ]
      # if is NOT in "LAST" PW layer, delete the COLUMN VECTOR(the deleted neuron in NOWIN_LAYER) and ROW VECTOR(the deleted neuron in NOWIN_LAYER+1)
      else:
        new_params[subnet] = para[selected_neuron_index[NOWIN_LAYER+1] ][ : , selected_neuron_index[NOWIN_LAYER] ]
      # After done 3.weight modification ,do layer transform from NOWIN_LAYER -> NOWIN_LAYER+1
      NOWIN_LAYER += 1  

    else:
    # In cnn.{layer}.{0,1,2} or "cnn.{layer}.{3}.bias" , only need to grab the select neuron .(Or in matrix operation , only need to delete the "ROW VECTOR". )
      new_params[subnet] = para[ selected_neuron_index[NOWIN_LAYER] ]
  
  else:
  # In F/C layer only need to copy the whole old_net to new_net 
    new_params[subnet] = para
'''

'\nNOWIN_LAYER = 1 #set the current layer mark(note that it\'s start from the "second" layer.)\nfor (subnet,para) , (subnet_new,para_new) in zip(params.items(),new_params.items()):\n\n  #Only deal with the cnn layers\n  if subnet.startswith(\'cnn\') and para.size() != torch.Size([]) and NOWIN_LAYER != len(selected_neuron_index):\n    # if is in PW layer(when encounter 3.weight), special handle for NOWIN_LAYER transform from NOWIN_LAYER -> NOWIN_LAYER+1\n    if subnet.endswith("3.weight"):\n      # if is in "LAST" PW layer , only need to delete the COLUMN VECTOR(the deleted neuron in NOWIN_LAYER) \n      if len(selected_neuron_index) == NOWIN_LAYER+1:\n        new_params[subnet] = para[ : , selected_neuron_index[NOWIN_LAYER] ]\n      # if is NOT in "LAST" PW layer, delete the COLUMN VECTOR(the deleted neuron in NOWIN_LAYER) and ROW VECTOR(the deleted neuron in NOWIN_LAYER+1)\n      else:\n        new_params[subnet] = para[selected_neuron_index[NOWIN_LAYER+1] ][ : , selected_neuron_index

In [14]:
#new_params['fc.0.weight'].shape

In [15]:
#params['fc.0.weight'].shape

In [16]:
#Check the Net paramter priority structure
#for (subnet,para) , (subnet_new,para_new) in zip(params.items(),new_params.items()):
#  print(subnet)