In [1]:
import torch
import timm

In [2]:
classes = [
    'haze', 'primary', 'agriculture', 'clear', 'water', 'habitation', 'road', 
    'cultivation', 'slash_burn', 'cloudy', 'partly_cloudy', 'conventional_mine', 
    'bare_ground', 'artisinal_mine', 'blooming', 'selective_logging', 'blow_down'
]
DEVICE = 'cpu'

In [3]:
model_path = '../weights/model.best.pth'
state_dict = torch.load(model_path, map_location=torch.device(DEVICE))

model = timm.create_model(model_name="resnet18", pretrained=False, num_classes=len(classes))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [3]:
model_path = '../weights/model_efb3_288.best.pth'
state_dict = torch.load(model_path, map_location=torch.device(DEVICE))

model = timm.create_model(model_name="efficientnet_b3", pretrained=False, num_classes=len(classes))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [4]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, classes, size, thresholds):
        super().__init__()
        self.model = model
        self.classes = classes
        self.size = size
        self.thresholds = thresholds
    
    def forward(self, image):
        return torch.sigmoid(self.model.forward(image))

In [5]:
wrapper = ModelWrapper(model, classes=classes, size=(224, 224), thresholds=(0.5,) * len(classes))

In [5]:
wrapper = ModelWrapper(model, classes=classes, size=(288, 288), thresholds=(0.5,) * len(classes))

In [6]:
scripted_model = torch.jit.script(wrapper)

In [7]:
scripted_model.classes

['haze',
 'primary',
 'agriculture',
 'clear',
 'water',
 'habitation',
 'road',
 'cultivation',
 'slash_burn',
 'cloudy',
 'partly_cloudy',
 'conventional_mine',
 'bare_ground',
 'artisinal_mine',
 'blooming',
 'selective_logging',
 'blow_down']

In [8]:
traced_model = torch.jit.trace(wrapper, torch.rand(1, 3, 224, 224))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [9]:
traced_model.classes

AttributeError: 'RecursiveScriptModule' object has no attribute 'classes'

In [11]:
dummy_input = torch.rand(1, 3, 224, 224)

In [12]:
with torch.no_grad():
    print(torch.sigmoid(model(dummy_input)))

tensor([[0.0199, 0.9874, 0.1850, 0.7732, 0.1074, 0.0250, 0.0821, 0.0469, 0.0043,
         0.0053, 0.0819, 0.0030, 0.0103, 0.0041, 0.0056, 0.0056, 0.0033]])


In [13]:
with torch.no_grad():
    print(scripted_model(dummy_input))

tensor([[0.0199, 0.9874, 0.1850, 0.7732, 0.1074, 0.0250, 0.0821, 0.0469, 0.0043,
         0.0053, 0.0819, 0.0030, 0.0103, 0.0041, 0.0056, 0.0056, 0.0033]])


In [15]:
model_path

'../weights/model.best.pth'

In [8]:
torch.jit.save(scripted_model, model_path.replace('.pth', '.zip'))

In [9]:
model = torch.jit.load(model_path.replace('.pth', '.zip'), map_location='cpu')

In [10]:
model.classes

['haze',
 'primary',
 'agriculture',
 'clear',
 'water',
 'habitation',
 'road',
 'cultivation',
 'slash_burn',
 'cloudy',
 'partly_cloudy',
 'conventional_mine',
 'bare_ground',
 'artisinal_mine',
 'blooming',
 'selective_logging',
 'blow_down']

In [11]:
model.thresholds

(0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5,
 0.5)

In [None]:
import torch
target = torch.tensor([0.1, 0.1, 0.2, 0.3])
target = torch.sigmoid(target)
target = (target > 0.5).float()
target 

## Test torchmetrics

In [7]:
import torch
from torchmetrics.classification import MultilabelAUROC
preds = torch.tensor([[0.75, 0.05, 0.35],
                       [0.45, 0.75, 0.05],
                       [0.05, 0.55, 0.75],
                       [0.05, 0.65, 0.05]])
target = torch.tensor([[1, 0, 1],
                        [0, 0, 0],
                        [0, 1, 1],
                        [1, 1, 1]])
metric = MultilabelAUROC(num_labels=3, average=None, thresholds=None)
print(metric(preds, target).numpy())
print(metric(preds, target).numpy().mean())

metric_avg = MultilabelAUROC(num_labels=3, average='macro', thresholds=None)
print(metric_avg(preds, target).numpy())

[0.625     0.5       0.8333334]
0.6527778
0.6527778


## Cut samples