In [1]:
%load_ext autoreload
%autoreload 2

In this example we are going to train an image classifier with the [EuroSAT](https://github.com/phelber/EuroSAT) dataset and show options to export the model for production. If you want to learn more about training the model, check [this](https://github.com/earthpulse/pytorch_eo/blob/main/examples/eurosat.ipynb) notebook.

We will use `torchscript` to export the model. You can use other options, such as `onnx` or use directly the checkpoint generated by the `trainer` (keep in mind, however, that in this case the library, dependencies and you model source code will be required in the production environment, which is very inconvenient). 

The following model can be exported with `torchscript`, along with some addition information that can be useful later on.

In [2]:
import torch
import timm 

class Model(torch.nn.Module):

    def __init__(self, in_chans, classes):
        super().__init__()
        self.model = timm.create_model(
            'resnet18',
            pretrained='imagenet',
            in_chans=in_chans,
            num_classes=len(classes)
        )
        self.classes = classes 
        self.info = 'This is a model trained for image classification with EuroSAT datset. Put here any useful information.'
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x): # DO NOT USE DICT, IT CANNOT BE EXPORTED WITH TORCHSCRIPT
        x = x.permute(0, 3, 1, 2) # b h w c -> b c h w 
        x = (x / 255).float()
        return self.model(x)

    @torch.jit.export # use this to export custom functions with torchscript
    def predict(self, x):
        with torch.no_grad():
            y_hat = self(x)
            probas = self.softmax(y_hat)
            return probas

Once the `model` is defined, we need to define a `task`. In this case, we use the `ImageClassification` task.

In [14]:
from pytorch_eo.tasks.classification import ImageClassification
from pytorch_eo.datasets.eurosat import EuroSATRGB

ds = EuroSATRGB(batch_size=32) 
ds.setup()

model = Model(in_chans=ds.in_chans, classes=ds.classes)

task = ImageClassification(model)

Now, we can use `Pytorch Lightning` for training the `model` to solve the `task` on the given `dataset`.

In [16]:
import pytorch_lightning as pl

trainer = pl.Trainer(
    gpus=1,
    max_epochs=3
)

trainer.fit(task, ds)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | Model            | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

  rank_zero_warn(
  rank_zero_warn(


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




In [17]:
trainer.test()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  rank_zero_warn(


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.889629602432251, 'test_loss': 0.3352689743041992}
--------------------------------------------------------------------------------


[{'test_acc': 0.889629602432251, 'test_loss': 0.3352689743041992}]

In [18]:
script_model = torch.jit.script(model.cpu())
script_model.save('my_model.pth')

Now this model can be laoded and used in any environment with `Python` and `Pytorch` installed without depending on `PytorchEO`.

In [19]:
import torch

my_model = torch.jit.load('my_model.pth')
my_model.eval()

my_model.info

'This is a model trained for image classification with EuroSAT datset. Put here any useful information.'

In [20]:
import random 

# load a new batch of images (in this case, from the test dataset)

batch = next(iter(ds.test_dataloader(shuffle=True, batch_size=10)))
imgs, labels = batch['image'], batch['label']

preds = my_model.predict(imgs)

for label, pred in zip(labels, preds):
	print(ds.classes[label], '->', model.classes[torch.argmax(pred)])

River -> River
River -> River
Pasture -> Pasture
Forest -> Forest
Highway -> Highway
Pasture -> Pasture
Residential -> Residential
Industrial -> Industrial
Residential -> Residential
HerbaceousVegetation -> HerbaceousVegetation
