In [1]:
from pprint import pprint

import torch
import torch.nn as nn
import timm
import kornia.augmentation as K
from kornia.contrib import Lambda
from faissknn import FaissKNNClassifier
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)

from src.model import swin_v2
from src.datasets.ships2ais import ShipS2AISDataModule
from src.utils import extract_features


data_root = "./data/ship-s2-ais/"
config_path = "checkpoints/swin-v2-rgb/config.yaml"
device = torch.device("cuda:0")

In [2]:
dm = ShipS2AISDataModule(root=data_root, bands="rgb")
dm.setup()

### Eval SwinV2 pretrained on Hydro dataset

In [3]:
model, transforms, config = swin_v2(config_path)
model = model.to(device)

Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.
=> merge config from checkpoints/swin-v2-rgb/config.yaml


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
mean = torch.tensor(config.DATA.MEAN) / 10000.0
std = torch.tensor(config.DATA.STD) / 10000.0

if config.DATA.BANDS == "rgb":
    mean = mean[[3, 2, 1]]
    std = std[[3, 2, 1]]


transforms = [
    K.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)),
    Lambda(lambda x: x / 10000.0),
    K.Normalize(mean=mean, std=std),
]
transforms = nn.Sequential(*transforms).to(device)

In [5]:
x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms=transforms)
x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms=transforms)

100%|██████████| 358/358 [01:15<00:00,  4.76it/s]
100%|██████████| 86/86 [00:18<00:00,  4.67it/s]


In [6]:
model = FaissKNNClassifier(
    n_neighbors=5,
    n_classes=2,
    device="cuda:0"
)
model.fit(x_train, y_train)
y_pred = model.predict(x_test)

In [7]:
metrics = {
    "f1_weighted": f1_score(y_test, y_pred, average="weighted"),
    "f1_macro": f1_score(y_test, y_pred, average="macro"),
    "f1_micro": f1_score(y_test, y_pred, average="micro"),
    "precision_micro": precision_score(y_test, y_pred, average="micro"),
    "precision_macro": precision_score(y_test, y_pred, average="macro"),
    "precision_weighted": precision_score(y_test, y_pred, average="weighted"),
    "recall_micro": recall_score(y_test, y_pred, average="micro"),
    "recall_macro": recall_score(y_test, y_pred, average="macro"),
    "recall_weighted": recall_score(y_test, y_pred, average="weighted"),
    "accuracy": accuracy_score(y_test, y_pred),
}
pprint(metrics)

{'accuracy': 0.8765115426896299,
 'f1_macro': 0.5357993172814213,
 'f1_micro': 0.8765115426896299,
 'f1_weighted': 0.8376021482567848,
 'precision_macro': 0.661124012724216,
 'precision_micro': 0.8765115426896299,
 'precision_weighted': 0.8323591661055489,
 'recall_macro': 0.5337417679837893,
 'recall_micro': 0.8765115426896299,
 'recall_weighted': 0.8765115426896299}


### Eval SwinV2 ImageNet

In [4]:
model = timm.create_model("swinv2_base_window16_256", pretrained=True).to(device)

In [5]:
mean = torch.tensor(config.DATA.MEAN) / 10000.0
std = torch.tensor(config.DATA.STD) / 10000.0

if config.DATA.BANDS == "rgb":
    mean = mean[[3, 2, 1]]
    std = std[[3, 2, 1]]


transforms = [
    K.Resize(model.patch_embed.img_size),
    Lambda(lambda x: x / 10000.0),
    K.Normalize(mean=mean, std=std),
]
transforms = nn.Sequential(*transforms).to(device)

In [6]:
x_train, y_train = extract_features(model, dm.train_dataloader(), device, transforms=transforms)
x_test, y_test = extract_features(model, dm.test_dataloader(), device, transforms=transforms)

100%|██████████| 358/358 [01:14<00:00,  4.78it/s]
100%|██████████| 86/86 [00:18<00:00,  4.69it/s]


In [8]:
model = FaissKNNClassifier(
    n_neighbors=5,
    n_classes=2,
    device="cuda:0"
)
model.fit(x_train, y_train)
y_pred = model.predict(x_test)

In [9]:
metrics = {
    "f1_weighted": f1_score(y_test, y_pred, average="weighted"),
    "f1_macro": f1_score(y_test, y_pred, average="macro"),
    "f1_micro": f1_score(y_test, y_pred, average="micro"),
    "precision_micro": precision_score(y_test, y_pred, average="micro"),
    "precision_macro": precision_score(y_test, y_pred, average="macro"),
    "precision_weighted": precision_score(y_test, y_pred, average="weighted"),
    "recall_micro": recall_score(y_test, y_pred, average="micro"),
    "recall_macro": recall_score(y_test, y_pred, average="macro"),
    "recall_weighted": recall_score(y_test, y_pred, average="weighted"),
    "accuracy": accuracy_score(y_test, y_pred),
}
pprint(metrics)

{'accuracy': 0.9113228288750458,
 'f1_macro': 0.7345492927402577,
 'f1_micro': 0.9113228288750457,
 'f1_weighted': 0.898939949828334,
 'precision_macro': 0.8367945308330393,
 'precision_micro': 0.9113228288750458,
 'precision_weighted': 0.9015650324045857,
 'recall_macro': 0.6886100557244175,
 'recall_micro': 0.9113228288750458,
 'recall_weighted': 0.9113228288750458}
