Демонстрация простой вариационной модели и ее прунинга.

**Disclaimer**: могут быть ошибки, кроме того, функии могут быть написаны неоптимально.

In [52]:
import torch as t 
import torchvision
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline
import pickle
from torchvision import datasets, transforms
import random

In [23]:
device = 'cpu' # cuda or cpu

In [24]:
batch_size = 64
init_log_sigma = -3.0 # логарифм дисперсии вариационного распределения при инициализации
prior_sigma = 0.1 # априорная дисперсия
epoch_num = 20 #количество эпох
lamb = [0, 0.1, 0.5, 1, 5, 10, 100, 1000]
# lam = 1.0 # коэффициент перед дивергенцией
hidden_num = 100 # количество нейронов на скрытом слое
t.manual_seed(42) # задаем значение генератора случайных чисел для повторяемости экспериментов
acc_delete = [] 
filename = 'save_array_0.1' # куда сохранять

In [25]:
# сохранение данных
def save(file):
    outfile = open(filename, 'wb')
    pickle.dump(file, outfile)
    outfile.close()
    
def load(path = filename):
    infile = open(path, 'rb')
    file = pickle.load(infile)
    infile.close()
    return file
    
    

In [26]:
# загрузка данных
train_data = torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                                  torchvision.transforms.Lambda(lambda x: x.view(-1))
                              ]))

test_data = torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                                  torchvision.transforms.Lambda(lambda x: x.view(-1))
                              ]))


train_loader = t.utils.data.DataLoader(train_data, batch_size=batch_size, pin_memory=True )
test_loader = t.utils.data.DataLoader(test_data, batch_size=batch_size)


In [27]:
class VarLayer(nn.Module): # вариационная однослойная сеть
    def __init__(self, in_,  out_,   act=F.relu):         
        nn.Module.__init__(self)                    
        self.mean = nn.Parameter(t.randn(in_, out_, device=device)) # параметры средних
        t.nn.init.xavier_uniform(self.mean) 
        self.log_sigma = nn.Parameter(t.ones(in_, out_, device = device)*init_log_sigma) # логарифм дисперсии
        self.mean_b = nn.Parameter(t.randn(out_, device=device)) # то же самое для свободного коэффициента
        self.log_sigma_b = nn.Parameter(t.ones(out_, device=device) * init_log_sigma)
                
        self.in_ = in_
        self.out_ = out_
        self.act = act
        
    def forward(self,x):
        if self.training: # во время обучения - сэмплируем из нормального распределения
            self.eps_w = t.distributions.Normal(self.mean, t.exp(self.log_sigma))
            self.eps_b = t.distributions.Normal(self.mean_b, t.exp(self.log_sigma_b))
        
            w = self.eps_w.rsample()
            b = self.eps_b.rsample()
             
        else:  # во время контроля - смотрим средние значения параметра        
            w = self.mean 
            b = self.mean_b
            
        # функция активации 
        return self.act(t.matmul(x, w)+b)

    def KLD(self):        
        # подсчет дивергенции
        size = self.in_, self.out_
        out = self.out_
        self.eps_w = t.distributions.Normal(self.mean, t.exp(self.log_sigma))
        self.eps_b = t.distributions.Normal(self.mean_b,  t.exp(self.log_sigma_b))
        self.h_w = t.distributions.Normal(t.zeros(size, device=device), t.ones(size, device=device)*prior_sigma)
        self.h_b = t.distributions.Normal(t.zeros(out, device=device), t.ones(out, device=device)*prior_sigma)                
        k1 = t.distributions.kl_divergence(self.eps_w,self.h_w).sum()        
        k2 = t.distributions.kl_divergence(self.eps_b,self.h_b).sum()        
        return k1+k2

