# resnet

> torchgeo resnet model adapter utils 

[![](https://raw.githubusercontent.com/butchland/fastai-torchgeo/master/assets/colab.svg)](https://colab.research.google.com/github/butchland/fastai-torchgeo/blob/master/nbs/01_data.ipynb)

In [None]:
#| default_exp resnet

In [None]:
#| hide
# check if in colab and install package as needed
![ -e /content ] && ! pip show fastai-torchgeo && pip install git+https://github.com/butchland/fastai-torchgeo.git
![ -e /content ] && ! pip show nbdev && pip install nbdev

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| exporti
import fastai.vision.all as fv
from fastai_torchgeo.core import GeoTensorImage
from functools import partial
from fastai_torchgeo.data import GeoImageBlock
import torch.nn as nn

In [None]:
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.datasets import EuroSAT100

In [None]:
#| export
def make_resnet_model(model, n_out):
    ll = list(enumerate(model.children()))
    cut = next(i for i,o in reversed(ll) if fv.has_pool_type(o))
    c_model = fv.cut_model(model, cut)
    nf = model.num_features
    res = fv.add_head(c_model,
                      nf, 
                      n_out=n_out, 
                      init = nn.init.kaiming_normal_,
                      head=None, 
                      concat_pool=True, 
                      pool=True,
                      lin_ftrs=None, 
                      ps=0.5, 
                      first_bn=True,
                      bn_final=False,
                      lin_first=False,
                      y_range=None,)
    return res

In [None]:
#| export
def resnet_split(m):
    return fv.L(m[0][:6], m[0][6:], m[1:]).map(fv.params)


In [None]:
from torchgeo.models import ResNet18_Weights, resnet18
from torchgeo.datamodules import EuroSATDataModule

In [None]:
pretrained = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO, num_classes=10) # load pretrained weights

In [None]:
model = make_resnet_model(pretrained, n_out=10) 

In [None]:
dblock = fv.DataBlock(blocks=(GeoImageBlock(), fv.CategoryBlock()),
                      get_items=fv.get_image_files,
                      splitter=fv.RandomSplitter(valid_pct=0.1, seed=42),
                      get_y=fv.parent_label,
                      item_tfms=fv.Resize(64),
                      batch_tfms=[fv.Normalize.from_stats(EuroSATDataModule.mean, EuroSATDataModule.std)],
                     )

In [None]:
sat_path = fv.untar_data(EuroSAT100.url)

In [None]:
dls = dblock.dataloaders(sat_path, bs=64)

In [None]:
model = make_resnet_model(pretrained, n_out=10) 

In [None]:
batch_size=64
num_workers = fv.defaults.cpus

In [None]:
# datamodule = EuroSATDataModule(root=sat_path,batch_size=batch_size, num_workers=num_workers, download=True)

In [None]:
# %%time
# datamodule.prepare_data()

In [None]:
learn = fv.Learner(
    dls, 
    model,
    loss_func=fv.CrossEntropyLossFlat(),
    metrics=[fv.error_rate,fv.accuracy],
    splitter=resnet_split,
)
learn.freeze()

In [None]:
learn.summary()

Sequential (Input shape: 64 x 13 x 64 x 64)
Layer (type)         Output Shape         Param #    Trainable 
                     64 x 64 x 32 x 32   
Conv2d                                    40768      False     
BatchNorm2d                               128        True      
ReLU                                                           
____________________________________________________________________________
                     64 x 64 x 16 x 16   
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Identity                                                       
ReLU                                                           
Identity                                                       
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                       

In [None]:
learn.fine_tune(10)

epoch,train_loss,valid_loss,error_rate,accuracy,time
0,1.002229,1.870585,0.5,0.5,00:07


epoch,train_loss,valid_loss,error_rate,accuracy,time
0,1.135068,1.851963,0.5,0.5,00:03
1,1.071356,1.824092,0.5,0.5,00:06
2,0.970566,1.785548,0.5,0.5,00:03
3,0.801415,1.728489,0.5,0.5,00:03
4,0.692353,1.671152,0.5,0.5,00:04
5,0.607209,1.615755,0.5,0.5,00:03
6,0.541593,1.56689,0.5,0.5,00:04
7,0.482271,1.519107,0.5,0.5,00:03
8,0.434583,1.487241,0.5,0.5,00:03
9,0.395028,1.466383,0.5,0.5,00:03
