In [1]:
from models.base_network import BaseNetwork
from file_handling.load_datasets import load_mnist
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from preprocessing.noise_models import gaussian_noise

  from ._conv import register_converters as _register_converters


In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: 255*x)])
test_data = datasets.CIFAR10(root='./cifar-10', train=False, download=False, transform=transform)
train_data = datasets.CIFAR10(root='./cifar-10', train=True, download=False, transform=transform)

In [3]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1000, shuffle=False, num_workers=2)

In [8]:
train_loader.dataset[0][0]

tensor([[[ 59.,  43.,  50.,  ..., 158., 152., 148.],
         [ 16.,   0.,  18.,  ..., 123., 119., 122.],
         [ 25.,  16.,  49.,  ..., 118., 120., 109.],
         ...,
         [208., 201., 198.,  ..., 160.,  56.,  53.],
         [180., 173., 186.,  ..., 184.,  97.,  83.],
         [177., 168., 179.,  ..., 216., 151., 123.]],

        [[ 62.,  46.,  48.,  ..., 132., 125., 124.],
         [ 20.,   0.,   8.,  ...,  88.,  83.,  87.],
         [ 24.,   7.,  27.,  ...,  84.,  84.,  73.],
         ...,
         [170., 153., 161.,  ..., 133.,  31.,  34.],
         [139., 123., 144.,  ..., 148.,  62.,  53.],
         [144., 129., 142.,  ..., 184., 118.,  92.]],

        [[ 63.,  45.,  43.,  ..., 108., 102., 103.],
         [ 20.,   0.,   0.,  ...,  55.,  50.,  57.],
         [ 21.,   0.,   8.,  ...,  50.,  50.,  42.],
         ...,
         [ 96.,  34.,  26.,  ...,  70.,   7.,  20.],
         [ 96.,  42.,  30.,  ...,  94.,  34.,  34.],
         [116.,  94.,  87.,  ..., 140.,  84.,  72.]]]

In [31]:
# reg_net = BaseNetwork("reg_net", [(3, 16, 5),(16,32,5),(800,5000), (5000,100),(100,10)],
#                                 ["conv", "conv", "fc", "fc", "fc"])
reg_net = BaseNetwork("reg_net", [(3*1024,5000),(5000,100),(100,10)],
                                ["fc", "fc", "fc"])
reg_net_opt = optim.SGD(reg_net.parameters(), lr=0.00015, momentum=0.9)

In [18]:
# emb_net = BaseNetwork("emb_net", [(3, 16, 5),(16,32,5), (800,5000),(5000,100),(100,10)],
#                                 ["conv", "conv", "emb", "fc", "fc"])
emb_net = BaseNetwork("emb_net", [(3*1024,5000),(5000,100),(100,10)],
                                ["emb", "fc", "fc"])
emb_net_opt = optim.SGD(emb_net.parameters(), lr=0.05, momentum=0.9)

In [27]:
criterion = nn.CrossEntropyLoss()

In [19]:
emb_net.train_model(train_loader, 15, emb_net_opt, criterion)

Training Model


[1,    16] loss: 2.30106
[1,    32] loss: 2.31283
[1,    48] loss: 2.30354
[1,    64] loss: 2.32058
[1,    80] loss: 2.30881
[1,    96] loss: 2.31354
[1,   112] loss: 2.31029
[1,   128] loss: 2.30898
[1,   144] loss: 2.30271
[1,   160] loss: 2.30429
[1,   176] loss: 2.29985
[1,   192] loss: 2.29576
[1,   208] loss: 2.29453
[1,   224] loss: 2.28320
[1,   240] loss: 2.27474
[1,   256] loss: 2.28868
[1,   272] loss: 2.25877
[1,   288] loss: 2.26135
[1,   304] loss: 2.25397
[1,   320] loss: 2.21960
[1,   336] loss: 2.24174
[1,   352] loss: 2.18894
[1,   368] loss: 2.19148
[1,   384] loss: 2.16115
[1,   400] loss: 2.16512
[1,   416] loss: 2.16173
[1,   432] loss: 2.21420
[1,   448] loss: 2.15645
[1,   464] loss: 2.15716
[1,   480] loss: 2.18131
[1,   496] loss: 2.12904
[1,   512] loss: 2.10541
[1,   528] loss: 2.17108
[1,   544] loss: 2.16484
[1,   560] loss: 2.16627
[1,   576] loss: 2.14330
[1,   592] loss: 2.12008
[1,   608] loss: 2.09750
[1,   624] loss: 2.09665
[1,   640] loss: 2.11739


[2,    16] loss: 1.93534
[2,    32] loss: 1.85886
[2,    48] loss: 1.86332
[2,    64] loss: 1.87091
[2,    80] loss: 1.86911
[2,    96] loss: 1.88764
[2,   112] loss: 1.84138
[2,   128] loss: 1.92349
[2,   144] loss: 1.95233
[2,   160] loss: 1.89323
[2,   176] loss: 1.92351
[2,   192] loss: 1.82410
[2,   208] loss: 1.94680
[2,   224] loss: 1.87491
[2,   240] loss: 1.82010
[2,   256] loss: 1.89515
[2,   272] loss: 1.92870
[2,   288] loss: 1.90460
[2,   304] loss: 1.87405
[2,   320] loss: 1.97926
[2,   336] loss: 1.93512
[2,   352] loss: 1.84637
[2,   368] loss: 1.86228
[2,   384] loss: 1.82682
[2,   400] loss: 1.83731
[2,   416] loss: 1.86499
[2,   432] loss: 1.92254
[2,   448] loss: 1.86139
[2,   464] loss: 1.86067
[2,   480] loss: 1.92056
[2,   496] loss: 1.91361
[2,   512] loss: 1.78054
[2,   528] loss: 1.88927
[2,   544] loss: 1.86760
[2,   560] loss: 1.91152
[2,   576] loss: 1.92507
[2,   592] loss: 1.93294
[2,   608] loss: 1.82468
[2,   624] loss: 1.80455
[2,   640] loss: 1.99004


[3,    16] loss: 1.81872
[3,    32] loss: 1.75473
[3,    48] loss: 1.79427
[3,    64] loss: 1.80119
[3,    80] loss: 1.75679
[3,    96] loss: 1.79590
[3,   112] loss: 1.76667
[3,   128] loss: 1.84920
[3,   144] loss: 1.83579
[3,   160] loss: 1.83602
[3,   176] loss: 1.83347
[3,   192] loss: 1.81798
[3,   208] loss: 1.85109
[3,   224] loss: 1.80931
[3,   240] loss: 1.73310
[3,   256] loss: 1.86338
[3,   272] loss: 1.89409
[3,   288] loss: 1.85659
[3,   304] loss: 1.78586
[3,   320] loss: 1.93813
[3,   336] loss: 1.85783
[3,   352] loss: 1.79159
[3,   368] loss: 1.79396
[3,   384] loss: 1.75904
[3,   400] loss: 1.74917
[3,   416] loss: 1.76314
[3,   432] loss: 1.84636
[3,   448] loss: 1.73926
[3,   464] loss: 1.72827
[3,   480] loss: 1.94809
[3,   496] loss: 1.79520
[3,   512] loss: 1.69953
[3,   528] loss: 1.80997
[3,   544] loss: 1.75948
[3,   560] loss: 1.89998
[3,   576] loss: 1.88025
[3,   592] loss: 1.85681
[3,   608] loss: 1.74311
[3,   624] loss: 1.72521
[3,   640] loss: 1.93519


[4,    16] loss: 1.72712
[4,    32] loss: 1.72438
[4,    48] loss: 1.74880
[4,    64] loss: 1.73224
[4,    80] loss: 1.68013
[4,    96] loss: 1.75433
[4,   112] loss: 1.70461
[4,   128] loss: 1.79459
[4,   144] loss: 1.78349
[4,   160] loss: 1.79614
[4,   176] loss: 1.77895
[4,   192] loss: 1.73570
[4,   208] loss: 1.81890
[4,   224] loss: 1.77043
[4,   240] loss: 1.67945
[4,   256] loss: 1.81528
[4,   272] loss: 1.83514
[4,   288] loss: 1.79282
[4,   304] loss: 1.77410
[4,   320] loss: 1.87671
[4,   336] loss: 1.80741
[4,   352] loss: 1.75035
[4,   368] loss: 1.74021
[4,   384] loss: 1.71537
[4,   400] loss: 1.67043
[4,   416] loss: 1.70378
[4,   432] loss: 1.84821
[4,   448] loss: 1.65212
[4,   464] loss: 1.69315
[4,   480] loss: 1.93076
[4,   496] loss: 1.76826
[4,   512] loss: 1.66953
[4,   528] loss: 1.79387
[4,   544] loss: 1.69817
[4,   560] loss: 1.87336
[4,   576] loss: 1.81562
[4,   592] loss: 1.78078
[4,   608] loss: 1.68618
[4,   624] loss: 1.66644
[4,   640] loss: 1.86165


[5,    16] loss: 1.69392
[5,    32] loss: 1.72918
[5,    48] loss: 1.72154
[5,    64] loss: 1.72194
[5,    80] loss: 1.65962
[5,    96] loss: 1.72155
[5,   112] loss: 1.68612
[5,   128] loss: 1.77817
[5,   144] loss: 1.72563
[5,   160] loss: 1.74230
[5,   176] loss: 1.73996
[5,   192] loss: 1.70702
[5,   208] loss: 1.76976
[5,   224] loss: 1.71451
[5,   240] loss: 1.65204
[5,   256] loss: 1.79736
[5,   272] loss: 1.80908
[5,   288] loss: 1.76801
[5,   304] loss: 1.76430
[5,   320] loss: 1.88864
[5,   336] loss: 1.76761
[5,   352] loss: 1.69206
[5,   368] loss: 1.72591
[5,   384] loss: 1.65838
[5,   400] loss: 1.62263
[5,   416] loss: 1.67934
[5,   432] loss: 1.84346
[5,   448] loss: 1.64084
[5,   464] loss: 1.66942
[5,   480] loss: 1.88678
[5,   496] loss: 1.74415
[5,   512] loss: 1.65897
[5,   528] loss: 1.78092
[5,   544] loss: 1.64718
[5,   560] loss: 1.85018
[5,   576] loss: 1.78635
[5,   592] loss: 1.72789
[5,   608] loss: 1.64510
[5,   624] loss: 1.62813
[5,   640] loss: 1.82704


Process Process-23:
Process Process-24:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/Applications/anaconda/l

[5,  2160] loss: 1.69833


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x1c4904b9e8>>
Traceback (most recent call last):
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 399, in __del__
    self._shutdown_workers()
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 378, in _shutdown_workers
    self.worker_result_queue.get()
  File "/Applications/anaconda/lib/python3.6/multiprocessing/queues.py", line 337, in get
    return _ForkingPickler.loads(res)
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 167, in rebuild_storage_filename
    storage = cls._new_shared_filename(manager, handle, size)
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 227, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 7290) exited unexpectedl

