In [1]:
import sys
import torch
import tqdm
import numpy as np
import random
import os
import json
sys.path.append('../../')

In [2]:
from models.cnn.search_cnn import  SearchCNN, SearchCNNController
from models.cnn_darts_hypernet.search_cnn_darts_hypernet import  SearchCNNControllerWithHyperNet

from configobj import ConfigObj

In [4]:

basecfg_path = '../../configs/hyper/fmnist.cfg'  #конфиг, на который мы ориентируемся при загрузки модели



cfg = ConfigObj(basecfg_path)
name = cfg['name'] # имя для сохранения результатов
ckp_path = '../../searchs/fmnist_darts/checkpoint_{}_49.ckp' # это шаблон названия сохраненных моделей
seeds = cfg['seeds'].split(';')  # сиды. можно брать из конфига
fine_epochs = 10


In [5]:
import utils
# get data with meta info
input_size, input_channels, n_classes, train_data, valid_data = utils.get_data(
    'fashionmnist', '../../data/', cutout_length=0, validation=True)

# split data to train/validation
n_train = len(train_data)
split = int(n_train * 0.5)
indices = list(range(n_train))

train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
    indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
    indices[split:])

train_loader = torch.utils.data.DataLoader(train_data,
                                        batch_size=64,
                                        sampler=train_sampler,
                                        num_workers=1,
                                        pin_memory=True)
valid_loader = torch.utils.data.DataLoader(train_data,
                                        batch_size=64,
                                        sampler=valid_sampler,
                                        num_workers=1,
                                        pin_memory=True)

test_loader = torch.utils.data.DataLoader(valid_data,
                                           batch_size=64,
                                           shuffle=False,
                                           num_workers=1,
                                           pin_memory=True)



In [6]:
def fix_structure(sc): # во время тестов нашей модели нужно перейти от непрерывной структуры к one-hot
    for alpha in sc.alpha_reduce:
        alpha.requires_grad = False
        for subalpha in alpha:
            
            argm = torch.argmax(subalpha)
            subalpha.data*=0
            subalpha.data[argm] += 1
    sc.sampling_mode = 'naive'
            
def calc_param_number(sc):
    penalty = 0
    for id, cell in enumerate(sc.net.cells):
            # можно не пробегать несколько раз, т.к. клетки одинаковы (С точностью до normal и reduce)                        
            weights = [alpha for alpha in sc.alpha_reduce] if cell.reduction else [
                alpha for alpha in sc.alpha_normal]        
            
            for edges, w_list in zip(cell.dag, weights):
                for mixed_op, weights in zip(edges, w_list):
                    for op, w in zip(mixed_op._ops, weights):                        
                        for param in op.parameters():
                            penalty += w*np.prod(param.shape) 
    return penalty            
            

In [7]:
# смотрим качество модели, которое мы получили на обучении, без фиксации структуры, но с переводом на обычный DARTS
cfg = ConfigObj(basecfg_path)
cfg['device'] = 'cuda'
for s in seeds:
    print ('seed {}'.format(s))
    sc = SearchCNNController(**cfg)        
    sc.load_state_dict(torch.load(ckp_path.format(s)))     
    sc = sc.to('cuda')        
    sc.eval()
    correct = 0
    total = 0
    for x,y in tqdm.tqdm(valid_loader):
        x = x.cuda()
        y = y.cuda()
        out = sc(x)
        correct += torch.eq(torch.argmax(out, 1), y).sum()
        total += len(x)
    print (correct*1.0/total*1.0)

seed 0


100%|██████████| 469/469 [00:11<00:00, 42.25it/s]


tensor(0.8997, device='cuda:0')
seed 13


100%|██████████| 469/469 [00:11<00:00, 41.42it/s]


tensor(0.8976, device='cuda:0')
seed 21


100%|██████████| 469/469 [00:11<00:00, 42.25it/s]

tensor(0.9011, device='cuda:0')





