-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·40 lines (27 loc) · 857 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/python3
from __future__ import print_function
import theano
import theano.tensor.signal.downsample
from utils import *
import time
import pickle
import sys
from layers import *
from network import *
from train import *
from prepare_cifar10 import *
def run(template, path):
cifar = prepare_cifar10()
cifar_train = cifar.train
cifar_train_stream = cifar.train_stream
cifar_validation = cifar.validation
cifar_validation_stream = cifar.validation_stream
cifar_test = cifar.test
cifar_test_stream = cifar.test_stream
print("Compiling...", end = " ")
sys.stdout.flush()
network = compile(template)
print("DONE")
sys.stdout.flush()
train(network, cifar_train_stream, cifar_validation_stream, 1e-3, 0.7, path)
print("Test error rate is %f%%" %(compute_error_rate(cifar_test_stream, network.predict) * 100.0,))