In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import os,time
import torch.optim as optim
import torchvision.models as models

In [2]:
fn_log = open('log.log','w')

In [3]:
def printLogFile(str,end_str='\n'):
    print(str,end=end_str)
    print(str,end=end_str,file=fn_log)
    fn_log.flush()

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
class StudentNet(nn.Module):
    '''
    在这个Net里面 我们会使用 Depthwise & Pointwise Convolution Layer 来叠 model
    你会发现,将原本的Convolution Layer 换成 Dw & Pw 后, Acurracy 通常不会降低很多
    
    另外,取名为 StudentNet 是因为这个 Model 后续要做 Knowledge Distillation
    '''
    
    def __init__(self,base=16,width_mult=1):
        '''
        Args:
            base: 这个model一开始的ch数量 每过一层都会*2 直到base*16为止
            width_mult: 为了之后的 Network Pruning使用,在base*8 chs的Layer上会 * width_mult代表剪枝后的ch数量
        '''
        super(StudentNet,self).__init__()
        multiplier = [1,2,4,8,16,16,16,16]
        
        # bandwidth: 每一个layer所使用的channel数量
        bandwidth = [base * m for m in multiplier]
        
        # 我们只是Pruning第三层以后的layer  ??? why not pruning layer 8
        for i in range(3,7):
            bandwidth[i] = int(bandwidth[i]*width_mult)
            
        self.cnn = nn.Sequential(
            # 第一层我们通常不做拆解Convolution Layer
            nn.Sequential(
                nn.Conv2d(3,bandwidth[0],3,1,1),
                nn.BatchNorm2d(bandwidth[0]),
                nn.ReLU6(),
                nn.MaxPool2d(2,2,0)
            ),
            
            # 接下来开始pruning
            nn.Sequential(
                # DW
                nn.Conv2d(bandwidth[0],bandwidth[0],3,1,1,groups=bandwidth[0]),
                # Batch Normalization
                nn.BatchNorm2d(bandwidth[0]),
                # RELU6 是限制neural最小只能到0,最大只能到6。MobileNet都是用的RELU6
                # 使用RELU6是因为如果数字过大时,不方便压缩到float16,也不方便之后的parameters quantization,所以用R
                nn.ReLU6(),
                # PW
                nn.Conv2d(bandwidth[0],bandwidth[1],1),
                # 过完PW后不需要再过RELU,经验上PW+RELU效果都会变差
                # 每过完一个block就进行down sampling
                nn.MaxPool2d(2,2,0)
            ),
            
            nn.Sequential(
                nn.Conv2d(bandwidth[1],bandwidth[1],3,1,1,groups=bandwidth[1]),
                nn.BatchNorm2d(bandwidth[1]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[1],bandwidth[2],1),
                nn.MaxPool2d(2,2,0),                
            ),

            nn.Sequential(
                nn.Conv2d(bandwidth[2],bandwidth[2],3,1,1,groups=bandwidth[2]),
                nn.BatchNorm2d(bandwidth[2]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[2],bandwidth[3],1),
                nn.MaxPool2d(2,2,0),                
            ),
            
            nn.Sequential(
                nn.Conv2d(bandwidth[3],bandwidth[3],3,1,1,groups=bandwidth[3]),
                nn.BatchNorm2d(bandwidth[3]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[3],bandwidth[4],1),
                nn.MaxPool2d(2,2,0),                
            ),
            
            nn.Sequential(
                nn.Conv2d(bandwidth[4],bandwidth[4],3,1,1,groups=bandwidth[4]),
                nn.BatchNorm2d(bandwidth[4]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[4],bandwidth[5],1),
                nn.MaxPool2d(2,2,0),                
            ),            
            
            nn.Sequential(
                nn.Conv2d(bandwidth[5],bandwidth[5],3,1,1,groups=bandwidth[5]),
                nn.BatchNorm2d(bandwidth[5]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[5],bandwidth[6],1),
                nn.MaxPool2d(2,2,0),                
            ),            

            nn.Sequential(
                nn.Conv2d(bandwidth[6],bandwidth[6],3,1,1,groups=bandwidth[6]),
                nn.BatchNorm2d(bandwidth[6]),
                nn.ReLU6(),
                nn.Conv2d(bandwidth[6],bandwidth[7],1),
                nn.MaxPool2d(2,2,0),                
            ),            

            # 这里我采用 global average pooling,
            # 如果输入的图片不一样 会因为 global average pooling 压成一样的形状,这样后面作FC就不会冲突
            # ？？？ 它这里没有采用对 dataset 做 resize transform
            nn.AdaptiveAvgPool2d((1,1)),
        )
        
        self.fc = nn.Sequential(
            # 这里我们直接 project 到11纬来输出答案
            nn.Linear(bandwidth[7],11),
        )

    def forward(self,x):
        out = self.cnn(x)
        out = out.view(out.size()[0],-1)
        return self.fc(out)
        

## Loss function

