In [None]:
import random
import os

import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics.classification import MulticlassF1Score
from torch.optim import AdamW
import numpy as np

from print_color import print

from lightly.loss import NegativeCosineSimilarity
from lightly.utils.debug import std_of_l2_normalized
from lightly.models.modules.heads import SimSiamPredictionHead, SimSiamProjectionHead

In [None]:
# os.environ["TRANSFORMER_FROM_SCRATCH"] = "True"

In [None]:
from sits_siam.backbone import TransformerBackbone
from sits_siam.head import BertHead, ClassifierHead
from sits_siam.utils import SitsDataset
from sits_siam.bottleneck import PoolingBottleneck, NDVIWord2VecBottleneck
from sits_siam.augment import AddNDVIWeights, RandomChanSwapping, RandomChanRemoval, RandomAddNoise, RandomTempSwapping, RandomTempShift, RandomTempRemoval, AddMissingMask, Normalize, Pipeline, ToPytorchTensor

In [None]:
def setup_seed():
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    np.random.seed(42)
    random.seed(42)
    torch.backends.cudnn.deterministic = True

# setup_seed()

In [None]:
whole_df = pd.read_parquet("data/california_sits_bert_original.parquet")

In [None]:
median = [0.0656, 0.0948, 0.1094, 0.1507, 0.2372, 0.2673, 0.2866, 0.2946, 0.2679, 0.1985]
iqd = [0.0456, 0.0536, 0.0946, 0.0769, 0.0851, 0.1053, 0.1066, 0.1074, 0.1428, 0.1376]

In [None]:
class FastSiamMultiViewTransform(object):
    def __init__(
        self,
        n_views: int = 2,
    ):
        self.n_views = n_views
        self.transform = Pipeline([
            # AddNDVIWeights(),
            RandomAddNoise(),
            RandomTempSwapping(),
            RandomTempShift(),
            # RandomTempRemoval(),
            AddMissingMask(),
            Normalize(
                a=median,
                b=iqd,
            ),
            ToPytorchTensor()
        ])

    def __call__(self, sample: np.ndarray):
        return [self.transform({k: v.copy() for k, v in sample.items()}) for _ in range(self.n_views)]

In [None]:
whole_df = pd.read_parquet("data/california_sits_bert_original.parquet")

train_df = whole_df[whole_df.use_bert.isin([0, 2])].reset_index(drop=True)
val_df = whole_df[whole_df.use_bert==1].reset_index(drop=True)

train_dataset = SitsDataset(train_df, max_seq_len=45, transform=FastSiamMultiViewTransform())
val_dataset = SitsDataset(val_df, max_seq_len=45, transform=FastSiamMultiViewTransform())

In [None]:
# disable scientific notation pytorch, keep 3 numbers after decimal
torch.set_printoptions(precision=3, sci_mode=False)

