In [None]:
from pathlib import Path
from collections import defaultdict

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics import Accuracy, MeanMetric, ClasswiseWrapper, ConfusionMatrix


from vistaformer.config import get_model_config, PROJECT_ROOT
from vistaformer.models import get_model
from vistaformer.datasets.pastis.dataloader import (
    get_val_transforms,
    get_dataloader,
    PASTISDataset,
)
from vistaformer.train_and_evaluate.eval_utils import generate_test_metrics


num_params = lambda x: sum(p.numel() for p in x.parameters() if p.requires_grad)

device = "cuda" if torch.cuda.is_available() else "cpu"

assert torch.cuda.is_available(), "CUDA not available"
root_path = Path("<very_real_path_to_training_outputs>")
assert root_path.exists(), f"Path {root_path} does not exist"

config = get_model_config(root_path / "config.yaml")

model = get_model(config)
model = model.to(device)
model = model.eval()
print(f"Number of trainable parameters: {num_params(model)}")

weights = torch.load(root_path / "best_model.pth", map_location=device)
model.load_state_dict(weights["model_state_dict"], strict=False)

datapath = Path("<very_real_path_to_data>")
config.dataset.path = datapath
out_path = Path("./outputs")
out_path.mkdir(exist_ok=True, parents=True)

dataloader = get_dataloader(config, fold=config.dataset.kwargs["test_folds"])

In [None]:
root_path = Path("very_real_path_to_training_outputs")
paths = [p.parent for p in list(root_path.rglob("*config.yaml"))]


confusion_matrix = ConfusionMatrix(task="multiclass", num_classes=config.num_classes).to(device)
model = model.eval()

print(f"Computing confusion matrix for {len(paths)} models")

for path in paths:
    print(f"Computing confusion matrix for {path.parent.parent.name}/{path.parent.name}...")

    config = get_model_config(path / "config.yaml")
    config.dataset.path = datapath
    dataloader = get_dataloader(config, fold=config.dataset.kwargs["test_folds"])

    with torch.no_grad():
        for data, target in dataloader:
            target = target.to(device)
            if config.is_multi_input_model:
                s2, s1a = data["s2"].to(device), data["s1"].to(device)
                output = model(s2, s1a)
                del data
            else:
                data = data.to(device)
                output = model(data)

            output = torch.argmax(output, dim=1)
            confusion_matrix.update(output, target)


In [None]:
class_names = [c for c in list(dataloader.dataset.index_to_label.values()) if c != "void-label"]
confusion_matrix_np = confusion_matrix.compute().cpu().numpy()
conf_matrix = confusion_matrix_np[0:19, 0:19] # remove void label

df_cm = pd.DataFrame(
    conf_matrix / np.sum(conf_matrix, axis=1)[:, None],
    index = [i for i in class_names],
    columns = [i for i in class_names]
).round(2)
df_cm.to_csv("confusion_matrix.csv")

plt.figure(figsize = (10,5))
sns.set_theme(style='white', font_scale=0.7)
sns.heatmap(df_cm, annot=True, cmap="Blues", fmt='g')
plt.savefig("confusion_matrix.pdf", facecolor="white", transparent=False, dpi=100, bbox_inches="tight", format="pdf")

In [None]:
dataloader = get_dataloader(config, fold=None)

class_name_dict = {v: class_pixel_counts[k] for k, v in dataloader.dataset.index_to_label.items()}
class_pixel_counts = defaultdict(int)
total_pixels = 0

for inputs, labels in dataloader:
    labels = labels.squeeze()
    # Count the pixels for each class
    unique_classes, class_counts = torch.unique(labels, return_counts=True)
    for class_idx, count in zip(unique_classes, class_counts):
        class_pixel_counts[class_idx.item()] += count.item()
        total_pixels += count.item()

In [None]:
# Plotting the bar chart
plt.figure(figsize=(12, 6))

data = pd.DataFrame(class_name_dict.items(), columns=['Class', 'Frequency'])
sns.barplot(x='Class', y='Frequency', data=data, hue='Class', palette="muted")

plt.xlabel('Class Labels')
plt.ylabel('Frequency')
plt.xticks(list(class_name_dict.keys()))  # Ensure x-ticks are the class labels
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=-90)
# Show the plot

plt.savefig("pastis_class_labels.pdf", facecolor="white", dpi=100, bbox_inches="tight", format="pdf")
plt.show()

In [None]:
from vistaformer.datasets.mtlcc.dataloader import get_dataloader as get_mtlcc_dataloader

class_pixel_counts = defaultdict(int)
total_pixels = 0

for split in ["train", "test", "val"]:
    config.dataset.path = Path("/mnt/sata1/datasets/time_series/MTLCC/data_IJGI18/datasets/full/240pkl")
    dataloader = get_mtlcc_dataloader(config, split=split)

    for inputs, labels in dataloader:
        labels = labels.squeeze()
        # Count the pixels for each class
        unique_classes, class_counts = torch.unique(labels, return_counts=True)
        for class_idx, count in zip(unique_classes, class_counts):
            class_pixel_counts[class_idx.item()] += count.item()
            total_pixels += count.item()