In [8]:
# смотрим качество модели, которое мы получили на обучении, без фиксации структуры, но с переводом на обычный DARTS
cfg = ConfigObj(basecfg_path)
cfg['device'] = 'cuda'
for s in seeds:
    print ('seed {}'.format(s))
    sc = SearchCNNController(**cfg)        
    sc.load_state_dict(torch.load(ckp_path.format(s)))     
    sc = sc.to('cuda')   
    fix_structure(sc)
    sc.eval()
    correct = 0
    total = 0
    for x,y in tqdm.tqdm(valid_loader):
        x = x.cuda()
        y = y.cuda()
        out = sc(x)
        correct += torch.eq(torch.argmax(out, 1), y).sum()
        total += len(x)
    print (correct*1.0/total*1.0)
    print (calc_param_number(sc))

seed 0


100%|██████████| 469/469 [00:10<00:00, 43.24it/s]


tensor(0.0856, device='cuda:0')
tensor(12496., device='cuda:0')
seed 13


100%|██████████| 469/469 [00:10<00:00, 43.11it/s]


tensor(0.0886, device='cuda:0')
tensor(12352., device='cuda:0')
seed 21


100%|██████████| 469/469 [00:10<00:00, 42.75it/s]

tensor(0.2225, device='cuda:0')
tensor(13120., device='cuda:0')





In [9]:
import torch.nn.functional as F
class HackedNNController(SearchCNNController):
    def __init__(self, **kwargs):
        SearchCNNController.__init__(self, **kwargs)
        self.lam = 0.0
        
    def loss(self, X, y):
        logits = self.forward(X)
        penalty = 0
        for id, cell in enumerate(self.net.cells):            
            
            weights = [alpha for alpha in self.alpha_reduce] if cell.reduction else [
                alpha for alpha in self.alpha_normal]

            weights = [F.softmax(w/self.t, dim=-1) for w in weights]

            for edges, w_list in zip(cell.dag, weights):
                for mixed_op, weights in zip(edges, w_list):
                    for op, w in zip(mixed_op._ops, weights):
                        for param in op.parameters():
                            penalty += w*np.prod(param.shape)           

        return self.criterion(logits, y)   + penalty * self.lam[0,0] 

