In [1]:
import sys
import torch
import tqdm
import numpy as np
import random
import os
import json
sys.path.append('../../')

In [3]:
from models.cnn.search_cnn import  SearchCNN, SearchCNNController
from models.cnn_var_local.search_cnn  import  LVarSearchCNN, LVarSearchCNNController
from configobj import ConfigObj

In [4]:
import utils
# get data with meta info
input_size, input_channels, n_classes, train_data, valid_data = utils.get_data(
    'mnist', '../../data/', cutout_length=0, validation=True)

train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=64,
                                               shuffle=True,
                                               num_workers=32,
                                               pin_memory=True)

valid_loader = torch.utils.data.DataLoader(valid_data,
                                           batch_size=64,
                                           shuffle=False,
                                           num_workers=1,
                                           pin_memory=True)



In [13]:
ckp_path = '../../searchs/searchs/mnist_darts_var/checkpoint_'
basecfg_path = '../../configs/mnist/var_darts.cfg'
darts_basecfg_path = '../../configs/mnist/darts.cfg'
seeds = [0, 13, 21, 42, 99] 
cfg = ConfigObj(basecfg_path)
name = cfg['name']
fine_tune = False
if fine_tune:
    name +='_fine'

In [34]:
sc = sc.to('cpu')
cfg2 = ConfigObj(darts_basecfg_path)
cfg2['darts']['sampling_mode'] = 'naive'
cfg2['device'] = 'cuda'
sc2 = SearchCNNController(**cfg2)

In [35]:
for a, g in zip(sc2.alpha_reduce, sc.net.q_gamma_reduce):
    a.data *= 0
    a.data += g.data
    

In [45]:
type(sc.net.cells[0])

models.cnn_var_local.search_cells.SearchCell

In [7]:
def darts_fix_structure(sc):
    for alpha in sc.alpha_reduce:
        alpha.requires_grad = False
        for subalpha in alpha:
            argm = torch.argmax(subalpha)
            subalpha.data*=0
            subalpha.data[argm] += 1
            
            
            

In [49]:
if fine_tune:
    # check everything is ok, when the structure is fixed
    cfg = ConfigObj(basecfg_path)
    cfg['darts']['sampling_mode'] = 'naive'
    cfg['device'] = 'cuda'
    for s in seeds:
        print (s)
        sc = SearchCNNController(**cfg)
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49.ckp').state_dict())
        fix_structure(sc)
        sc = sc.to('cuda')
        optim = torch.optim.Adam(sc.weights())
        correct = 0
        total = 0
        for x,y in tqdm.tqdm(train_loader):
            x = x.cuda()
            y = y.cuda()            
            optim.zero_grad()
            loss = sc.loss(x,y)
            loss.backward()
            optim.step()                              
        torch.save(sc.state_dict(), ckp_path+str(s)+'_49_fine.ckp')

0


100%|██████████| 938/938 [02:06<00:00,  7.42it/s]
  0%|          | 0/938 [00:00<?, ?it/s]

13


100%|██████████| 938/938 [02:05<00:00,  7.50it/s]


21


100%|██████████| 938/938 [02:05<00:00,  7.47it/s]
  0%|          | 0/938 [00:00<?, ?it/s]

42


100%|██████████| 938/938 [02:05<00:00,  7.48it/s]
  0%|          | 0/938 [00:00<?, ?it/s]

99


100%|██████████| 938/938 [02:05<00:00,  7.49it/s]


In [15]:
# check everything is ok
cfg = ConfigObj(basecfg_path)
cfg['device'] = 'cuda'
for s in seeds:
    print (s)
    sc = LVarSearchCNNController(**cfg)
    if fine_tune:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49_fine.ckp'))
    else:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49.ckp'))
    sc = sc.to('cuda')
    sc.eval()
    correct = 0
    total = 0
    for x,y in tqdm.tqdm(valid_loader):
        x = x.cuda()
        y = y.cuda()
        out = sc(x)
        correct += torch.eq(torch.argmax(out, 1), y).sum()
        total += len(x)
    print (correct*1.0/total*1.0)

  0%|          | 0/157 [00:00<?, ?it/s]

0


100%|██████████| 157/157 [00:04<00:00, 33.46it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9939, device='cuda:0')
13


100%|██████████| 157/157 [00:04<00:00, 35.30it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9934, device='cuda:0')
21