In [None]:
# Plotting the bar chart
class_name_dict = {v: class_pixel_counts[k] for k, v in dataloader.dataset.index_to_label.items()}

plt.figure(figsize=(12, 6))
data = pd.DataFrame(class_name_dict.items(), columns=['Class', 'Frequency'])
sns.barplot(x='Class', y='Frequency', data=data, hue='Class', palette="muted")
sns.set_theme(style='white', font_scale=0.7)
# plt.bar(list(class_name_dict.keys()), list(class_name_dict.values()), color='blue')
plt.xlabel('Class Labels')
plt.ylabel('Frequency')
# plt.title('Distribution of PASTIS Class Labels')
plt.xticks(list(class_name_dict.keys()))  # Ensure x-ticks are the class labels
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=-90)
# Show the plot

plt.savefig("mtlcc_class_labels.pdf", facecolor="white", dpi=100, bbox_inches="tight", format="pdf")
plt.show()

In [None]:
import typing
from itertools import zip_longest
from collections import Counter
from numpy import prod
from functools import partial

from fvcore.nn import FlopCountAnalysis
from fvcore.nn.jit_handles import generic_activation_jit
from natten.flops import add_natten_handle


def get_shape(val: object) -> typing.List[int]:
    """
    Get the shapes from a jit value object
    """
    if val.isCompleteTensor():
        r = val.type().sizes()
        if not r:
            r = [1]
        return r
    elif val.type().kind() in ("IntType", "FloatType"):
        return [1]
    else:
        raise ValueError()

def basic_binary_op_flop_jit(inputs, outputs, name):
    input_shapes = [get_shape(v) for v in inputs]
    # for broadcasting
    input_shapes = [s[::-1] for s in input_shapes]
    max_shape = np.array(list(zip_longest(*input_shapes, fillvalue=1))).max(1)
    flop = prod(max_shape)
    flop_counter = Counter({name: flop})
    return flop_counter


def pretty_flops(num_flops: int):
    """
    Pretty print the number of FLOPs.
    """
    units = [("GFLOPs", 1e9), ("MFLOPs", 1e6), ("KFLOPs", 1e3), ("FLOPs", 1)]
    for unit_name, unit_value in units:
        if num_flops >= unit_value:
            return f"{num_flops / unit_value:.2f} {unit_name}"

    return "0 FLOPs"


input_dim = 32
config.image_size = input_dim
config.max_seq_len = 30
config.model_kwargs["seq_lens"] = [30, 15, 7]

model = get_model(config)
model = model.to(device)

inputs = torch.randn(4, 30, 6, input_dim, input_dim).to(device)
sen = torch.randn(4, 30, 10, input_dim, input_dim).to(device)
# model = model.to(device)

counter = FlopCountAnalysis(model, inputs=(sen, inputs))
counter.set_op_handle("aten::softmax", generic_activation_jit("aten::softmax"))
counter.set_op_handle("aten::sigmoid", generic_activation_jit("aten::gelu"))
counter.set_op_handle("aten::gelu", generic_activation_jit("aten::gelu"))
counter.set_op_handle("aten::mish", generic_activation_jit("aten::mish"))
counter.set_op_handle("aten::div_", partial(basic_binary_op_flop_jit, name='aten::div_'))
counter.set_op_handle("aten::mul", partial(basic_binary_op_flop_jit, name='aten::mul'))
counter.set_op_handle("aten::add", partial(basic_binary_op_flop_jit, name='aten::add'))
counter.set_op_handle("aten::add_", partial(basic_binary_op_flop_jit, name='aten::add_'))
add_natten_handle(counter)

print(f"Total number of estimated flops: {pretty_flops(counter.total())}")

In [None]:
# Data provided
model_flops = {
    "U-TAE": [11.64, 46.81, 105.32, 187.24],
    "TSViT": [65.23, 326.36, 979.64, 2352.28],
    "VistaFormer": [7.58, 26.8, 113.79, 326.24],
    "VistaFormer(NeighbourAttn)": [4.85, 14.97, 32.57, 57.67],
}

# Input dimensions
input_dimensions = [32, 64, 96, 128]

# Convert the data into a format suitable for Seaborn
data = {
    "Input Dimensions": input_dimensions * len(model_flops),
    "GFLOPs": sum(model_flops.values(), []),
    "Model": [model for model in model_flops for _ in input_dimensions]
}

# Create the plot
plt.figure(figsize=(10, 6))
sns.lineplot(x="Input Dimensions", y="GFLOPs", hue="Model", data=data, palette="muted", marker='o')
sns.set_theme(style='white', font_scale=0.7)
# Add titles and labels
# plt.title('GFLOPs vs Input Dimensions for Different Models')
plt.yscale('log')
plt.xlabel('Input Dimensions', fontweight='bold', fontsize=10)
plt.ylabel('GFLOPs (log scale)', fontweight="bold", fontsize=10)

plt.savefig("model_gflops.pdf", facecolor="white", transparent=False, dpi=100, bbox_inches="tight", format="pdf")
# Show the plot
plt.show()