In [1]:
import os
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [2]:
import torch
from torch.utils.data import DataLoader

In [3]:
from utils.pytorch_datasets import Ben19Dataset
from utils.pytorch_models import ResNet18
from utils.pytorch_utils import start_cuda, get_classification_report, print_micro_macro
from utils.pytorch_utils import MetricTracker, init_results, update_results

## Parameters

In [4]:
from pathlib import Path

data_dir = Path("data/")

# path to LMDB file of all BEN19 patches
lmdb_path = data_dir / "BigEarth_Serbia_Summer_S2.lmdb/"
# csv paths to train & val file belonging to serbia
csv_train_path = data_dir / "train.csv"
csv_val_path = data_dir / "test.csv"

In [5]:
cuda_no = 1
batch_size = 128
num_workers = 0
epochs = 1

channels = 10
num_classes = 19
dataset_filter = "serbia"

## Initialize Model, Optimizer and Loss Function (Criterion)

## Initialize Train & Val Set and DataLoaders

In [6]:
training_set = Ben19Dataset(
    lmdb_path=lmdb_path, csv_path=csv_train_path, img_transform="default"
)
train_loader = DataLoader(
    training_set,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True,
    pin_memory=True,
)

validation_set = Ben19Dataset(
    lmdb_path=lmdb_path, csv_path=csv_val_path, img_transform="default"
)
val_loader = DataLoader(
    validation_set,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
    pin_memory=True,
)

## Functions for Training

In [7]:
def train_epoch(model, train_loader, criterion, optimizer, epoch):
    loss_tracker = MetricTracker()
    model.train()

    for idx, batch in enumerate(tqdm(train_loader, desc="training")):
        data, labels, index = batch["data"], batch["label"], batch["index"]
        data = data
        labels = labels
        optimizer.zero_grad()

        logits = model(data)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        batch_size = data.size(0)
        loss_tracker.update(loss.item(), batch_size)

    print("Train loss: {:.6f}".format(loss_tracker.avg))
    return loss_tracker.avg

In [8]:
def val_epoch(model, val_loader, dataset_filter):
    model.eval()
    y_true = []
    predicted_probs = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(val_loader, desc="test")):
            data = batch["data"]
            labels = batch["label"].numpy()

            logits = model(data)
            probs = torch.sigmoid(logits).cpu().numpy()

            predicted_probs += list(probs)

            y_true += list(labels)

    predicted_probs = np.asarray(predicted_probs)
    y_predicted = (predicted_probs >= 0.5).astype(np.float32)

    y_true = np.asarray(y_true)
    report = get_classification_report(
        y_true, y_predicted, predicted_probs, dataset_filter
    )
    return report

In [9]:
def train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
):
    results = init_results(num_classes)

    for epoch in range(1, epochs + 1):
        print("Epoch {}/{}".format(epoch, epochs))
        print("-" * 10)

        train_epoch(model, train_loader, criterion, optimizer, epoch)
        report = val_epoch(model, val_loader, dataset_filter)

        results = update_results(results, report, num_classes)
        print_micro_macro(report)

    return results

### <span style="color:();">Run to silence warnings.</span>

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=UserWarning)  # Suppress UserWarnings
    warnings.filterwarnings("ignore", category=FutureWarning)  # Suppress FutureWarnings

### Select Model

Run one of the following cells to train the corresponding model

In [11]:
from models.poolformer import poolformer_s12

model = poolformer_s12(in_chans=10, num_classes=19)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
s12_results = train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
)

Epoch 1/1
----------


training:   0%|          | 0/61 [00:00<?, ?it/s]

In [None]:
from models.poolformer import poolformer_s24

model = poolformer_s24(in_chans=10, num_classes=19)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
s24_results = train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
)

In [None]:
from models.poolformer import poolformer_s36

model = poolformer_s36(in_chans=10, num_classes=19)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
s36_results = train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
)

In [None]:
from models.poolformer import poolformer_m36