In [56]:
class LowRankNet(nn.Module): #сеть с аппроксимацией
    def __init__(self, in_, out_, hidden = 3, diagonal=False, act= lambda x: x):
        nn.Module.__init__(self)
        self.w = nn.Linear(1, hidden)
        t.nn.init.xavier_uniform(self.w.weight)
        self.act = act
        self.diagonal = diagonal
        if diagonal:
            self.w_d = nn.Linear(hidden, out_)
            t.nn.init.xavier_uniform(self.w_d.weight)
        else:
            self.w_a1 = nn.Linear(hidden, in_)
            t.nn.init.xavier_uniform(self.w_a1.weight)
            self.w_a2 = nn.Linear(hidden, out_)
            t.nn.init.xavier_uniform(self.w_a2.weight)
            self.w_b = nn.Linear(hidden, out_)
            t.nn.init.xavier_uniform(self.w_b.weight)
            self.in_ = in_
            self.out_ = out_
            
    def forward(self, lam):
        h = self.act(self.w(lam))        
        if self.diagonal:
            return self.w_d(h)
        else:
            a1 = self.w_a1(h)
            a2 = self.w_a2(h)
            b = self.w_b(h)
         
            return t.matmul(a1.view(-1, 1), a2.view(1, -1)) + b
        
    def KLD(self):        
        # подсчет дивергенции
        size = self.in_, self.out_
        out = self.out_
        self.eps_w = t.distributions.Normal(self.w, t.exp(self.a_1))
        self.eps_b = t.distributions.Normal(self.w_b,  t.exp(self.a_2))
        self.h_w = t.distributions.Normal(t.zeros(size, device=device), t.ones(size, device=device)*prior_sigma)
        self.h_b = t.distributions.Normal(t.zeros(out, device=device), t.ones(out, device=device)*prior_sigma)                
        k1 = t.distributions.kl_divergence(self.eps_w,self.h_w).sum()        
        k2 = t.distributions.kl_divergence(self.eps_b,self.h_b).sum()        
        return k1+k2

In [41]:
class VarSeqNet(nn.Sequential):    
    # класс-обертка на случай, если у нас многослойная нейронная сеть
    def KLD(self):
        k = 0
        for l in self:
            k+=l.KLD()
        return k

In [66]:
def train_batches(net, loss_fn, optimizer, i, out, out_loss, kld, loss):
    for id, (x,y) in enumerate(train_loader):  
            id+=1
            if device == 'cuda':
                x = x.cuda()
                y = y.cuda()            
            optimizer.zero_grad() 
            lam = t.from_numpy(np.random.randint(low = 0, high = 100, size = 5))
            out[i] = net(lam)
            # правдоподобие должно суммироваться по всей обучающей выборке
            # в случае батчей - она приводится к тому же порядку 
            out_loss[i] = loss_fn(out[i], y)* len(train_data)         
            kld[i] =  net.KLD() * lam        
            loss[i] = (out_loss[i]+kld[i])       
            if id %100 == 0:           
                print ("Number of net:",i, loss[i].data, out_loss[i].data, kld[i].data)            
            loss[i].backward()       
            clip_grad_value_(net.parameters(), 1.0) # для стабильности градиента. С этим можно играться
            optimizer.step()

In [43]:
def statistic(net, loss_fn, i, kld, loss, out, out_loss):
    net.eval()  
    kld[i] =  net.KLD() 
    loss[i] = kld[i]
    for x,y in test_loader:
         if device == 'cuda':
            x = x.cuda()
            y = y.cuda()          
    out[i] = net(x)   
    out_loss[i] = loss_fn(out[i], y)* len(train_data)/len(test_data)   
    #  print(out_loss[i])
    # print(loss[i])
    loss[i] += out_loss[i]
    net.train()
    print (loss[i])
    return loss[i]


    

In [44]:
# рассмотрим для примера сеть, состояющую из двух слоев
# второй слой - softmax. По сути для обучения задавать активацию явно не нужно, она забита в nn.CrossEntropyLoss
def init_nets(loss_fn_nets):
    for i in range(3):
        nets.append(VarSeqNet(LowRankNet(784,  hidden_num), LowRankNet(hidden_num, 10, act=lambda x:x)))
        optimizer_nets.append(optim.Adam(nets[i].parameters()))
        loss_fn_nets.append(nn.CrossEntropyLoss())
    loss_graph=[[],[],[]]
    out = [None, None, None]
    out_loss = [None, None, None]
    kld = [None, None, None]
    loss = [None, None, None]
    return out, out_loss, kld, loss, loss_graph

def train_nets(out, out_loss, kld, loss, loss_graph):
    for epoch in range(epoch_num):             
        for i,net in enumerate(nets):
            train_batches(net,loss_fn_nets[i], optimizer_nets[i],i, out, out_loss, kld, loss)
        print ('end of epoch: ', epoch)   
        for i,net in enumerate(nets):
            print("Number of net:",i)        
            loss_graph[i].append(statistic(net, loss_fn_nets[i], i, kld, loss, out, out_loss))

        

In [45]:
#print(loss_graf)
def graph_loss_func(loss_graph, nets):
    for i,net in enumerate(nets): 
        plt.plot(loss_graph[i])
    plt.ylabel('Loss function')
    plt.xlabel('Number of epoche')
    plt.show()