In [6]:
def loss_fn_kd(outputs, labels, teacher_outputs, T=20, alpha=0.5):
    # 一般的Cross Entropy
    hard_loss = F.cross_entropy(outputs, labels) * (1. - alpha)
    # 讓logits的log_softmax對目標機率(teacher的logits/T後softmax)做KL Divergence。
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T)
    return hard_loss + soft_loss

## Data Processing

In [15]:
import re
import torch
from glob import glob
from PIL import Image
import torchvision.transforms as transforms

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, folderName, transform=None):
        self.transform = transform
        self.data = []
        self.label = []

        for img_path in sorted(glob(folderName + '/*.jpg')):
            try:
                # Get classIdx by parsing image path
                class_idx = int(re.findall(re.compile(r'\d+'), img_path)[1])
            except:
                # if inference mode (there's no answer), class_idx default 0
                class_idx = 0

            self.data.append(img_path)
            self.label.append(class_idx)

    def __len__(self):
        return len(self.data)
#         return 100

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        img_path = self.data[idx]
        image = Image.open(img_path)
        # Get File Descriptor
        image_fp = image.fp
        image.load()
        # Close File Descriptor (or it'll reach OPEN_MAX)
        image_fp.close()

        if self.transform:
            image = self.transform(image)

        return image, self.label[idx]


trainTransform = transforms.Compose([
    transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
testTransform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

In [16]:
def get_dataloader(mode='training', batch_size=32):

    assert mode in ['training', 'testing', 'validation']

    dataset = MyDataset(
        f'../input/ml2020spring-hw7/food-11/{mode}',
        transform=trainTransform if mode == 'training' else testTransform)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'training'),
        num_workers=2
    )

    return dataloader

## Pre-processing

In [17]:
train_loader = get_dataloader(mode='training',batch_size=32)

In [18]:
valid_loader = get_dataloader(mode='validation',batch_size=32)

In [19]:
teacher_net = models.resnet18(pretrained=False,num_classes=11).to(device)

student_net = StudentNet(base=16).to(device)

teacher_net.load_state_dict(torch.load(f'../input/hw7-model/teacher_resnet18.bin',map_location=device))

optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)

## Training

In [20]:
def run_epoch(data_loader,update=True,alpha=0.5):
    total_num,total_hit,total_loss = 0,0,0
    for now_step, batch_data in enumerate(data_loader):
        optimizer.zero_grad()
        inputs, hard_labels = batch_data
#         inputs = inputs.to(device)
        inputs = inputs.to(torch.device('cuda'))
        teacher_net.cuda()
        student_net.cuda()
        hard_labels = torch.LongTensor(hard_labels).to(device)
        # 因为 TecherNet 没有要做 backprop 所以 我们要用 torch.no_grad 
        # 告诉torch不要去存储中间值（用来做 backprop）以浪费内存
        
        with torch.no_grad():
            soft_labels = teacher_net(inputs)
            
        if update:
            logits = student_net(inputs) 
            # 使用我們之前所寫的融合soft label&hard label的loss。
            # T=20是原始論文的參數設定。
            loss = loss_fn_kd(logits,hard_labels,soft_labels,20,alpha)
            loss.backward()
            optimizer.step()
        else:
            # 只是算validation acc的話，就開no_grad節省空間。
            with torch.no_grad():
                logits = student_net(inputs) 
                loss = loss_fn_kd(logits,hard_labels,soft_labels,20,alpha)

        total_hit += torch.sum(torch.argmax(logits,dim=1) == hard_labels).item()
        total_num += len(inputs)
    
        total_loss += loss.item()*len(inputs)
        print('Nowstep {}/{}: loss: {:6.4f}, acc {:6.4f} device:{}'.format(
            now_step+1, len(data_loader), total_loss/total_num, total_hit/total_num, device),end='\r')
    return total_loss/total_num, total_hit/total_num

In [None]:
# TeacherNet 永远都是eval模式
teacher_net.eval()
now_best_acc = 0

# 从之前训练的模型基础上开始训练模型

# student_net.load_state_dict(torch.load('../input/hw7-model/student_model.bin'))

for epoch in range(200):
    start = time.time()
    student_net.train()
    train_loss,train_acc = run_epoch(train_loader,update=True)
    student_net.eval()
    val_loss,val_acc = run_epoch(valid_loader,update=False)
    
    # 存下最好的 model
    if val_acc > now_best_acc:
        now_best_acc = val_acc
        torch.save(student_net.state_dict(),'student_model.bin')

    printLogFile('Epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}, seconds used {:6.4f}'.format(
        epoch, train_loss, train_acc, val_loss, val_acc, time.time()-start))

    fn_log.flush()

## Eval the student_net

In [21]:
student_net.load_state_dict(torch.load('../input/hw7-model/student_model.bin'))

<All keys matched successfully>

In [22]:
student_net.eval()
val_loss,val_acc = run_epoch(valid_loader,update=False)

printLogFile(f'\nTrained student_net loss:{val_loss:6.4f} acc:{val_acc:6.4f}')

