# Transfer Learning

## Loading Libraries

In [1]:
#Numerical Computing
import numpy as np

# Data Manipulation
import pandas as pd

# Data Visualization
import seaborn as sns
import matplotlib
import matplotlib_inline
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
import matplotlib.patches as patches

# Dataset's Iteration Performance
from tqdm import tqdm

# Time
import time

# OS
import re
import sys
import json
import string
import unicodedata
from glob import glob
from io import BytesIO
from imageio import imread
from zipfile import ZipFile
import requests, zipfile, io
from collections import Counter 
from urllib.request import urlopen


# Warnings 
import warnings
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)

# SciPy
from scipy.signal import convolve

# PyTorch
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import *
from torchvision.ops import nms
import torch.nn.functional as F
from torchvision import transforms
# from torchtext.datasets import AG_NEWS
# from torchtext.data.utils import get_tokenizer
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator


# IDLMAM Libraries
from idlmam import moveTo, run_epoch, set_seed, View, pad_and_pack
from idlmam import train_simple_network, set_seed, Flatten, weight_reset, train_network
from idlmam import LanguageNameDataset, pad_and_pack, EmbeddingPackable, LastTimeStep, LambdaLayer
from idlmam import AttentionAvg, GeneralScore, DotScore, AdditiveAttentionScore, ApplyAttention, getMaskByFill


# Scikit-Learn
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

#  IPython Display
from IPython.display import Latex
from IPython.display import display_pdf
from IPython.display import set_matplotlib_formats

  from .autonotebook import tqdm as notebook_tqdm


### Visualization Set-Up

In [2]:
%matplotlib inline

matplotlib_inline.backend_inline.set_matplotlib_formats('png', 'pdf')

### Setting Seeds & Device

In [3]:
torch.backends.cudnn.deterministic=True

set_seed(42)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")  

## Transferring Model Parameters

### Retrieving & Setting Data Up

In [8]:
data_url_zip = "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip"

# Downloading
if not os.path.isdir('./data/PetImages'):
    resp = urlopen(data_url_zip)
    zipfile = ZipFile(BytesIO(resp.read()))
    zipfile.extractall(path = './data')

# Removing Bad Files
bad_files = [
    './data/PetImages/Dog/11702.jpg',
    "./data/PetImages/Cat/666.jpg"
]
for f in bad_files:
    if os.path.isfile(f):
        os.remove(f)

In [None]:
all_images = torchvision.datasets.ImageFolder("./data/PetImages", transform=transforms.Compose(
    [
        transforms.Resize(130), 
        transforms.CenterCrop(128), 
        transforms.ToTensor(), 
    ]))

# Train Test Split
train_size = int(len(all_images)*0.8) 

test_size = len(all_images)-train_size 

In [None]:
# Random Split
train_data, test_data = torch.utils.data.random_split(all_images, (train_size, test_size)) 

In [None]:
# Batch Size
B = 128

# Data Loader
train_loader = DataLoader(train_data, batch_size=B, shuffle=True)
test_loader = DataLoader(test_data, batch_size=B)

In [None]:
f, axarr = plt.subplots(2,4, figsize=(20,10)) 

for i in range(2): 
    for j in range(4): 
        x, y = test_data[i*4+j] 
        axarr[i,j].imshow(x.numpy().transpose(1,2,0)) 
        axarr[i,j].text(0.0, 0.5, str(round(y,2)), dict(size=20, color='red'))

## Transfer Learning & Training with CNNs

In [None]:
# ResNet18 Set-Up
model = torchvision.models.resnet18()

# Surgical Procedure
model.fc = nn.Linear(model.fc.in_features, 2)

In [None]:
# Loss Fuction
loss = nn.CrossEntropyLoss()

In [None]:
# Model Training
normal_results = train_network(model, 
loss, 
train_loader, 
epochs=10, 
device=device, 
test_loader=test_loader, 
score_funcs={'Accuracy': accuracy_score})

In [None]:
sns.lineplot(x='epoch', 
y='test Accuracy', 
data=normal_results, 
label='Regular')

plt.title('RestNet18 Model Training')
plt.grid(True)
plt.show()

### Adjusting PreTrained Parameters

In [None]:
# Pretrained Model
model_pretrained = torchvision.models.resnet18(pretrained=True) 

# Surgical Adjustment
model_pretrained.fc = nn.Linear(model_pretrained.fc.in_features, 2)

In [None]:
# 1st Conv Filter into Numpy Tensor
filters_pretrained = model_pretrained.conv1.weight.data.cpu().numpy() 

In [None]:
# Shifting Range
filters_pretrained = filters_pretrained-np.min(filters_pretrained) 

# Rescaling
filters_pretrained = filters_pretrained/np.max(filters_pretrained)

In [None]:
# Reallocating Image Dims
filters_pretrained = np.moveaxis(filters_pretrained, 1, -1)

In [None]:
i_max = int(round(np.sqrt(filters_pretrained.shape[0]))) 
j_max = int(np.floor(filters_pretrained.shape[0]/float(i_max))) 

f, axarr = plt.subplots(i_max,j_max, figsize=(10,10)) 

for i in range(i_max): 
    for j in range(j_max): 
        indx = i*j_max+j 
        axarr[i,j].imshow(filters_pretrained[indx,:]) 
        axarr[i,j].set_axis_off() 

In [None]:
def visualizeFilters(conv_filters):
    conv_filters = conv_filters-np.min(conv_filters)
    conv_filters = conv_filters/np.max(conv_filters)
    conv_filters = np.moveaxis(conv_filters, 1, -1)
    
    i_max = int(round(np.sqrt(conv_filters.shape[0])))
    j_max = int(np.floor(conv_filters.shape[0]/float(i_max)))
    f, axarr = plt.subplots(i_max,j_max, figsize=(10,10))
    for i in range(i_max):
        for j in range(j_max):
            indx = i*j_max+j
            axarr[i,j].imshow(conv_filters[indx,:])
            axarr[i,j].set_axis_off()

In [None]:
filters_catdog = model.conv1.weight.data.cpu().numpy() 
visualizeFilters(filters_catdog) 

In [None]:
class NormalizeInput(nn.Module):
    def __init__(self, baseModel):
        """
        baseModel: the original ResNet model that needs to have it's inputs pre-processed
        """
        super(NormalizeInput, self).__init__()
        self.baseModel = baseModel 
        
        self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1), requires_grad=False) 
        self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1), requires_grad=False)
    
    def forward(self, input):
        input = (input-self.mean)/self.std
        return self.baseModel(input)

In [None]:
model_pretrained = NormalizeInput(model_pretrained)

### Training with Warm Stars

In [None]:
warmstart_results = train_network(model_pretrained, 
loss, 
train_loader, 
epochs=10, 
device=device, 
test_loader=test_loader, 
score_funcs={'Accuracy': accuracy_score})

In [None]:
# Regular Model
sns.lineplot(x='epoch', y='test Accuracy', data=normal_results, label='Regular')

# Warm Start
sns.lineplot(x='epoch', y='test Accuracy', data=warmstart_results, label='Warm')

plt.title('Regular vs. Warm StartModel')
plt.grid(True)
plt.show()

In [None]:
filters_catdog_finetuned = model_pretrained.baseModel.conv1.weight.data.cpu().numpy() 
visualizeFilters(filters_catdog_finetuned) 

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=7b262bf3-85b2-4421-a448-4fe589bc864f' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>