In [1]:
import xarray
from pyproj import Transformer
import numpy as np
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, TensorDataset
import presto
import os 
import warnings
from sklearn.metrics import classification_report, accuracy_score
from single_file_presto import Presto, FinetuningHead, PrestoFineTuningModel
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore")
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

In [2]:
BANDS = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12', 'QA10', 'QA20', 'NDVI', 'water', 'trees', 'grass', 'flooded_vegetation', 'crops', 'shrub_and_scrub', 'built', 'bare', 'snow_and_ice', 'label']
#after B12 things aren't taken

In [3]:
def process_images(folder_path):
    i = 0
    arrays, masks, latlons, image_names, labels, dynamic_worlds, month = [], [], [], [], [], [], []
    
    for filename in os.listdir(folder_path):  # Iterate over files in the folder
        if filename.startswith('.DS_Store') or filename.startswith('nurpur'):  
            continue
        if filename.endswith('.tif'):  # Process only TIFF files
            full_path = os.path.join(folder_path, filename)  # Construct full file path
            tif_file = xarray.open_rasterio(full_path)
            crs = tif_file.crs.split("=")[-1]
            transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
            
            for x_idx in range(0, 1):
                for y_idx in range(0, 1):
                    x, y = tif_file.x[x_idx], tif_file.y[y_idx]
                    lon, lat = transformer.transform(x, y)
                    
                    s2_data_for_pixel = torch.from_numpy((tif_file.values[:, x_idx, y_idx]* 10000).astype(int)).float()
                    if i == 0:
                        print(s2_data_for_pixel.shape)
                        i += 1
                    s2_data_with_time_dimension = s2_data_for_pixel.unsqueeze(0)
                    x, mask, dynamic_world = presto.construct_single_presto_input(
                        s2=s2_data_with_time_dimension, s2_bands=BANDS
                    )
                    x[0][-1] = (x[0][8] - x[0][4])/(x[0][8] + x[0][4])
                    label = None
                    if 'maize' in filename:
                        label = 0
                    elif 'nurpur' in filename:
                        label = 3
                    elif 'rice' in filename:
                        label = 1
                    elif 'sug' in filename:
                        label = 2
                    if label is not None:
                        latlons.append(torch.tensor([lat, lon]))
                        arrays.append(x)
                        masks.append(mask)
                        dynamic_worlds.append(dynamic_world)
                        labels.append(label)
                        image_names.append(filename)
                        if filename[-6].isdigit(): 
                            n = int(filename[-6:-4])
                            month.append(n-1)
                        else: 
                            n = int(filename[-5])
                            month.append(n-1)
                        
    return (torch.stack(arrays, axis=0),
            torch.stack(masks, axis=0),
            torch.stack(dynamic_worlds, axis=0),
            torch.stack(latlons, axis=0),
            torch.tensor(labels),
            image_names,
            torch.tensor(month)
            )

In [4]:
sample_data = process_images("sample data")

torch.Size([24])


In [14]:
print(sample_data[0][0])
print(sample_data[1][0])

tensor([[ 1.0000,  1.0000,  0.2247,  0.2133,  0.1955,  0.2474,  0.3085,  0.3376,
          0.3279,  0.3498,  0.2949,  0.2474, -7.7757,  0.0000,  0.0000,  0.0000,
          0.2530]])
tensor([[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.]])


for each pixel the 'x' values go - [s1 band, s1 band, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, ERA5, ERA5, SRTM, SRTM, NDVI (which we calculated)], the mask indices correspond to the x indices for each pixel with a 1 where that index of x has no data and 0 where it does

In [10]:
sample_data[2][sample_data[2] == 9] = 4 #dynamic world for all is 4 

In [11]:
dl_1 = DataLoader(
    TensorDataset(
        sample_data[0].float(),  # x
        sample_data[1].bool(),  # mask
        sample_data[2].long(),  # dynamic world
        sample_data[3].float(), # latlons
        sample_data[6], #month
        sample_data[4] #labels
    ),
    batch_size=146, 
    shuffle=False,
)

This is the part I need help with 