Nowstep 4/4: loss: 8.2315, acc 0.6400 device:cuda
Trained student_net loss:8.2315 acc:0.6400


In [23]:
printLogFile(f"\noriginal cost: {os.stat('../input/hw7-model/student_model.bin').st_size} bytes")


original cost: 1046701 bytes


In [24]:
params = torch.load('../input/hw7-model/student_model.bin')

## 32-bit Tensor --> 16-bit

In [25]:
import numpy as np
import pickle

def encode16(params,fname):
    '''
    将params压缩成16bit后输出到fname
    
    Args:
        params: model`s state_dict
        fname: output file`s name
    '''
    custom_dict = {}
    for (name,param) in params.items():
        param = np.float64(param.cpu().numpy())
        # some item is just a number, need not to be compressed
        if type(param) == np.ndarray:
            custom_dict[name] = np.float16(param)
        else:
            custom_dict[name] = param
    pickle.dump(custom_dict,open(fname,'wb'))
    
def decode16(fname):
    '''
    从fname读取各个param 将其从16bit还原回torch.tensor 后存如 state_dict 
    
    Args:
        fname: file name
    '''
    params = pickle.load(open(fname,'rb'))
    custom_dict = {}
    for (name,param) in params.items():
        param = torch.tensor(param)
        custom_dict[name] = param
    return custom_dict

encode16(params,'16_bit_model.pkl')

printLogFile(f"\n16-bit cost: {os.stat('16_bit_model.pkl').st_size} bytes")

student_net.load_state_dict(decode16('16_bit_model.pkl'))

student_net.eval()
val_loss,val_acc = run_epoch(valid_loader,update=False)

printLogFile(f'\n16-bit Trained student_net loss:{val_loss:6.4f} acc:{val_acc:6.4f}')


16-bit cost: 522958 bytes
Nowstep 4/4: loss: 8.2303, acc 0.6400 device:cuda
Trained student_net loss:8.2303 acc:0.6400


# 32-bit Tensor -> 8-bit (OPTIONAL)

這邊提供轉成8-bit的方法，僅供大家參考。
因為沒有8-bit的float，所以我們先對每個weight記錄最小值和最大值，進行min-max正規化後乘上$2^8-1$在四捨五入，就可以用np.uint8存取了。

$W' = round(\frac{W - \min(W)}{\max(W) - \min(W)} \times (2^8 - 1)$)



> 至於能不能轉成更低的形式，例如4-bit呢? 當然可以，待你實作。

In [26]:
def encode8(params,fname):
    '''
    将params压缩成8bit后输出到fname
    
    Args:
        params: model`s state_dict
        fname: output file`s name
    '''
    custom_dict = {}
    for (name,param) in params.items():
        param = np.float64(param.cpu().numpy())
        # some item is just a number, need not to be compressed
        if type(param) == np.ndarray:
            min_val = np.min(param)
            max_val = np.max(param)
            param = np.round((param-min_val)/(max_val-min_val)*255)
            param = np.uint8(param)
            custom_dict[name] = (min_val,max_val,param)
        else:
            custom_dict[name] = param
    pickle.dump(custom_dict,open(fname,'wb'))
    
def decode8(fname):
    '''
    从fname读取各个param 将其从8bit还原回torch.tensor 后存如 state_dict 
    
    Args:
        fname: file name
    '''
    params = pickle.load(open(fname,'rb'))
    custom_dict = {}
    for (name,param) in params.items():
        if type(param) == tuple:
            min_val,max_val,param = param
            param = np.float64(param)
            param = (param / 255 * (max_val - min_val)) + min_val

        param = torch.tensor(param)
        custom_dict[name] = param
    return custom_dict

encode8(params,'8_bit_model.pkl')

printLogFile(f"\n8-bit cost: {os.stat('8_bit_model.pkl').st_size} bytes")

student_net.load_state_dict(decode8('8_bit_model.pkl'))

student_net.eval()
val_loss,val_acc = run_epoch(valid_loader,update=False)

printLogFile(f'\n8-bit Trained student_net loss:{val_loss:6.4f} acc:{val_acc:6.4f}')


8-bit cost: 268471 bytes
Nowstep 4/4: loss: 8.3283, acc 0.6200 device:cuda
Trained student_net loss:8.3283 acc:0.6200


## Prediction

In [27]:
test_loader = get_dataloader(mode='testing',batch_size=32)

In [28]:
pred = []
for i,data in enumerate(test_loader):
    inputs,_ = data
    inputs = inputs.to(device)
    student_net.to(device)
    student_net.eval()
    with torch.no_grad():
        logits = student_net(inputs)
        logits = torch.argmax(logits,dim=1).cpu().numpy().tolist()
        pred += logits
    print(f'{i+1}/{len(test_loader)}',end='\r')


4/4

In [29]:
import pandas as pd

df = pd.DataFrame({'id':list(range(len(pred))),'value':pred})

df.to_csv('predict.csv',index=None)

In [30]:
fn_log.close()