In [10]:
import torch
import timm

In [11]:
classes = [
    'haze', 'primary', 'agriculture', 'clear', 'water', 'habitation', 'road', 
    'cultivation', 'slash_burn', 'cloudy', 'partly_cloudy', 'conventional_mine', 
    'bare_ground', 'artisinal_mine', 'blooming', 'selective_logging', 'blow_down'
]
DEVICE = 'cpu'

In [26]:
model_path = '../weights/model.best.pth'
state_dict = torch.load(model_path, map_location=torch.device(DEVICE))

model = timm.create_model(model_name="resnet18", pretrained=False, num_classes=len(classes))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [16]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, classes, size, thresholds):
        super().__init__()
        self.model = model
        self.classes = classes
        self.size = size
        self.thresholds = thresholds
    
    def forward(self, image):
        return torch.sigmoid(self.model.forward(image))

In [39]:
wrapper = ModelWrapper(model, classes=classes, size=(224, 224), thresholds=(0.5,) * len(classes))

In [40]:
scripted_model = torch.jit.script(wrapper)

In [41]:
scripted_model.classes

['haze',
 'primary',
 'agriculture',
 'clear',
 'water',
 'habitation',
 'road',
 'cultivation',
 'slash_burn',
 'cloudy',
 'partly_cloudy',
 'conventional_mine',
 'bare_ground',
 'artisinal_mine',
 'blooming',
 'selective_logging',
 'blow_down']

In [33]:
traced_model = torch.jit.trace(wrapper, torch.rand(1, 3, 224, 224))

In [34]:
traced_model.classes

In [22]:
dummy_input = torch.rand(1, 3, 224, 224)

In [23]:
with torch.no_grad():
    print(torch.sigmoid(model(dummy_input)))

tensor([[0.0206, 0.9866, 0.1904, 0.7749, 0.1151, 0.0271, 0.0861, 0.0505, 0.0047,
         0.0057, 0.0851, 0.0034, 0.0113, 0.0046, 0.0063, 0.0063, 0.0036]])


In [24]:
with torch.no_grad():
    print(scripted_model(dummy_input))

tensor([[0.0206, 0.9866, 0.1904, 0.7749, 0.1151, 0.0271, 0.0861, 0.0505, 0.0047,
         0.0057, 0.0851, 0.0034, 0.0113, 0.0046, 0.0063, 0.0063, 0.0036]])


In [42]:
torch.jit.save(scripted_model, model_path.replace('.pth', '.zip'))

In [43]:
model = torch.jit.load(model_path.replace('.pth', '.zip'), map_location='cpu')

In [44]:
model.classes

['haze',
 'primary',
 'agriculture',
 'clear',
 'water',
 'habitation',
 'road',
 'cultivation',
 'slash_burn',
 'cloudy',
 'partly_cloudy',
 'conventional_mine',
 'bare_ground',
 'artisinal_mine',
 'blooming',
 'selective_logging',
 'blow_down']

In [45]:
model.thresholds

(0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5)