So Decoder output is a tuple - from what I understand the first part corresponds to the 'x'(which has spectral band information) and second corresponds to dynamic world which is: 
        0: "water",
        1: "trees",
        2: "grass",
        3: "flooded_vegetation",
        4: "crops",
        5: "shrub_and_scrub",
        6: "built",
        7: "bare",
        8: "snow_and_ice"
        and 9 if unknown 

I am not sure if the two parts of the tuple correspond to what I think they do, which is what I want to confirm. Most of my undedrstanding came from presto/presto.py. If i am correct how can i proceed with the loss function on both the x input and dynamic world? Should I stack them or something like that? 

In [12]:
# Load the pre-trained model
pretrained_model = presto.Presto.load_pretrained()

# Set encoder and decoder parameters to trainable
pretrained_model.encoder.requires_grad_(True)
pretrained_model.decoder.requires_grad_(True)

# Define optimizer and loss function
optimizer = torch.optim.Adam(pretrained_model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()  # Choose appropriate loss function based on your task
num_epochs = 10
# Train the model on your dataset
for epoch in range(num_epochs):
    for batch in dl_1:  # Iterate over your dataset batches
        optimizer.zero_grad()
        x, mask, dynamic_world, latlons, month, labels = batch  # Assuming your batch contains inputs, dynamic_world, latlons, and targets
        decoder_output = pretrained_model.forward(x, dynamic_world, latlons, mask, month)
        #print(x)
        #print('decoder: ',type(decoder_output)
        loss = loss_fn(decoder_output[0],x)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Save the fine-tuned model
torch.save(pretrained_model.state_dict(), "pretrained_finetuned_model.pth")

Epoch 1/10, Loss: 0.0067171091213822365
Epoch 2/10, Loss: 0.005097079556435347
Epoch 3/10, Loss: 0.0026093136984854937
Epoch 4/10, Loss: 0.0018327711150050163
Epoch 5/10, Loss: 0.0014371551806107163
Epoch 6/10, Loss: 0.0011533841025084257
Epoch 7/10, Loss: 0.0009246274130418897
Epoch 8/10, Loss: 0.0007092871237546206
Epoch 9/10, Loss: 0.0005239146994426847
Epoch 10/10, Loss: 0.0004037352337036282


After this is separate stuff that has been figured out

## Fine tuning and prediction

In [134]:
presto_model = presto.Presto.load_pretrained()
# Instantiate head, not regression so false and 4 classes 
num_classes = 3
head = FinetuningHead(hidden_size=presto_model.encoder.embedding_size, num_outputs=num_classes, regression = False)
#fine-tuning model call
fine_tuning_model = PrestoFineTuningModel(encoder=presto_model.encoder, head=head)
#loss function for multi-class classification
criterion = nn.CrossEntropyLoss()
#optimizer
learning_rate = 0.001
optimizer = optim.Adam(fine_tuning_model.parameters(), lr=learning_rate)

In [None]:
num_epochs = 20
lowest_val_loss = float('inf')  # Initialize lowest validation loss to positive infinity

for epoch in range(num_epochs):
    # Set the model in training mode
    fine_tuning_model.train()
    total_loss = 0.0
    
    # Training loop
    for batch_idx, (x, mask, dynamic_world, latlons, month, labels) in enumerate(train_dl):
        # Clear the gradients
        optimizer.zero_grad()
        # Forward propagation
        outputs = fine_tuning_model(x, dynamic_world, latlons, mask, month)
        # Loss calculation
        loss = criterion(outputs, labels)
        # Backpropagation
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Calculate average training loss for the epoch
    avg_train_loss = total_loss / len(train_dl)
    # Set the model in evaluation mode
    fine_tuning_model.eval()
    val_loss = 0.0
    
    # Validation loop
    with torch.no_grad():
        for val_batch_idx, (x_val, mask_val, dynamic_world_val, latlons_val, month_val, labels_val) in enumerate(val_dl):
            labels_val = labels_val.long()
            # Forward propagation
            val_outputs = fine_tuning_model(x_val, dynamic_world_val, latlons_val, mask_val, month_val)
            # Loss calculation
            val_loss += criterion(val_outputs, labels_val).item()

    # Calculate average validation loss
    avg_val_loss = val_loss / len(val_dl)
    # Check if the current epoch has the lowest validation loss encountered so far
    if avg_val_loss < lowest_val_loss:
        lowest_val_loss = avg_val_loss
        # Save the model
        torch.save(fine_tuning_model.state_dict(), 'best_model.pth')
        print("Best model saved with validation loss:", lowest_val_loss)
    
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}")

    # After the loop, load the best model and continue training
    fine_tuning_model.load_state_dict(torch.load('best_model.pth'))

