# Getting Started | SMARTIES Hugging Face Transformer Model 
This notebook demonstrates the use of SMARTIES pretrained models using the Hugging Face Transformers interface.

## Setup
To use SMARTIES model weights with HF transformers interface, you need to install only Transformers Python package ($\geq$ v4.52.0). For this notebook, EuroSAT dataset from torchgeo is used as an example of downstream task/dataset. This requires to install [torchgeo](https://github.com/microsoft/torchgeo) package. To install all the required packages of this notebook, you can run the following cells:

In [None]:
%pip install transformers torchgeo

In [None]:
import torch
from transformers import AutoModel
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchgeo.datasets import EuroSAT

## Load SMARTIES Weights
With one line of code, you can load the SMARTIES model weights from Hugging Face Hub. The model is loaded in evaluation mode by default. There are two versions of SMARTIES model available on Hugging Face Hub, one with ViT-B backbone and the other with ViT-L backbone. You can choose the one that fits your needs with model name: 
```python
'gsumbul/SMARTIES-v1-ViT-B' or 'gsumbul/SMARTIES-v1-ViT-L'
```

In [None]:
model = AutoModel.from_pretrained(
    "gsumbul/SMARTIES-v1-ViT-B",
    trust_remote_code=True
)

## Prepare Dataloader

In SMARTIES paper, data preprocessing is achieved by first min-max image normalization with 1% and 99% percentile values, and then image standardization with mean and standard deviation values (calculated after the first step). This allows SMARTIES to be robust towards data distribution differences across multiple sensors (e.g., long-tailed distribution of 12 bit Sentinel-2 images vs. short-tailed distribution of 8 bit RGB images). However, you can also use SMARTIES with only widely used mean-std standardization, expecting a slight drop in performance.

In [None]:
class PercentileNormalize(torch.nn.Module):
    def __init__(self, percentile1, percentile99):
        super().__init__()
        self.percentile1 = torch.tensor(percentile1)
        self.percentile99 = torch.tensor(percentile99)
    def forward(self, inpts):
        image, label = inpts['image'], inpts['label']
        return {
            'image': image.sub_(self.percentile1.view(-1, 1, 1)).div_((self.percentile99 - self.percentile1).view(-1, 1, 1)).clamp_(min=0,max=1),
            'label': label
        }

# EuroSAT dataset initialization
# You may need to adjust the root path and download/checksum flags as needed
dataset = lambda split: EuroSAT(
    root="EuroSAT",
    split=split,
    bands=('B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12'),
    transforms=v2.Compose([
        PercentileNormalize(percentile1=[968.0, 697.0, 457.0, 242.0, 203.0, 179.0, 158.0, 131.0, 119.0, 57.0, 22.0, 11.0],
                            percentile99=[2084.0, 2223.0, 2321.0, 2862.0, 2883.0, 3898.0, 4876.0, 4806.0, 5312.0, 1851.0, 4205.0, 3132.0]),
        v2.Normalize(
            mean=[0.34366747736930847, 0.2713719308376312, 0.3102375864982605, 0.2662188410758972, 0.36944717168807983, 0.4893955886363983, 0.4686998128890991, 0.46322500705718994, 0.4768053889274597, 0.3750271201133728, 0.42840376496315, 0.3525424003601074],
            std=[0.21045532822608948, 0.1970716118812561, 0.19605565071105957, 0.21756012737751007, 0.20496250689029694, 0.22960464656352997, 0.22847740352153778, 0.23722581565380096, 0.23559165000915527, 0.22142820060253143, 0.23700211942195892, 0.23857484757900238],
        ),
        v2.Resize((224, 224))
    ]),
    download=True,
)
train_ds, val_ds = dataset('train'), dataset('val')

# You may need to adjust the batch size and number of workers based on your system's capabilities
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=8)
val_loader = DataLoader(val_ds, batch_size=256, shuffle=True, num_workers=8)

batch = next(iter(train_loader))
image = batch["image"]
label = batch["label"]
print("Image shape:", image.shape)
print("Label shape:", label.shape)

## Downstream Transfer
SMARTIES enables sensor-agnostic processing of remote sensing (RS) data, and thus downstream transfer using a unified model across a diverse set of RS sensors and tasks while allowing for the use of arbitrary combinations of spectral bands. To do so, you need to specify which spectrum-aware projection layers you want to use for downstream transfer. Each projection layer is associated with a unique range in the electromagnetic spectrum (i.e., spectral band) that is defined in the `spectrum_specs.yaml` file with keys as band names (e.g., aerosal, red_edge_1, blue_1 etc.). SMARTIES supports a wide range of spectral bands, and you can choose the bands that are relevant to your specific application.  