model = poolformer_m36(in_chans=10, num_classes=19)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
m36_results = train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
)

In [None]:
from models.poolformer import poolformer_m48

model = poolformer_m48(in_chans=10, num_classes=19)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
m48_results = train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
)

In [None]:
model = ResNet18(num_cls=num_classes, channels=channels, pretrained=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
criterion = torch.nn.BCEWithLogitsLoss(reduction="mean")
resnet_results = train(
    model, train_loader, val_loader, criterion, optimizer, epochs, dataset_filter
)



Epoch 1/10
----------


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.211027


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.8068 | recall: 0.6508 | f1-score: 0.7204 | support: 9725 | mAP: 0.8245
macro     precision: 0.4929 | recall: 0.3141 | f1-score: 0.3433 | support: 9725 | mAP: 0.4195

Epoch 2/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.158587


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7856 | recall: 0.7189 | f1-score: 0.7508 | support: 9725 | mAP: 0.8518
macro     precision: 0.4721 | recall: 0.3554 | f1-score: 0.3901 | support: 9725 | mAP: 0.4480

Epoch 3/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.146699


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.8218 | recall: 0.6746 | f1-score: 0.7409 | support: 9725 | mAP: 0.8460
macro     precision: 0.5225 | recall: 0.3449 | f1-score: 0.3849 | support: 9725 | mAP: 0.4568

Epoch 4/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.134175


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7814 | recall: 0.7139 | f1-score: 0.7462 | support: 9725 | mAP: 0.8435
macro     precision: 0.4981 | recall: 0.3815 | f1-score: 0.4109 | support: 9725 | mAP: 0.4555

Epoch 5/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.120330


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7695 | recall: 0.6157 | f1-score: 0.6841 | support: 9725 | mAP: 0.8090
macro     precision: 0.5094 | recall: 0.3200 | f1-score: 0.3464 | support: 9725 | mAP: 0.4488

Epoch 6/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.104891


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7879 | recall: 0.7171 | f1-score: 0.7509 | support: 9725 | mAP: 0.8478
macro     precision: 0.5117 | recall: 0.3846 | f1-score: 0.4119 | support: 9725 | mAP: 0.4599

Epoch 7/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.086484


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7429 | recall: 0.6989 | f1-score: 0.7203 | support: 9725 | mAP: 0.8170
macro     precision: 0.4454 | recall: 0.4016 | f1-score: 0.4013 | support: 9725 | mAP: 0.4544

Epoch 8/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.065308


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7641 | recall: 0.7488 | f1-score: 0.7564 | support: 9725 | mAP: 0.8471
macro     precision: 0.4970 | recall: 0.3971 | f1-score: 0.4226 | support: 9725 | mAP: 0.4583

Epoch 9/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.045216


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7496 | recall: 0.7174 | f1-score: 0.7331 | support: 9725 | mAP: 0.8252
macro     precision: 0.4807 | recall: 0.3905 | f1-score: 0.4116 | support: 9725 | mAP: 0.4530

Epoch 10/10
----------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/61 [00:00<?, ?it/s]

Train loss: 0.028988


test:   0%|          | 0/28 [00:00<?, ?it/s]

micro     precision: 0.7527 | recall: 0.7353 | f1-score: 0.7439 | support: 9725 | mAP: 0.8275
macro     precision: 0.4938 | recall: 0.3858 | f1-score: 0.4125 | support: 9725 | mAP: 0.4626



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(s12_results.keys())

dict_keys(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', 'micro avg', 'macro avg', 'ap_mic', 'ap_mac'])


In [None]:
import pickle

with open("s12_results.pkl", "wb") as f:
    pickle.dump(s12_results, f)

with open("resnet_results.pkl", "wb") as f:
    pickle.dump(resnet_results, f)

In [None]:
with open("s12_results.pkl", "rb") as f:
    loaded_dict = pickle.load(f)
    assert loaded_dict == s12_results