# resnet

> torchgeo resnet model adapter utils 

[![](https://raw.githubusercontent.com/butchland/fastai-torchgeo/master/assets/colab.svg)](https://colab.research.google.com/github/butchland/fastai-torchgeo/blob/master/nbs/02_resnet.ipynb)

In [None]:
#| default_exp resnet

In [None]:
#| hide
# check if in colab and install package as needed
![ -e /content ] && ! pip show fastai-torchgeo && pip install git+https://github.com/butchland/fastai-torchgeo.git
![ -e /content ] && ! pip show nbdev && pip install nbdev

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| exporti
import fastai.vision.all as fv
from fastai_torchgeo.core import GeoTensorImage
from functools import partial
from fastai_torchgeo.data import GeoImageBlock
import torch.nn as nn

In [None]:
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.datasets import EuroSAT100

In [None]:
#| export
def make_resnet_model(model:nn.Module, # pretrained torchgeo model
                      n_out:int, # number of outputs
                     ) -> nn.Module: # new model with a new head for finetuning
    """
    Creates a ResNet model by cutting the fully connected (fc) layer of a pretrained ResNet model and replacing it with a new head. 

    The new head is created by concatenating adaptive pooling layers and a linear layer followed by an activation
    function. The new head is then appended to the cut model

    #### Parameters
    
    - `model` (torch.nn.Module): A pretrained ResNet model.
    - `n_out` (int): The number of output classes.

    #### Returns
   
    - `torch.nn.Module`: The ResNet model with the new head.
    """

    ll = list(enumerate(model.children()))
    cut = next(i for i,o in reversed(ll) if fv.has_pool_type(o))
    c_model = fv.cut_model(model, cut)
    nf = model.num_features
    res = fv.add_head(c_model,
                      nf, 
                      n_out=n_out, 
                      init = nn.init.kaiming_normal_,
                      head=None, 
                      concat_pool=True, 
                      pool=True,
                      lin_ftrs=None, 
                      ps=0.5, 
                      first_bn=True,
                      bn_final=False,
                      lin_first=False,
                      y_range=None,)
    return res

In [None]:
#| export
def resnet_split(m:nn.Module, # A model
                ) -> [nn.Module]: # A list of parameter groups
    """
    Splits the resnet model parameters into parameter groups
    
    Used by fastai for discriminative learning rates (finetuning)
    
    #### Parameters
    - `m` (nn.Module): Model 
    
    #### Returns
    - `[torch.nn.Module]` : A list of parameter groups 
     
    """
    return fv.L(m[0][:6], m[0][6:], m[1:]).map(fv.params)


#### Adapting a pretrained resnet torchgeo model for a fastai Learner

In [None]:
from torchgeo.models import ResNet18_Weights, resnet18
from torchgeo.datamodules import EuroSATDataModule

In [None]:
pretrained = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO, num_classes=10) # load pretrained weights

In [None]:
model = make_resnet_model(pretrained, n_out=10) 

In [None]:
dblock = fv.DataBlock(blocks=(GeoImageBlock(), fv.CategoryBlock()),
                      get_items=fv.get_image_files,
                      splitter=fv.RandomSplitter(valid_pct=0.1, seed=42),
                      get_y=fv.parent_label,
                      item_tfms=fv.Resize(64),
                      batch_tfms=[fv.Normalize.from_stats(EuroSATDataModule.mean, EuroSATDataModule.std)],
                     )

In [None]:
sat_path = fv.untar_data(EuroSAT100.url)

In [None]:
dls = dblock.dataloaders(sat_path, bs=64)

In [None]:
model = make_resnet_model(pretrained, n_out=10) 

In [None]:
batch_size=64
num_workers = fv.defaults.cpus

In [None]:
# datamodule = EuroSATDataModule(root=sat_path,batch_size=batch_size, num_workers=num_workers, download=True)

In [None]:
# %%time
# datamodule.prepare_data()

In [None]:
learn = fv.Learner(
    dls, 
    model,
    loss_func=fv.CrossEntropyLossFlat(),
    metrics=[fv.error_rate,fv.accuracy],
    splitter=resnet_split,
)
# freeze uses parameter groups created by `resnet_split` 
# to lock parameters of pretrained model except for the model head

learn.freeze()

In [None]:
# note: only head parameter group is trainable (except BatchNorm layers w/ch are always trainable)
learn.summary()

Sequential (Input shape: 64 x 13 x 64 x 64)
Layer (type)         Output Shape         Param #    Trainable 
                     64 x 64 x 32 x 32   
Conv2d                                    40768      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 64 x 16 x 16   
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                       

In [None]:
learn.fine_tune(2)

epoch,train_loss,valid_loss,error_rate,accuracy,time
0,4.042061,2.311349,0.8,0.2,00:01


epoch,train_loss,valid_loss,error_rate,accuracy,time
0,3.859293,2.302754,1.0,0.0,00:02
1,3.652791,2.247182,0.9,0.1,00:02


In [None]:
# unlock all weights and make the whole model trainable
learn.unfreeze()

In [None]:
# all parameters are now trainable
learn.summary()

Sequential (Input shape: 64 x 13 x 64 x 64)
Layer (type)         Output Shape         Param #    Trainable 
                     64 x 64 x 32 x 32   
Conv2d                                    40768      True      
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 64 x 16 x 16   
MaxPool2d                                                      
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      True      
BatchNorm2d                               128        True      
ReLU                       

In [None]:
# uses discriminative learning rates across parameter groups to give 
# the "upper lavers" higher learning rates while keeping the "lower layers"
# to a lower learning rates, nearly freezing their weights.
learn.fit_one_cycle(5, lr_max=slice(2.e-3,8.e-6)) 

epoch,train_loss,valid_loss,error_rate,accuracy,time
0,2.139503,2.269313,0.9,0.1,00:02
1,2.047539,2.380141,1.0,0.0,00:03
2,2.059563,2.353606,0.9,0.1,00:02
3,2.036056,2.318051,0.9,0.1,00:02
4,1.905594,2.308228,0.9,0.1,00:03
