## Transfer Learning

Remeber how in lesson 2 we saw that Convolutional Neural Networks learn simple features in the first layer, and more complicated features in the later layers?

<img src="../data/lesson_images/layer1.png" alt="Activations of the first layer of a CNN" width="300" caption="Activations of the first layer of a CNN (courtesy of Matthew D. Zeiler and Rob Fergus)" id="img_layer1">


<img src="../data/lesson_images/layer2.png" alt="Activations of the second layer of a CNN" width="800" caption="Activations of the second layer of a CNN (courtesy of Matthew D. Zeiler and Rob Fergus)" id="img_layer2">

<img src="../data/lesson_images/chapter2_layer3.PNG" alt="Activations of the third layer of a CNN" width="800" caption="Activations of the third layer of a CNN (courtesy of Matthew D. Zeiler and Rob Fergus)" id="img_layer3">

<img src="../data/lesson_images/chapter2_layer4and5.PNG" alt="Activations of layers 4 and 5 of a CNN" width="800" caption="Activations of layers 4 and 5 of a CNN (courtesy of Matthew D. Zeiler and Rob Fergus)" id="img_layer4">

The idea of **transfer learning** is that we take the first and middle layers from a CNN that has *already been trained*. This is because they are able to recognise generally useful features such as circles and faces. 

We then add fresh last layers - the **head** - and train them to recognise the particular objects we want to recognise.

In [None]:
# DO NOT CHANGE THIS CODE

import fastbook
fastbook.setup_book()

from icrawler.builtin import BingImageCrawler
from pathlib import Path

In [None]:
# DO NOT CHANGE THIS CODE

from fastbook import *
from fastai.vision.widgets import *

In [None]:
# Feel free to change these categories to something you want to teach your model
# ..eg dogs vs cats
# ..or cats vs bats 

# The word on the left is the name of the category
# The sentence on the right is a search term that will find images of that category on the internet

categories = {
    'oyster':             'oyster mushroom pleurotus ostreatus',
    'death_cap':          'death cap mushroom amanita phalloides',
}

In [None]:
# DO NOT CHANGE THIS CODE

# Here we download images of the above categories from the internet to use as training data

training_images_folder = 'training_images'

for cls, term in categories.items():
    dest = Path(training_images_folder)/cls
    dest.mkdir(parents=True, exist_ok=True)
    crawler = BingImageCrawler(storage={'root_dir': str(dest)})
    crawler.crawl(keyword=term, max_num=10, min_size=(200,200), overwrite=True)


In [None]:
# DO NOT CHANGE THIS CODE

# Here we remove any images that failed to download correctly

failed = verify_images(get_image_files(training_images_folder))
failed.map(Path.unlink)  # delete any that failed to open

In [None]:
# DO NOT CHANGE THIS CODE

training_image_filenames = get_image_files(training_images_folder)

In [None]:
# DO NOT CHANGE THIS CODE

# Let's look at one of the images we downloaded

im = Image.open(training_image_filenames[0])
im.to_thumb(64,64)

In [None]:
# DO NOT CHANGE THIS CODE

# These are NOT AI preditions
# Instead, these are some of the images we downloaded, along with their category names
# You should check that the images are correct for the category names as this is our training data

path = Path(training_images_folder)
dblock = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(0.2, seed=42),
    get_y=parent_label,
    item_tfms=Resize(224)
)

dls = dblock.dataloaders(path, bs=16)
dls.show_batch(max_n=9)


In [None]:
# DO NOT CHANGE THIS CODE

# Here we copy an existing trained model (resnet34) and train the head of the model to recognize our categories

learn = vision_learner(dls, resnet18, metrics=error_rate).to_fp16()
learn.fine_tune(3)

In [None]:
# Feel free to change this code

# Change the URL to an image on the internet you want to test 
# The image should be of one of the categories you trained your model on

# 1) Make a test folder
test_path = Path('test_images')
test_path.mkdir(exist_ok=True)

# 2) Download your image
url = 'https://live.staticflickr.com/2927/14293282232_03499896bb.jpg'  # ← replace with your URL
save_filename = 'testOysterMushroom'                       # derive a filename
save_path = test_path/save_filename

headers = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)',
    'Referer': 'https://www.flickr.com/'
}

resp = requests.get(url, headers=headers)
resp.raise_for_status()  # will crash if download failed
with open(save_path,'wb') as f:
    f.write(resp.content)


image_path = test_path/save_filename  

plt.imshow(PILImage.create(image_path))
plt.axis('off')
plt.title('Test Image of an Oyster Mushroom')
plt.show()


In [None]:
# Feel free to change this code

# Here we predict the category of the image we downloaded

# If you want to upload your own image, save it in the test_images folder as "my_image.jpg" 
# .. and uncomment the line below: image_path = test_path/'my_image.jpg'

# image_path = test_path/'my_image.jpg'  


image = PILImage.create(image_path) # open the image we downloaded
pred_class, pred_idx, probs = learn.predict(image)


fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Left: show the image with its predicted label
axes[0].imshow(image)
axes[0].axis('off')
axes[0].set_title(f'Predicted: {pred_class}\nConfidence: {probs[pred_idx]:.0%}')

# Right: bar-chart of all class probabilities
classes = learn.dls.vocab
axes[1].barh(classes, probs)
axes[1].invert_yaxis()                # highest probability at top
axes[1].set_xlabel('Probability')
axes[1].set_title('Class Probabilities')

plt.tight_layout()
plt.show()