100%|██████████| 157/157 [00:04<00:00, 35.24it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9937, device='cuda:0')
42


100%|██████████| 157/157 [00:04<00:00, 35.21it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9938, device='cuda:0')
99


100%|██████████| 157/157 [00:04<00:00, 35.19it/s]

tensor(0.9935, device='cuda:0')





In [50]:
# check everything is ok, when the structure is fixed
cfg = ConfigObj(basecfg_path)
cfg['darts']['sampling_mode'] = 'naive'
cfg['device'] = 'cuda'
for s in seeds:
    print (s)
    sc = SearchCNNController(**cfg)
    if fine_tune:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49_fine.ckp'))
    else:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49.ckp').state_dict())
    fix_structure(sc)
    

    sc = sc.to('cuda')
    sc.eval()
    correct = 0
    total = 0
    for x,y in tqdm.tqdm(valid_loader):
        x = x.cuda()
        y = y.cuda()
        out = sc(x)
        correct += torch.eq(torch.argmax(out, 1), y).sum()
        total += len(x)
    print (correct*1.0/total*1.0)

0


100%|██████████| 157/157 [00:04<00:00, 37.23it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9822, device='cuda:0')
13


100%|██████████| 157/157 [00:04<00:00, 37.19it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9895, device='cuda:0')
21


100%|██████████| 157/157 [00:04<00:00, 37.20it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9803, device='cuda:0')
42