In [None]:
class TransformerClassifier(pl.LightningModule):
    def __init__(self, max_seq_len=45):
        super(TransformerClassifier, self).__init__()
        self.backbone = TransformerBackbone(max_seq_len=max_seq_len)
        self.bottleneck = PoolingBottleneck()
        self.projection_head = SimSiamProjectionHead(128, 512, 128)
        self.prediction_head = SimSiamPredictionHead(128, 64, 128)

        self.criterion = NegativeCosineSimilarity()


    def forward(self, input):
        x = input["x"]
        doy = input["doy"]
        mask = input["mask"]
        # weight = input["weight"]

        f = self.backbone(x, doy, mask)
        f = self.bottleneck(f)
        z = self.projection_head(f)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

    def training_step(self, batch, batch_idx):
        views = batch
        features = [self.forward(view) for view in views]
        zs = torch.stack([z for z, _ in features])
        ps = torch.stack([p for _, p in features])

        loss = 0.0
        for i in range(len(views)):
            mask = torch.arange(len(views), device=self.device) != i
            loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_collapse", std_of_l2_normalized(ps[0].detach()), sync_dist=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        views = batch
        features = [self.forward(view) for view in views]
        zs = torch.stack([z for z, _ in features])
        ps = torch.stack([p for _, p in features])

        loss = 0.0
        for i in range(len(views)):
            mask = torch.arange(len(views), device=self.device) != i
            loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)

        self.log("val_loss", loss, prog_bar=True, sync_dist=True)
        self.log("val_collapse", std_of_l2_normalized(ps[0].detach()), sync_dist=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        views = batch
        features = [self.forward(view) for view in views]
        zs = torch.stack([z for z, _ in features])
        ps = torch.stack([p for _, p in features])

        loss = 0.0
        for i in range(len(views)):
            mask = torch.arange(len(views), device=self.device) != i
            loss += self.criterion(ps[i], torch.mean(zs[mask], dim=0)) / len(views)

        self.log("test_loss", loss, sync_dist=True)
        self.log("test_collapse", std_of_l2_normalized(ps[0].detach()), prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optim = torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
        return optim

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=4)

In [None]:
trainer = pl.Trainer(max_epochs=20)
model = TransformerClassifier()


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
raise Exception("Pare malandro")

In [None]:
class FastSiamTestTransform(object):
    def __init__(
        self,
        n_views: int = 1,
        a=None,
        b=None
    ):
        self.n_views = n_views
        self.transform = [
            AddNDVIWeights(),
            # RandomAddNoise(),
            # RandomTempSwapping(),
            # RandomTempShift(),
            # RandomTempRemoval(),
            AddMissingMask(),
            Normalize(
                a=a,
                b=b,
            ),
            ToPytorchTensor()
        ]

    def apply_transform(self, sample):

        sample = {k: v.copy() for k, v in sample.items()} # Avoid side effects since augmentations are in place
        for t in self.transform:
            sample = t(sample)
        return sample

    def __call__(self, sample: np.ndarray):
        return [self.apply_transform(sample) for _ in range(self.n_views)]

In [None]:
backbone = model.backbone
bottleneck = model.bottleneck
backbone.eval()
bottleneck.eval()

In [None]:
all_features = None
with torch.inference_mode():
    for batch in train_dataloader:
        sample = batch[0]

        x = sample["x"]
        doy = sample["doy"]
        mask = sample["mask"]
        
        f = backbone(x, doy, mask)
        f = bottleneck(f)
        if all_features is None:
            all_features = f
        else:
            all_features = torch.cat([all_features, f], dim=0)

train_df_with_features = train_df[["id", "label"]].groupby("id").first().reset_index()
train_df_with_features = train_df_with_features.merge(pd.DataFrame(all_features.detach().numpy(), columns=[f"feature_{i}" for i in range(all_features.shape[1])]), left_index=True, right_index=True)

In [None]:
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt

tsne = TSNE(n_components=2, n_iter=250)
tsne.fit(train_df_with_features[[column for column in train_df_with_features.columns if column.startswith("feature_")]])

In [None]:
all_features = None
with torch.inference_mode():
    for batch in val_dataloader:
        sample = batch[0]

        x = sample["x"]
        doy = sample["doy"]
        mask = sample["mask"]
        
        f = backbone(x, doy, mask)
        f = bottleneck(f)
        if all_features is None:
            all_features = f
        else:
            all_features = torch.cat([all_features, f], dim=0)

val_df_with_features = val_df[["id", "label"]].groupby("id").first().reset_index()
val_df_with_features = val_df_with_features.merge(pd.DataFrame(all_features.detach().numpy(), columns=[f"feature_{i}" for i in range(all_features.shape[1])]), left_index=True, right_index=True)

In [None]:
# Plotting TSNE of dataframe features in two dimensions

X_embedded = tsne.fit_transform(all_features.cpu().detach().numpy())

val_df_with_features["tsne-2d-one"] = X_embedded[:, 0]
val_df_with_features["tsne-2d-two"] = X_embedded[:, 1]

plt.figure(figsize=(16, 10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="label",
    palette=sns.color_palette("hsv", len(val_df_with_features.label.unique())),
    data=val_df_with_features[["tsne-2d-one", "tsne-2d-two", "label"]],
    legend="full",
    alpha=0.7
)

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import f1_score


knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_df_with_features[[column for column in train_df_with_features.columns if column.startswith("feature_")]], train_df_with_features.label)
y_pred = knn.predict(val_df_with_features[[column for column in train_df_with_features.columns if column.startswith("feature_")]])

f1 = f1_score(val_df_with_features.label, y_pred, average="weighted")
print(f"F1 Score: {f1}")