# AnySat Guide

#### Simple Usage

AnySat is available through PyTorch Hub.

In [4]:
import torch

model = torch.hub.load('gastruc/anysat', 'anysat', pretrained=True, force_reload=True, flash_attn=False)

Downloading: "https://github.com/gastruc/anysat/zipball/main" to /home/GAstruc/.cache/torch/hub/main.zip


#### Local usage

Repo installation:

```bash
git clone https://github.com/gastruc/AnySat.git
cd AnySat
pip install -e AnySat
```



In [7]:
from hubconf import AnySat

model = AnySat.from_pretrained('base', flash_attn=False) #Set flash_attn=True if you have flash-attn module installed (url flash attn)
#device = "cuda" If you want to run on GPU default is cpu

#### Experiments Reproduction

All experiments are available in the [experiments](https://github.com/gastruc/AnySat/tree/main/experiments) folder.

For the reproduction of AnySat envirnoment run:

```bash
# clone project
git clone https://github.com/gastruc/anysat
cd anysat

# [OPTIONAL] create conda environment
conda create -n anysat python=3.9
conda activate anysat

# install requirements
pip install -r requirements.txt

# Create data folder where you can put your datasets
mkdir data
# Create logs folder
mkdir logs
```

And Then run the experiment you want:

```bash
# Run AnySat pretraining on GeoPlex
python src/train.py exp=GeoPlex_AnySAT

# Run AnySat finetuning on BraDD-S1TS
python src/train.py exp=BraDD_AnySAT_FT

# Run AnySat linear probing on BraDD-S1TS
python src/train.py exp=BraDD_AnySAT_LP
```

You can modify through hydra all parameters you want. For example to train a Small version of AnySat on GeoPlex datasets, run:

```bash
python src/train.py exp=GeoPlex_AnySAT model=Any_Small_multi
```



## Inference on AnySat

#### Template of data 

We are gonna use an example from TreeSatAI-TS dataset.

In [None]:
## Import des données réelles

To get features from an observation of a batch of observations, you need to provide to the model a dictionnary where keys are from the list: 
| Dataset       | Description                       | Tensor Size                                          | Channels                                  | Resolution |
|---------------|-----------------------------------|-----------------------------------------|-------------------------------------------|------------|
| aerial        | Single date tensor |Bx4xHxW                                              | RGB, NiR                                  | 0.2m       |
| aerial-flair  | Single date tensor |Bx5xHxW                                              | RGB, NiR, Elevation                       | 0.2m       |
| spot          | Single date tensor |Bx3xHxW                                              | RGB                                       | 1m         |
| naip          | Single date tensor |Bx4xHxW                                               | RGB                                       | 1.25m      |
| s2            | Time series tensor |BxTx10xHxW                                          | B2, B3, B4, B5, B6, B7, B8, B8a, B11, B12 | 10m        |
| s1-asc        | Time series tensor |BxTx2xHxW                                             | VV, VH                                     | 10m        |
| s1            | Time series tensor |BxTx3xHxW                                            | VV, VH, Ratio                             | 10m        |
| alos          | Time series tensor |BxTx3xHxW                                            | HH, HV, Ratio                             | 30m        |
| l7            | Time series tensor |BxTx6xHxW                                            | B1, B2, B3, B4, B5, B7                    | 30m        |
| l8            | Time series tensor |BxTx11xHxW                                           | B8, B1, B2, B3, B4, B5, B6, B7, B9, B10, B11 | 10m        |
| modis         | Time series tensor |BxTx7xHxW                                            | B1, B2, B3, B4, B5, B6, B7                | 250m       |

In [3]:
import torch

In [4]:
data = {
    "aerial": torch.randn(2, 4, 300, 300), #2 batch size, 4 channels, 300x300 pixels
    "s2": torch.randn(2, 4, 10, 6, 6), #2 batch size, 4 dates, 10 channels, 6x6 pixels
    "s2_dates": torch.randint(0, 367, (2, 4)),
    "s1": torch.randn(2, 4, 3, 6, 6), #2 batch size, 4 dates, 10 channels, 6x6 pixels
    "s1_dates": torch.randint(0, 367, (2, 4)),
}
## A changer par les données réelles

Time series keys require a "{key}_dates" (for example "s2_dates") tensor of size BxT that value an integer that represent the day of the year.

Decide on:
- **Patch size** (in m, must be a multiple of 10): adjust according to the scale of your tiles and GPU memory. In general, avoid having more than 1024 patches per tile.
- **Output type**: Choose between:
  - `'tile'`: Single vector per tile
  - `'patch'`: A vector per patch
  - `'dense'`: A vector per sub-patch. Doubles the size to the vectors
  - `'all'`: A vector per patch with class token at first position
 
The sub patches are `1x1` pixels for time series and `10x10` pixels for VHR images. If using `output='dense'`, specify the `output_modality`.
Scale should divide the spatial cover of all modalities and be a multiple of 10

You can specify the type of output you want:
- 'tile': to get the tile features
- 'patch': to get the patch features
- 'dense': to get the dense map at subpatch level
    - If dense is selected, you can specify the modality you want to keep with modality_keep parameter, default is the first modality in the data.


In [7]:
features = model(data, patch_size=10, output='tile') 
print(features.shape)

torch.Size([2, 768])


In [6]:
features = model(data, patch_size=10, output='patch') 
print(features.shape)

torch.Size([2, 6, 6, 768])


In [5]:
features = model(data, patch_size=20, output='patch') 
print(features.shape)

torch.Size([2, 3, 3, 768])


In [7]:
features = model(data, patch_size=60, output='patch') 
print(features.shape)

torch.Size([2, 1, 1, 768])


In [5]:
features = model(data, patch_size=20, output='dense', output_modality="aerial") 
print(features.shape)

torch.Size([2, 30, 30, 1536])


In [6]:
features = model(data, patch_size=20, output='dense', output_modality="s2") 
print(features.shape)

torch.Size([2, 6, 6, 1536])