#print(out_loss)

#graph_loss_func()

In [46]:
def test_acc(out): # точность классификации
    acc = []
    for i,net in enumerate(nets):
        correct = 0
        net.eval()
        for x,y in test_loader:
            if device == 'cuda':
                x = x.cuda()
                y = y.cuda()     
            out[i] = net(x)    
            correct += out[i].argmax(1).eq(y).sum().cpu().numpy()
        acc.append(correct / len(test_data))
    print(sum(acc)/len(acc))   
    return(acc)
#test_acc(out)

In [47]:
# коэффициенты информативности, см. статью practical variational inference
# попробуем удалять параметры первого слоя по этому коэффициенту

def init_coeff(prune_coef, mu, sigma):
    for i,net in enumerate(nets): 
        mu.append(net[0].mean) 
        sigma.append(t.exp(2*net[0].log_sigma))
        prune_coef.append((mu[i]**2/sigma[i]).cpu().detach().numpy())  


In [48]:
# будем удалять по 10% от модели и смотреть качество
def delete_10(acc_delete, prune_coef, mu, sigma, nets, out):
    acc_delete = []
    sorted_coefs = []
    for i, net in enumerate(nets):
        sorted_coefs.append(np.sort(prune_coef[i].flatten()))
    for j in range(10):
        for i,net in enumerate(nets): 
            ids = (prune_coef[i] <= sorted_coefs[i][round(j/10*len(sorted_coefs[i]))]) 
            net[0].mean.data*=(1-t.tensor(ids*1.0, device=device, dtype=t.float))
            print ('nonzero params: ', (abs(net[0].mean)>0).float().mean())
        acc_delete.append(test_acc(out))
    return acc_delete    


In [49]:
def graph(acc_delete, lamb):
    proc = [0,10,20,30,40,50,60,70,80,90]
    plt.rcParams['figure.figsize'] = 12, 12
    for k, lam in enumerate(lamb):
        acc_delete_n = np.array(acc_delete[k])
        plt.plot(proc, np.mean(acc_delete_n, 1), label = 'lambda = {}'.format(str(lam)))
        # откладываем ошибку вокруг среднего, альфа - прозрачность линии
        plt.fill_between(proc, np.mean(acc_delete_n, 1)  + np.std(acc_delete_n, 1) , np.mean(acc_delete_n, 1) - np.std(acc_delete_n, 1) , alpha = 0.5 )
    plt.ylabel('Точность классификации', fontsize = 20)
    plt.xlabel('Процент удаления', fontsize = 20)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.legend(loc='best')
    plt.savefig('1')
    plt.show()

#acc_delete = load('save_array_0.1')    
#graph(acc_delete, lamb)


    

In [50]:
# проверяем, что фокусов тут нет, удаляем оставшиеся 10%\
def delete_last10():
    flag = 0
    for j in range(10):
        for i,net in enumerate(nets): 
            if (flag == 0):
                sorted_coefs = np.sort(prune_coef[i].flatten())
                flag = 1
            ids = (prune_coef[i] <= sorted_coefs[round((0.9+j/100)*len(sorted_coefs))]) 
            net[0].mean.data*=(1-t.tensor(ids*1.0, device=device, dtype=t.float))
            print ('nonzero params: ', (abs(net[0].mean)>0).float().mean())
        (test_acc())
    for i,net in enumerate(nets):
        net[0].mean.data*=0
        print ('nonzero params: ', (abs(net[0].mean)>0).float().mean())
    (test_acc())
    
#delete_last10()    

In [67]:
loss_fn_nets =[]
nets = []
optimizer_nets = []
mu_glob = []
sigma_glob = []
prune_coef_glob = []
init_nets_output =  init_nets(loss_fn_nets)
train_nets(init_nets_output[0], init_nets_output[1], init_nets_output[2], init_nets_output[3], init_nets_output[4])
#graph_loss_func()
init_coeff(prune_coef_glob, mu_glob, sigma_glob)
for k,lam in enumerate(lamb):
    acc_delete.append(None)
    acc_delete[k]= delete_10(acc_delete[k], prune_coef_glob, mu_glob, sigma_glob, nets, init_nets_output[0])
graph(acc_delete,lamb)
save(acc_delete)


  """
  del sys.path[0]
  from ipykernel import kernelapp as app


RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'mat2' in call to _th_mm