100%|██████████| 157/157 [00:04<00:00, 37.24it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

tensor(0.9881, device='cuda:0')
99


100%|██████████| 157/157 [00:04<00:00, 37.18it/s]

tensor(0.9669, device='cuda:0')





In [51]:
# FGM
results = []
cfg = ConfigObj(basecfg_path)
cfg['darts']['sampling_mode'] = 'naive'
cfg['device'] = 'cuda'
for s in seeds:
    results.append([])
    print (s)
    sc = SearchCNNController(**cfg)
    if fine_tune:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49_fine.ckp'))
    else:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49.ckp').state_dict())
        
    fix_structure(sc)

    sc = sc.to('cuda')
    sc.eval()
    for eps in np.linspace(0.0, 1.0, 11):
        
        correct = 0
        total = 0

        for x,y in tqdm.tqdm(valid_loader):
            x = x.cuda()
            x.requires_grad = True                                                 
            y = y.cuda()
            out = sc(x)
            loss = sc.criterion(out, y)
            sc.zero_grad()
            loss.backward()
            data_grad = x.grad.data
            sign_data_grad = data_grad.sign()    
            perturbed_image = x + eps*sign_data_grad                    
            out = sc(perturbed_image)
            correct += torch.eq(torch.argmax(out, 1), y).sum().cpu().detach().numpy()
            total += len(x)
        print (eps, correct*1.0/total*1.0)
        results[-1].append((eps, float(correct*1.0/total*1.0)))
with open(name+'_fgm.json', 'w') as out:
    out.write(json.dumps(results))

  0%|          | 0/157 [00:00<?, ?it/s]

0


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9822


100%|██████████| 157/157 [00:23<00:00,  6.69it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.7842


100%|██████████| 157/157 [00:23<00:00,  6.60it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.3918


100%|██████████| 157/157 [00:23<00:00,  6.69it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.1786


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.0861


100%|██████████| 157/157 [00:23<00:00,  6.68it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.0598


100%|██████████| 157/157 [00:23<00:00,  6.61it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.053


100%|██████████| 157/157 [00:23<00:00,  6.68it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.0475


100%|██████████| 157/157 [00:23<00:00,  6.69it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.0436


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.0406


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.0356
13


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9895


100%|██████████| 157/157 [00:23<00:00,  6.69it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.8847


100%|██████████| 157/157 [00:24<00:00,  6.48it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.5343


100%|██████████| 157/157 [00:23<00:00,  6.59it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.2394


100%|██████████| 157/157 [00:23<00:00,  6.57it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.1543


100%|██████████| 157/157 [00:24<00:00,  6.51it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.1208


100%|██████████| 157/157 [00:23<00:00,  6.62it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.1018


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.0913


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.086


100%|██████████| 157/157 [00:23<00:00,  6.68it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.0819


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.0801
21


100%|██████████| 157/157 [00:23<00:00,  6.63it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9803


100%|██████████| 157/157 [00:23<00:00,  6.71it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.6799


100%|██████████| 157/157 [00:23<00:00,  6.61it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.2035


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.0708


100%|██████████| 157/157 [00:23<00:00,  6.71it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.0419


100%|██████████| 157/157 [00:23<00:00,  6.69it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.0281


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.021


100%|██████████| 157/157 [00:23<00:00,  6.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.0173


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.0149


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.0144


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]


1.0 0.0139
42


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9881


100%|██████████| 157/157 [00:23<00:00,  6.58it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.8512


100%|██████████| 157/157 [00:23<00:00,  6.63it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.4825


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.2297


100%|██████████| 157/157 [00:23<00:00,  6.63it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.1546


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.1175


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.1011


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.0938


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.088


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.0831


100%|██████████| 157/157 [00:24<00:00,  6.54it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.0814
99


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9669


100%|██████████| 157/157 [00:23<00:00,  6.55it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.657


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.2991


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.1054


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.0703


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.0631


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.0612


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.0581


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.053


100%|██████████| 157/157 [00:23<00:00,  6.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.0536


100%|██████████| 157/157 [00:23<00:00,  6.64it/s]

1.0 0.0578





In [52]:
# structure
cfg = ConfigObj(basecfg_path)
cfg['darts']['sampling_mode'] = 'naive'
cfg['device'] = 'cuda'
results = []
for s in seeds:
    print (s)
    sc = SearchCNNController(**cfg)
    if fine_tune:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49_fine.ckp'))
    else:
        sc.load_state_dict(torch.load(ckp_path+str(s)+'_49.ckp').state_dict())
    results.append([])
    fix_structure(sc)
    ones = []
    for i, alpha in enumerate(sc.alpha_reduce):
        for j, subalpha in enumerate(alpha):
            ones.append((i,j, torch.argmax(subalpha)))

    sc = sc.to('cuda')
    sc.eval()
    for eps in np.linspace(0.0, 1.0, 11):
        correct = 0
        total = 0

        for x,y in tqdm.tqdm(valid_loader):
            to_prune = random.sample(ones,  int(len(ones)*eps))
            for i, j, argm in to_prune:
                sc.alpha_reduce[i][j].data *= 0

            x = x.cuda()                                                          
            y = y.cuda()
            out = sc(x)            
            correct += torch.eq(torch.argmax(out, 1), y).sum().cpu().detach().numpy()
            total += len(x)
            for i, j, argm in to_prune:
                sc.alpha_reduce[i][j][argm].data += 1
        print (eps, correct*1.0/total*1.0)
        results.append( (eps, float(correct*1.0/total*1.0)))
with open(name+'_struct.json', 'w') as out:
    out.write(json.dumps(results))

  0%|          | 0/157 [00:00<?, ?it/s]

0


100%|██████████| 157/157 [00:04<00:00, 36.74it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9822


100%|██████████| 157/157 [00:04<00:00, 36.10it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.7853


100%|██████████| 157/157 [00:04<00:00, 36.46it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.6204


100%|██████████| 157/157 [00:04<00:00, 36.31it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.4197


100%|██████████| 157/157 [00:04<00:00, 36.28it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.2909


100%|██████████| 157/157 [00:04<00:00, 36.26it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.1784


100%|██████████| 157/157 [00:04<00:00, 35.24it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.1539


100%|██████████| 157/157 [00:04<00:00, 36.21it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.1349


100%|██████████| 157/157 [00:04<00:00, 36.07it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.1232


100%|██████████| 157/157 [00:04<00:00, 35.97it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.1202


100%|██████████| 157/157 [00:04<00:00, 35.93it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.1009
13


100%|██████████| 157/157 [00:04<00:00, 36.85it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9895


100%|██████████| 157/157 [00:04<00:00, 36.66it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.9092


100%|██████████| 157/157 [00:04<00:00, 36.62it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.7689


100%|██████████| 157/157 [00:04<00:00, 36.32it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.4794


100%|██████████| 157/157 [00:04<00:00, 36.46it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.4108


100%|██████████| 157/157 [00:04<00:00, 36.31it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.2338


100%|██████████| 157/157 [00:04<00:00, 36.19it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.2113


100%|██████████| 157/157 [00:04<00:00, 36.09it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.1749


100%|██████████| 157/157 [00:04<00:00, 36.08it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.1288


100%|██████████| 157/157 [00:04<00:00, 36.02it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.1242


100%|██████████| 157/157 [00:04<00:00, 35.79it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.1009
21


100%|██████████| 157/157 [00:04<00:00, 36.43it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9803


100%|██████████| 157/157 [00:04<00:00, 36.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.8591


100%|██████████| 157/157 [00:04<00:00, 36.04it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.7668


100%|██████████| 157/157 [00:04<00:00, 36.13it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.5337


100%|██████████| 157/157 [00:04<00:00, 35.40it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.4606


100%|██████████| 157/157 [00:04<00:00, 35.90it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.292


100%|██████████| 157/157 [00:04<00:00, 35.52it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.2083


100%|██████████| 157/157 [00:04<00:00, 34.12it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.1963


100%|██████████| 157/157 [00:04<00:00, 35.61it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.129


100%|██████████| 157/157 [00:04<00:00, 35.97it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.1188


100%|██████████| 157/157 [00:04<00:00, 35.06it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.1009
42


100%|██████████| 157/157 [00:04<00:00, 36.28it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9881


100%|██████████| 157/157 [00:04<00:00, 36.63it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.8464


100%|██████████| 157/157 [00:04<00:00, 36.55it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.7254


100%|██████████| 157/157 [00:04<00:00, 36.46it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.4886


100%|██████████| 157/157 [00:04<00:00, 36.38it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.4084


100%|██████████| 157/157 [00:04<00:00, 36.29it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.297


100%|██████████| 157/157 [00:04<00:00, 36.25it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.2566


100%|██████████| 157/157 [00:04<00:00, 36.23it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.195


100%|██████████| 157/157 [00:04<00:00, 35.92it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.134


100%|██████████| 157/157 [00:04<00:00, 35.91it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.1222


100%|██████████| 157/157 [00:04<00:00, 35.65it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

1.0 0.1009
99


100%|██████████| 157/157 [00:04<00:00, 36.88it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.0 0.9669


100%|██████████| 157/157 [00:04<00:00, 36.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.1 0.8484


100%|██████████| 157/157 [00:04<00:00, 35.70it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.2 0.6605


100%|██████████| 157/157 [00:04<00:00, 36.48it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.30000000000000004 0.4646


100%|██████████| 157/157 [00:04<00:00, 36.40it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.4 0.407


100%|██████████| 157/157 [00:04<00:00, 36.31it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.5 0.2378


100%|██████████| 157/157 [00:04<00:00, 36.22it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.6000000000000001 0.2037


100%|██████████| 157/157 [00:04<00:00, 36.14it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.7000000000000001 0.1931


100%|██████████| 157/157 [00:04<00:00, 36.04it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.8 0.1271


100%|██████████| 157/157 [00:04<00:00, 35.98it/s]
  0%|          | 0/157 [00:00<?, ?it/s]

0.9 0.1138


100%|██████████| 157/157 [00:04<00:00, 35.94it/s]

1.0 0.1009





In [127]:
x = x.cuda()                                                          
y = y.cuda()
out = sc(x)            
torch.eq(torch.argmax(out, 1), y).sum()


tensor(60, device='cuda:0')

In [4]:
# retraining

In [5]:
# structure

In [6]:
# training

OrderedDict([('alpha_reduce.0',
              tensor([[ 2.0884e-04, -6.6295e-04, -1.6188e-04, -7.5637e-04,  8.9418e-04,
                        9.7369e-05,  4.7107e-04, -1.7893e-03],
                      [-4.0987e-04,  1.2731e-03, -2.0698e-03,  1.0106e-03,  2.1814e-05,
                        1.9141e-04, -8.4204e-04, -2.6669e-03]])),
             ('alpha_reduce.1',
              tensor([[-1.5093e-03,  9.3491e-04,  2.8583e-04,  1.2068e-03, -1.2244e-03,
                       -1.0065e-04, -7.2345e-04,  1.3378e-05],
                      [ 2.1611e-03, -3.4841e-04, -2.8028e-04, -3.6471e-04, -2.6839e-04,
                       -9.8860e-04, -1.2835e-04, -7.8459e-04],
                      [ 3.3627e-04, -5.9228e-04,  9.6359e-04, -7.1787e-04, -5.2517e-05,
                       -1.1599e-03, -6.3788e-04,  2.2427e-04]])),
             ('alpha_reduce.2',
              tensor([[ 4.4362e-04, -1.2023e-03,  1.9001e-03, -8.1682e-04, -7.5418e-04,
                        1.6238e-03, -5.1434e-04, -2.170