In [145]:
fine_tuning_model.eval() 
features_list = []
for (x, mask, dw, latlons, month, _) in tqdm(dl_1):
    with torch.no_grad(): #reduces computation no gradients or calculations, like how you don't have it for validation or testing 
        encodings = (
            fine_tuning_model.encoder(
                x, dynamic_world=dw, mask=mask, latlons=latlons, month=month
            )
            .cpu()
            .numpy()
        )
        features_list.append(encodings)
features_np = np.concatenate(features_list)

100%|██████████| 90/90 [00:01<00:00, 63.27it/s]


In [146]:
model = RandomForestClassifier(class_weight="balanced", random_state=42)
model.fit(features_np, train_data[4].numpy()) #here the labels come in

In [147]:
dl_2 = DataLoader(
    TensorDataset(
        test_data[0].float(),  # x
        test_data[1].bool(),  # mask
        test_data[2].long(),  # dynamic world
        test_data[3].float(),  # latlons
        test_data[6]
    ),
    batch_size=146,
    shuffle=False,
)

In [148]:
test_preds = []
for (x, mask, dw, latlons, month) in tqdm(dl_2):
    with torch.no_grad():
        fine_tuning_model.eval()
        encodings = (fine_tuning_model.encoder(
            x, dynamic_world=dw, mask=mask, latlons=latlons, month=month)
            .cpu()
            .numpy()
        )
        test_preds.append(model.predict_proba(encodings))

100%|██████████| 10/10 [00:00<00:00, 45.94it/s]


In [149]:
pix_per_image = 1 

test_preds_np = np.concatenate(test_preds, axis=0) #single 1d array of predictions
test_preds_np = np.reshape(
    test_preds_np,
    (int(len(test_preds_np) / pix_per_image), pix_per_image, test_preds_np.shape[-1]), #pixel-wise predicition
)

In [150]:
# then, take the mode of the model predictions
test_preds_np_argmax = stats.mode(
    np.argmax(test_preds_np, axis=-1), axis=1, keepdims=False
)[0]
target = np.reshape(test_data[4], (int(len(test_data[4]) / pix_per_image), pix_per_image))[:, 0]

Class 0 = Maize
Class 1 = Nurpur Rice
Class 2 = Rice
Class 3 = Sugarcane

In [152]:
report = classification_report(target, test_preds_np_argmax)

# Calculate accuracy for all four classes
accuracy = accuracy_score(target, test_preds_np_argmax)

print("Classification Report:")
print(report)

print("\nAccuracy:", accuracy)
print("\nOverallF1 score:",f1_score(target, test_preds_np_argmax, average="weighted"))

Classification Report:
              precision    recall  f1-score   support

           0       0.64      0.66      0.65       348
           1       0.79      0.79      0.79       600
           2       0.80      0.79      0.79       492

    accuracy                           0.76      1440
   macro avg       0.74      0.74      0.74      1440
weighted avg       0.76      0.76      0.76      1440


Accuracy: 0.7569444444444444

OverallF1 score: 0.7574077527153302


In [153]:
from sklearn.metrics import accuracy_score

class_accuracies = []
for class_label in range(3):  # Assuming 4 classes
    # Find indices where class_label occurs in target
    class_indices = np.where(target == class_label)[0]
    # Extract predictions corresponding to class_label
    class_predictions = test_preds_np_argmax[class_indices]
    # Calculate accuracy for class_label
    class_accuracy = accuracy_score(class_predictions, np.full_like(class_predictions, class_label))
    class_accuracies.append(class_accuracy)

print("Class Accuracies:")
for class_label, accuracy in enumerate(class_accuracies):
    print(f"Class {class_label}: {accuracy}")

Class Accuracies:
Class 0: 0.6551724137931034
Class 1: 0.7916666666666666
Class 2: 0.7865853658536586
