In [None]:
%cd crop-type-segmentation/

In [None]:
from pathlib import Path

from matplotlib import pyplot as plt
import torch
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import pandas as pd
from tqdm import tqdm

from prithvi import PrithviSemanticSegmentation
from datamodule import CropTypeDataModule

In [None]:
# Modify the following configuration as needed
config = {}
config["num_classes"] = 17
config["num_frames"] = 7
config["months"] = [2, 4, 5, 6, 7, 8, 9]
config["img_size"] = 224
config["learning_rate"] = 0.001
config["decoder_num_convs"] = 1
config["in_channels"] = 6
config["channels"] = [0, 1, 2, 3, 4, 5]
config["batch_size"] = 5
config["wandb_project"] = "test"
config["wandb_name"] = "test"
config["max_epochs"] = 1
config["num_workers"] = 4 if torch.cuda.is_available() else 0
config["ckpt_path"] = ""
config["train_percent"] = 0.7

### Load the data

In [None]:
data_dir = Path("data")

In [None]:
datamodule = CropTypeDataModule(
    data_dir=data_dir,
    months=config["months"],
    channels=config["channels"],
    train_percent=config["train_percent"],
    img_size=config["img_size"],
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

In [None]:
datamodule.setup(stage=None)

In [None]:
print(f"len train dataset: {len(datamodule.train_dataset)}")
print(f"len val dataset: {len(datamodule.val_dataset)}")

In [None]:
model = PrithviSemanticSegmentation.load_from_checkpoint(
    checkpoint_path=config["ckpt_path"],
    num_classes=config["num_classes"],
    in_channels=config["in_channels"],
    num_frames=config["num_frames"],
    decoder_num_convs=config["decoder_num_convs"],
    img_size=config["img_size"],
    learning_rate=config["learning_rate"],
)

### Val dataset

In [None]:
model.eval()

all_y_val = []
all_pred_val = []
for batch in tqdm(datamodule.val_dataloader()):
    x, y = batch
    pred = model.model(x)
    pred = torch.argmax(pred.output, dim=1)
    y_flat = y.flatten()
    pred_flat = pred.flatten()
    all_y_val.extend(y_flat)
    all_pred_val.extend(pred_flat)

In [None]:
classes_used = {
    "class_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    "class_names": [
        "Other",
        "Rice",
        "Sunflower",
        "Winter Wheat",
        "Alfalfa",
        "Other Hay/Non Alfalfa",
        "Tomatoes",
        "Fallow/Idle Cropland",
        "Almonds",
        "Walnuts",
        "Open Water",
        "Evergreen Forest",
        "Shrubland",
        "Grassland/Pasture",
        "Herbaceous Wetlands",
        "Plums",
        "Developed",
    ],
}

In [None]:
cm = confusion_matrix(all_y_val, all_pred_val, normalize="true")

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    cmap="Blues",
    cbar=False,
    linewidths=1,
    linecolor="black",
    xticklabels=classes_used["class_names"],
    yticklabels=classes_used["class_names"],
)
plt.ylabel("Actual")
plt.xlabel("Predicted")
plt.title("Confusion Matrix")
plt.show()

In [None]:
precision, recall, f1, _ = precision_recall_fscore_support(all_y_val, all_pred_val)

In [None]:
val_metrics = pd.DataFrame(
    {
        "class": classes_used["class_names"],
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }
)

In [None]:
val_metrics

### Train dataset

In [None]:
all_y_train = []
all_pred_train = []
for batch in tqdm(datamodule.train_dataloader()):
    x, y = batch
    pred = model.model(x)
    pred = torch.argmax(pred.output, dim=1)
    y_flat = y.flatten()
    pred_flat = pred.flatten()
    all_y_train.extend(y_flat)
    all_pred_train.extend(pred_flat)

In [None]:
precision, recall, f1, _ = precision_recall_fscore_support(all_y_train, all_pred_train)

In [None]:
train_metrics = pd.DataFrame(
    {
        "class": classes_used["class_names"],
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }
)

In [None]:
train_metrics