KeyboardInterrupt: 

In [20]:
emb_net.save_model("saved_models/emb_net_test.pt")

In [21]:
emb_net.test_model_once(test_loader, None, 0)

Evaluating model once


36.7

In [22]:
emb_net.test_model(test_loader, gaussian_noise)

Evaluating model once
Evaluating model once
Evaluating model once
Evaluating model once
Evaluating model once


[36.61, 36.61, 36.65, 36.45, 36.62]

In [32]:
reg_net.train_model(train_loader, 15, reg_net_opt, criterion)

Training Model


[1,    16] loss: 52.72527
[1,    32] loss: 2.60997
[1,    48] loss: 2.30600
[1,    64] loss: 2.31226
[1,    80] loss: 2.30366
[1,    96] loss: 2.30575
[1,   112] loss: 2.30334
[1,   128] loss: 2.30697
[1,   144] loss: 2.30553
[1,   160] loss: 2.30706
[1,   176] loss: 2.31166
[1,   192] loss: 2.30756
[1,   208] loss: 2.30496
[1,   224] loss: 2.30650
[1,   240] loss: 2.30314
[1,   256] loss: 2.30307
[1,   272] loss: 2.30361
[1,   288] loss: 2.30522
[1,   304] loss: 2.30741
[1,   320] loss: 2.30394
[1,   336] loss: 2.30833
[1,   352] loss: 2.29948


Process Process-46:
Process Process-45:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Applications/anaconda/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/Applications/anaconda/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/Applications/anaconda/lib/python3.6/site-packages/torch/utils/data/dataloader.py"

KeyboardInterrupt: 

In [None]:
# reg_net.save_model("saved_models/reg_net_test.pt")

In [None]:
reg_net.load_model("saved_models/reg_net_test.pt")

In [None]:
reg_net.test_model_once(test_loader, None, 0)

In [None]:
reg_net.test_model(test_loader, gaussian_noise)

In [None]:
# emb_net.train_model(train_loader, 15, emb_net_opt, criterion)

In [None]:
trainset = datasets.CIFAR10(root='../data_cifar_10', train=True,
                                       download=False, transform=transform.ToTensor())
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=False, num_workers=2)

In [None]:
testset = datasets.CIFAR10(root='../data_cifar_10', train=False,
                                       download=False, transform=transform.ToTensor())
test_loader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)