To set the bands you want to use, you can either: 
1. Use the `sensor_type` parameter for predefined bands of well-known sensory input. It can be set as `S2` for Sentinel-2 L2A image bands, `S1` for Sentinel-1 GRD image bands, or `RGB` for VHR commercial RGB image bands. 
   - Example: `model(image, sensor_type='S2')`
2. Specify bands by passing a list of band names to the `bands` parameter when calling the model. Note that the order of the bands in the list should match the order of the bands in the input image tensor. 
    - Example: `model(image, bands=['blue_1', 'green_1', 'red_1','near_infrared_1'])`


### Image-Level Feature Extraction

In [None]:
# As bands of the dataset follow Sentinel-2 L2A image bands, sensor-type parameter can be used
image_feats = model(image, sensor_type='S2')
print("Features shape:", image_feats.shape)

# Or you can specify the bands you want to use directly
# All the bands used
all_bands = ['aerosol', 'blue_1', 'green_2', 'red_2', 'red_edge_1', 'red_edge_2', 'near_infrared_2', 'near_infrared_1', 
         'near_infrared_3', 'short_wave_infrared_1', 'short_wave_infrared_3', 'short_wave_infrared_4']

image_feats = model(image, bands=all_bands)
print("Features shape:", image_feats.shape)

# Only RGB bands used 
rgb_bands = ['blue_1', 'green_2', 'red_2']
image_rgb_feats = model(image[:,[1,2,3]], bands=rgb_bands)
print("Features shape:", image_rgb_feats.shape)


### Dense Feature Extraction with All Tokens
For dense tasks (e.g. semantic segmentation) during downstream transfer, features associated with all the tokens (including CLS token) can be extracted by setting `all_tokens=True`.

In [None]:
image_feats_dense = model(image, bands=all_bands, all_tokens=True)
print("Features shape (all tokens):", image_feats_dense.shape)
print("Features shape (all tokens without cls token):", image_feats_dense[:,1:,:].shape)

### Linear Probing for Scene-Classification

To perform scene classification on EuroSAT as an example of downstream transfer, in this notebook, linear probing is employed. To do so, we freeze the SMARTIES encoder and train a single FC (linear) layer on top for classification. This is done by extracting features for each image and training a linear classifier using these features.

Below, we show how to set up, train, and evaluate linear probing on top of SMARTIES features.

#### Scene Classification Model

In [None]:
import torch.nn as nn

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Freeze SMARTIES encoder
model.eval()
for param in model.parameters():
    param.requires_grad = False

# Get feature dimension from a dummy forward pass
with torch.no_grad():
    dummy_feats = model(image.to(device), sensor_type='S2')
    feat_dim = dummy_feats.shape[-1]

num_classes = len(train_ds.classes)

class SceneClassification(nn.Module):
    def __init__(self, backbone, feat_dim, num_classes):
        super().__init__()
        self.backbone = backbone
        self.head = torch.nn.Sequential(
            torch.nn.BatchNorm1d(feat_dim, affine=False, eps=1e-6), 
            nn.Linear(feat_dim, num_classes)
        )
    def forward(self, x, **kwargs):
        feats = self.backbone(x, **kwargs)
        return self.head(feats)

cls_model = SceneClassification(model, feat_dim, num_classes).to(device)

#### Training Loop for Linear Probing
We train only the linear layer while keeping the backbone frozen.

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(cls_model.head.parameters(), lr=1e-3, betas=(0.9, 0.95))

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0
    for batch in dataloader:
        imgs, labels = batch['image'].to(device), batch['label'].to(device)
        optimizer.zero_grad()
        outputs = model(imgs, sensor_type='S2')
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        total_correct += (outputs.argmax(1) == labels).sum().item()
        total_samples += imgs.size(0)
    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc

epochs = 10
for epoch in range(epochs):
    loss, acc = train_one_epoch(cls_model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {loss:.4f} - Acc: {acc:.4f}")

#### Evaluation
Evaluate the linear probe on the validation set of the dataset.

In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    total_correct, total_samples = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            imgs, labels = batch['image'].to(device), batch['label'].to(device)
            outputs = model(imgs, sensor_type='S2')
            preds = outputs.argmax(1)
            total_correct += (preds == labels).sum().item()
            total_samples += imgs.size(0)
    acc = total_correct / total_samples
    print(f"Evaluation Accuracy: {acc:.4f}")

evaluate(cls_model, val_loader, device)