# Example how to load the model with pretrained weights

In [1]:
from gchm.models.xception_sentinel2 import xceptionS2_08blocks_256

## Let's print the doc string
Please have a look at the choices for the argument `model_weights` .


In [2]:
print(xceptionS2_08blocks_256.__doc__)


    The model used in 'A high-resolution canopy height model of the Earth.'
    It is a smaller version (with only 8 sepconv blocks and 256 sepconv filters) of the model described in:
    'Country-wide high-resolution vegetation height mapping with Sentinel-2' <https://arxiv.org/abs/1904.13270>

    Args:
        in_channels (int): Number of channels of the input. (12 sentinel-2 bands + 3 lat-lon-encoding) = 15 channels)
        out_channels (int): Dimension of the pixel-wise output.
        returns (string): Key specifying the return. Choices: ['targets', 'variances_exp', 'variances']
        model_weights (string): This can either be set to the checkpoint path ".pt" or to one of the options below.

    Model weights choices:
        None: Randomly initialize the model weights.
        Path: Path to a pretrained checkpoint file. (E.g. './trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/checkpoint.pt')
        'GLOBAL_GEDI_MODEL_0': This will download the pretrained models and 

## Creating the model with randomly initialized weights

By default the model weights will be randomly initialized.

In [3]:
model = xceptionS2_08blocks_256()


## Loading pretrained weights

There are two ways to load pretrained model weights. 

1. Automatically download the pretrained weights by setting the given keys.
2. Set the path to a given checkpoint.


### Automatically download the pretrained weights

As described above, setting `model_weights="GLOBAL_GEDI_MODEL_0"` will download and extract the pretrained models into `download_dir`. 

The download is skipped, if the pretrained weights were already downloaded.


In [4]:
model = xceptionS2_08blocks_256(in_channels=15, out_channels=1, 
                                model_weights="GLOBAL_GEDI_MODEL_0",
                                returns="variances_exp",
                                download_dir="./trained_models")


model_weights set to:  GLOBAL_GEDI_MODEL_0
downloading pretrained models...


100%|███████████████████████████████████████████████████████████████████████| 148M/148M [00:12<00:00, 12.1MB/s]


unzipping...
Archive:  ./trained_models/trained_models_GLOBAL_GEDI_2019_2020.zip
   creating: ./trained_models/GLOBAL_GEDI_2019_2020/
   creating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/
   creating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/args.json  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/checkpoint.pt  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/train_input_mean.npy  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/train_input_std.npy  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/train_target_mean.npy  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/FT_Lm_SRCB/train_target_std.npy  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/args.json  
  inflating: ./trained_models/GLOBAL_GEDI_2019_2020/model_0/checkpoint.pt  
  inflating: ./trained_models/GLOBAL_GE

### Setting the path to a given checkpoint

Note that the checkpoint file is expected to contain a dict with a key `"model_state_dict"` that is loaded.


In [5]:
model_weights = "./trained_models/GLOBAL_GEDI_2019_2020/model_1/FT_Lm_SRCB/checkpoint.pt"

model = xceptionS2_08blocks_256(in_channels=15, out_channels=1, 
                                model_weights=model_weights,
                                returns="variances_exp")


Loading pretrained model weights from:
./trained_models/GLOBAL_GEDI_2019_2020/model_1/FT_Lm_SRCB/checkpoint.pt