In [None]:
# попробуем дообучить модельки с GS при фиксированной лямбде
cfg = ConfigObj(basecfg_path)
cfg['device'] = 'cuda'
for lam_ in range(-8, -3):    
    lam_ = 10.0**(lam_)
    lam = torch.tensor([[lam_]]).cuda()
    
    def hyperloss(self, X, y):
        logits = self.forward(X, lam)
        penalty = 0
        for id, cell in enumerate(self.net.cells):            
            
            weights = [alpha for alpha in self.hyper_reduce] if cell.reduction else [
                alpha for alpha in self.hyper_normal]

            weights = [F.softmax(w/self.t, dim=-1) for w in weights]

            for edges, w_list in zip(cell.dag, weights):
                for mixed_op, weights in zip(edges, w_list):
                    for op, w in zip(mixed_op._ops, weights):
                        for param in op.parameters():
                            penalty += w*np.prod(param.shape)           

        return self.criterion(logits, y)   + penalty * lam[0,0] 

        
    for s in seeds:
       
        sc0 = HackedNNController(**cfg)        
        sc0.load_state_dict(torch.load(ckp_path.format(s)))     
        sc0 = sc0.to('cuda')
        sc0.samling_mode='gumbel-softmax'
        sc0.lam = lam


    
        
        batch_id = 0
        for e in range(fine_epochs//2):
            tq = tqdm.tqdm((zip(train_loader, valid_loader)))
            losses = []
            for ((trn_X, trn_y), (val_X, val_y)) in tq:
                batch_id += 1                
                t = 0.2 + (0.8 - 0.8 * batch_id/(fine_epochs//2*len(train_loader)))
                sc0.t = t
                trn_X, trn_y = trn_X.to('cuda', non_blocking=True), trn_y.to('cuda', non_blocking=True)
                val_X, val_y = val_X.to('cuda', non_blocking=True), val_y.to('cuda', non_blocking=True)                 
                loss = sc0.train_step(trn_X, trn_y, val_X, val_y).detach().cpu().numpy()
                losses.append(loss)
                tq.set_description('{};{}'.format(sc0.t, str(np.mean(losses))))
                

    
        
    
        fix_structure(sc0)
        sc0.eval()
        correct = 0
        total = 0
        
        for x,y in tqdm.tqdm(valid_loader):
            x = x.cuda()
            y = y.cuda()
            out = sc0(x)
            correct += torch.eq(torch.argmax(out, 1), y).sum()
            total += len(x)
        
        penalty = calc_param_number(sc0)        
        print ('seed {}, lam: {}'.format(s, lam_))
        print (correct*1.0/total*1.0)
        print ('param num', penalty)
        print ('\n'*3)
        torch.save(sc0.state_dict(), 'darts_{}_prefine_lam_{}.pth'.format(s, lam_ ))
                            
        

0.8400000000000001;0.37684387: : 469it [02:31,  3.10it/s]
0.6799999999999999;0.37178084: : 469it [02:39,  2.94it/s]
0.52;0.40263748: : 469it [02:36,  3.00it/s]              
0.35999999999999993;0.4215558: : 469it [02:36,  2.99it/s] 
0.2;0.43260792: : 469it [02:40,  2.92it/s]                
100%|██████████| 469/469 [00:11<00:00, 41.93it/s]


seed 0, lam: 1e-08
tensor(0.8324, device='cuda:0')
param num tensor(11040., device='cuda:0')






0.8400000000000001;0.3811183: : 469it [02:41,  2.90it/s] 
0.6799999999999999;0.3699376: : 469it [02:40,  2.91it/s] 
0.52;0.4135063: : 469it [02:41,  2.91it/s]               
0.35999999999999993;0.4480036: : 469it [02:41,  2.91it/s] 
0.2;0.47684354: : 469it [02:41,  2.90it/s]                
100%|██████████| 469/469 [00:11<00:00, 41.51it/s]


seed 13, lam: 1e-08
tensor(0.8384, device='cuda:0')
param num tensor(10128., device='cuda:0')






0.8400000000000001;0.38550937: : 469it [02:41,  2.91it/s]
0.6799999999999999;0.37989107: : 469it [02:41,  2.90it/s]
0.52;0.41088405: : 469it [02:41,  2.91it/s]              
0.35999999999999993;0.4422571: : 469it [02:41,  2.90it/s] 
0.2;0.4486926: : 469it [02:40,  2.91it/s]                 
100%|██████████| 469/469 [00:11<00:00, 41.51it/s]


seed 21, lam: 1e-08
tensor(0.8389, device='cuda:0')
param num tensor(11808., device='cuda:0')






0.8400000000000001;0.37633765: : 469it [02:41,  2.90it/s]
0.6799999999999999;0.37092686: : 469it [02:41,  2.91it/s]
0.52;0.403047: : 469it [02:40,  2.92it/s]                
0.35999999999999993;0.4378077: : 469it [02:41,  2.91it/s] 
0.2;0.43894133: : 469it [02:41,  2.91it/s]                
100%|██████████| 469/469 [00:11<00:00, 41.50it/s]


seed 0, lam: 1e-07
tensor(0.8454, device='cuda:0')
param num tensor(11040., device='cuda:0')






0.8400000000000001;0.383313: : 469it [02:41,  2.91it/s]  
0.6799999999999999;0.3754614: : 469it [02:41,  2.91it/s] 
0.52;0.40271854: : 469it [02:41,  2.91it/s]              
0.35999999999999993;0.448941: : 469it [02:41,  2.90it/s]  
0.2;0.4706503: : 469it [02:41,  2.90it/s]                 
100%|██████████| 469/469 [00:11<00:00, 41.17it/s]


seed 13, lam: 1e-07
tensor(0.8088, device='cuda:0')
param num tensor(9616., device='cuda:0')






0.8400000000000001;0.38738075: : 469it [02:41,  2.90it/s]
0.6799999999999999;0.37609303: : 469it [02:40,  2.91it/s]
0.52;0.40970162: : 469it [02:41,  2.91it/s]              
0.35999999999999993;0.4445109: : 469it [02:41,  2.91it/s] 
0.2;0.45687944: : 469it [02:41,  2.91it/s]                
100%|██████████| 469/469 [00:11<00:00, 41.43it/s]


seed 21, lam: 1e-07
tensor(0.8318, device='cuda:0')
param num tensor(10896., device='cuda:0')






0.8400000000000001;0.38372692: : 469it [02:41,  2.91it/s]
0.6799999999999999;0.38254455: : 469it [02:41,  2.90it/s]
0.52;0.4179877: : 469it [02:41,  2.90it/s]               
0.35999999999999993;0.44831303: : 469it [02:41,  2.90it/s]
0.2;0.44789633: : 469it [02:41,  2.91it/s]                
100%|██████████| 469/469 [00:11<00:00, 41.19it/s]


seed 0, lam: 1e-06
tensor(0.8330, device='cuda:0')
param num tensor(11040., device='cuda:0')






0.8400000000000001;0.38843343: : 469it [02:41,  2.90it/s]
0.6799999999999999;0.37935135: : 469it [02:41,  2.90it/s]
0.52;0.41520765: : 469it [02:41,  2.91it/s]              
0.35999999999999993;0.4513578: : 469it [02:41,  2.91it/s] 
0.2;0.47515437: : 469it [02:41,  2.90it/s]                
100%|██████████| 469/469 [00:11<00:00, 41.57it/s]


seed 13, lam: 1e-06
tensor(0.8138, device='cuda:0')
param num tensor(9984., device='cuda:0')






0.8400000000000001;0.3943096: : 469it [02:41,  2.90it/s] 
0.6799999999999999;0.39136225: : 469it [02:41,  2.91it/s]
0.52;0.408577: : 469it [02:41,  2.91it/s]                
0.35999999999999993;0.44473007: : 469it [02:41,  2.91it/s]
0.2;0.4567476: : 469it [02:41,  2.91it/s]                 
100%|██████████| 469/469 [00:11<00:00, 41.62it/s]


seed 21, lam: 1e-06
tensor(0.7576, device='cuda:0')
param num tensor(11808., device='cuda:0')






0.8400000000000001;0.4639917: : 469it [02:41,  2.90it/s] 
0.6799999999999999;0.4606934: : 469it [02:41,  2.91it/s] 
0.5735607675906185;0.50101376: : 312it [01:46,  2.88it/s]

In [13]:
sc0.loss

<function __main__.hyperloss(self, X, y)>

In [21]:
'model_{}_prefine_lam_{}.pth'.format(s, lam_ )

'model_0_prefine_lam_0.001.pth'

In [29]:
(sc(x) - sc0(x, torch.tensor([10e-9]).cuda()))

tensor([[-2.0407e+00, -2.5080e+00, -4.8985e+00,  6.0774e+00,  1.1547e-01,
         -1.7027e+00, -4.5315e-01,  8.9518e+00, -4.7982e+00,  2.8806e+00],
        [-2.4257e+00, -1.9611e-01, -6.0386e+00,  9.4664e-03, -4.2474e+00,
         -1.2515e+00,  1.2721e-01,  8.6204e+00,  6.9197e-01,  5.5956e+00],
        [ 2.1178e+00, -5.2917e-01,  4.5798e+00,  1.0271e-01, -5.7425e+00,
         -7.9479e+00,  3.9173e+00,  1.7867e+00, -3.9782e-01,  2.3104e+00],
        [ 3.5477e+00,  1.9448e+00,  8.7270e+00,  1.7255e+00, -6.4827e+00,
         -9.9263e+00, -3.4194e-01,  1.1371e+00,  6.0196e+00, -5.7826e+00],
        [-6.3454e-01, -7.4744e-01,  1.8433e+00,  4.2400e+00, -7.6684e+00,
         -3.7727e+00,  3.5501e+00,  3.1300e+00, -6.1671e+00,  7.1702e+00],
        [-2.3435e+00, -7.0764e-01, -4.9570e+00, -9.3269e-01, -4.9343e+00,
         -2.6882e+00,  7.4749e-01,  8.8085e+00,  1.3677e+00,  6.1104e+00],
        [ 1.5993e+00,  3.2610e+00,  5.3848e+00, -6.0362e-02, -7.3853e+00,
         -8.0302e+00, -3.8744e-0

In [18]:
sc0.hyper_reduce[0](torch.tensor([0.0])), sc0.hyper_reduce[0](torch.tensor([1.0]))

(tensor([[-0.8079, -0.9030, -0.6230, -0.2134,  3.2586, -0.3532, -0.4448, -0.4817],
         [-0.8494, -0.8544, -0.6758, -0.5583,  4.6323, -0.5762, -0.6024, -0.5743]],
        grad_fn=<ViewBackward>),
 tensor([[-1.0875, -1.2262, -0.8786, -0.3082,  4.5715, -0.4987, -0.6206, -0.6553],
         [-1.1367, -1.1572, -0.9096, -0.7862,  6.3077, -0.7849, -0.8311, -0.7623]],
        grad_fn=<ViewBackward>))

In [19]:
sc0.hyper_reduce[0].model[0].weight, sc0.hyper_reduce[0].model[0].bias

(Parameter containing:
 tensor([[-0.2796],
         [-0.3232],
         [-0.2557],
         [-0.0948],
         [ 1.3129],
         [-0.1455],
         [-0.1758],
         [-0.1736],
         [-0.2873],
         [-0.3027],
         [-0.2337],
         [-0.2279],
         [ 1.6754],
         [-0.2087],
         [-0.2287],
         [-0.1880]], requires_grad=True), Parameter containing:
 tensor([-0.8079, -0.9030, -0.6230, -0.2134,  3.2586, -0.3532, -0.4448, -0.4817,
         -0.8494, -0.8544, -0.6758, -0.5583,  4.6323, -0.5762, -0.6024, -0.5743],
        requires_grad=True))

In [34]:
lam_ = 0.0
for h,a in zip(sc0.hyper_reduce, sc.alpha_reduce):
                a.data = torch.clone(h(torch.tensor([lam_]).cuda())) 
       
fix_structure(sc)
for a in sc.alpha_reduce:
    print (a)

Parameter containing:
tensor([[-0., -0., 0., -0., 0., 0., -0., 1.],
        [-0., -0., -0., -0., -0., 0., 0., 1.]], device='cuda:0')
Parameter containing:
tensor([[-0., -0., -0., 0., -0., 0., -0., 1.],
        [-0., -0., -0., 0., -0., 0., -0., 1.],
        [1., -0., 0., 0., 0., -0., -0., 0.]], device='cuda:0')
Parameter containing:
tensor([[-0., 0., -0., -0., 0., -0., -0., 1.],
        [-0., -0., 0., -0., -0., -0., -0., 1.],
        [-0., -0., 0., -0., 1., -0., -0., 0.],
        [-0., -0., 0., 0., 0., 0., -0., 1.]], device='cuda:0')
Parameter containing:
tensor([[0., 0., -0., -0., -0., -0., -0., 1.],
        [-0., -0., -0., -0., -0., -0., -0., 1.],
        [-0., -0., 0., -0., 1., -0., -0., 0.],
        [0., 0., 0., -0., -0., -0., -0., 1.],
        [0., 0., -0., -0., 1., 0., -0., -0.]], device='cuda:0')


In [33]:
lam_ = 1.0
for h,a in zip(sc0.hyper_reduce, sc.alpha_reduce):
                a.data = torch.clone(h(torch.tensor([lam_]).cuda()) )
       
fix_structure(sc)
for a in sc.alpha_reduce:
    print (a)

Parameter containing:
tensor([[-0., -0., 0., -0., 0., 0., -0., 1.],
        [-0., -0., -0., -0., -0., 0., 0., 1.]], device='cuda:0')
Parameter containing:
tensor([[-0., -0., -0., 0., -0., 0., -0., 1.],
        [-0., -0., -0., 0., -0., 0., -0., 1.],
        [0., -0., 0., 0., -0., -0., -0., 1.]], device='cuda:0')
Parameter containing:
tensor([[-0., 0., -0., -0., 0., -0., -0., 1.],
        [-0., -0., 0., -0., -0., -0., -0., 1.],
        [-0., -0., 0., -0., 0., -0., -0., 1.],
        [-0., -0., 0., -0., 0., 0., -0., 1.]], device='cuda:0')
Parameter containing:
tensor([[0., 0., -0., -0., -0., -0., -0., 1.],
        [-0., -0., -0., -0., -0., -0., -0., 1.],
        [0., 0., 0., -0., 0., -0., -0., 1.],
        [0., 0., 0., -0., -0., -0., -0., 1.],
        [1., 0., -0., -0., 0., 0., -0., -0.]], device='cuda:0')


In [11]:
import torch.nn.functional as F
def hyperloss(self, X, y, lam):
        #logits = self.forward(X, lam)
        penalty = 0
        
        for id, cell in enumerate(self.net.cells):
            # можно не пробегать несколько раз, т.к. клетки одинаковы (С точностью до normal и reduce)            
            
            lam_ = self.norm_lam(lam)
            weights = [alpha(lam_) for alpha in self.hyper_reduce] if cell.reduction else [
                alpha(lam_) for alpha in self.hyper_normal]
            
            weights = [F.softmax(w/self.t, dim=-1) for w in weights]
            
              
            for edges, w_list in zip(cell.dag, weights):
                for mixed_op, weights in zip(edges, w_list):
                    for op, w in zip(mixed_op._ops, weights):                        
                        for param in op.parameters():
                            penalty += w*np.prod(param.shape)          
            #penalty += lam_[0,0] * (torch.norm(self.net.linear.weight)**2 + torch.norm(self.net.linear.bias)**2)
            #penalty += (1.0-lam_[0,0]) * (torch.norm(self.net.linear2.weight)**2 + torch.norm(self.net.linear2.bias)**2)
            

        # oleg return self.criterion(logits, y)   + penalty * lam[0,0] 
        return penalty 
    

In [14]:
sc0 = sc0.cuda()
sc0.t = 0.1

In [15]:
for r in range(-10, -4):
    
    print (hyperloss(sc0, 0, 0, torch.tensor([[10**r]]).cuda()))

tensor(6328.6978, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5707.0322, device='cuda:0', grad_fn=<AddBackward0>)
tensor(5262.7612, device='cuda:0', grad_fn=<AddBackward0>)
tensor(4670.6914, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3872.0901, device='cuda:0', grad_fn=<AddBackward0>)
tensor(3365.1294, device='cuda:0', grad_fn=<AddBackward0>)


'64'

In [14]:
sc0.hyper_reduce[0].model[0].weight, sc0.hyper_reduce[0].model[0].bias

(Parameter containing:
 tensor([[ 1.1084],
         [-0.0756],
         [-0.0762],
         [-0.0475],
         [-0.6884],
         [-0.0391],
         [-0.0297],
         [-0.0831],
         [ 1.0998],
         [-0.0604],
         [-0.0712],
         [-0.0253],
         [-0.6183],
         [-0.0505],
         [-0.0459],
         [-0.0997]], device='cuda:0', requires_grad=True),
 Parameter containing:
 tensor([ 1.5646, -0.8272, -0.6284, -0.3190,  1.9144, -0.6870, -0.3635, -1.1838,
          1.3121, -0.8585, -0.7832,  0.6159,  1.7367, -0.8624, -0.6465, -1.3189],
        device='cuda:0', requires_grad=True))

In [29]:
for lam_ in [-4.0,  -1.0]:
        lam = torch.tensor([10**lam_])
        # создаем модель с обязательным указанием, что к структуре не применяется softmax
        cfg = ConfigObj(basecfg_path)
        cfg['darts']['sampling_mode'] = 'naive' 
        cfg['device'] = 'cuda'    
        for s in seeds:
            print (s)
            sc0 = SearchCNNControllerWithHyperNet(**cfg)        
            sc0.load_state_dict(torch.load(ckp_path.format(s)))                            
            sc = SearchCNNController(**cfg)
            
            sc.net.load_state_dict(sc0.net.state_dict())
            for h,a in zip(sc0.hyper_reduce, sc.alpha_reduce):
                a.data = torch.clone(h(lam))                                            
            fix_structure(sc)
            
            sc = sc.to('cuda')
            optim = torch.optim.Adam(sc.weights())
            correct = 0
            total = 0
            # дообучаем n эпох
            print ('seed {}, lam: {}'.format(s, lam_))
            for e in range(5):
                for x,y in tqdm.tqdm(train_loader):
                    x = x.cuda()
                    y = y.cuda()            
                    optim.zero_grad()
                    loss = sc.loss(x,y)
                    loss.backward()
                    optim.step()                              
                
                sc.eval()
                correct = 0
                total = 0
                for x,y in tqdm.tqdm(valid_loader):
                    x = x.cuda()
                    y = y.cuda()
                    out = sc(x)
                    correct += torch.eq(torch.argmax(out, 1), y).sum()
                    total += len(x)
                print (correct*1.0/total*1.0)
                sc.train()
            torch.save(sc.state_dict(), ckp_path.format(s)+'{}.fine'.format(lam_))

0


  0%|          | 0/938 [00:00<?, ?it/s]

seed 0, lam: -4.0


  cpuset_checked))
  3%|▎         | 27/938 [00:02<01:27, 10.43it/s]


KeyboardInterrupt: 

In [34]:
# смотрим качество модели, которое мы получили на обучении
cfg = ConfigObj(basecfg_path)
cfg['device'] = 'cuda'
for lam_ in [-4.0,  -1.0]:
    for s in seeds:
        print ('seed {}, lam: {}'.format(s, lam_))
        sc = SearchCNNController(**cfg)
        
        sc.load_state_dict(torch.load(ckp_path.format(s)+'{}.fine'.format(lam_)))
        
        sc = sc.to('cuda')
        sc.eval()
        correct = 0
        total = 0
        for x,y in tqdm.tqdm(valid_loader):
            x = x.cuda()
            y = y.cuda()
            out = sc(x)
            correct += torch.eq(torch.argmax(out, 1), y).sum()
            total += len(x)
        print (correct*1.0/total*1.0)

  0%|          | 0/157 [00:00<?, ?it/s]

seed 21, lam: -4.0


100%|██████████| 157/157 [00:03<00:00, 43.77it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.0960, device='cuda:0')
seed 21, lam: -1.0


100%|██████████| 157/157 [00:03<00:00, 44.67it/s]

tensor(0.0967, device='cuda:0')





ParameterList(
    (0): Parameter containing: [torch.FloatTensor of size 2x8]
    (1): Parameter containing: [torch.FloatTensor of size 3x8]
    (2): Parameter containing: [torch.FloatTensor of size 4x8]
    (3): Parameter containing: [torch.FloatTensor of size 5x8]
)