forked from render-examples/fastai-v3
-
Notifications
You must be signed in to change notification settings - Fork 5
/
learner.py
51 lines (40 loc) · 1.28 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
BASE_PATH = './'
EXAMPLE_PATH = './example.jpg'
# Imports
from fastai.vision import *
from fastai.widgets import *
# Set the file path
path = Path(BASE_PATH)
# Create data loader / manager
data_bunch = ImageDataBunch.from_folder(
# Data director
path,
# Reserve 20 percent of our images for our validation set
valid_pct=0.2,
# Transforms to apply to the image to create variations on our training image
ds_tfms=get_transforms(max_zoom=1.0),
# Dimension of image to process
size=224,
# Num workers to use
num_workers=4
).normalize(
# Use imagenet stats to normalize (to match what was pre-trained with)
imagenet_stats
)
# Create our learner to process the training data and update our model
learner = create_cnn(data_bunch, models.resnet34, pretrained=True, metrics=error_rate)
# Start training
learner.fit_one_cycle(10)
# Unfreeze the model
learner.unfreeze()
# Train the entire model some more
learner.fit_one_cycle(10, max_lr=slice(1e-3, 1e-5))
# Reduce the learning rate and train some more
learner.fit_one_cycle(10, max_lr=slice(1e-4, 1e-6))
# Grab an example image
example_image = open_image(Path(EXAMPLE_PATH))
# Make a prediction
predicted_class, _, _ = learner.predict(example_image)
print(predicted_class)
# Export the model
learner.export()