In [None]:
import random
import os

import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics.classification import MulticlassF1Score
from torch.optim import AdamW
import numpy as np

from print_color import print

In [None]:
os.environ["TRANSFORMER_FROM_SCRATCH"] = "True"

In [None]:
from sits_siam.backbone import TransformerBackbone
from sits_siam.head import BertHead, ClassifierHead
from sits_siam.utils import SitsDataset
from sits_siam.bottleneck import PoolingBottleneck, NDVIWord2VecBottleneck
from sits_siam.augment import AddNDVIWeights, RandomChanSwapping, RandomChanRemoval, RandomAddNoise, RandomTempSwapping, RandomTempShift, RandomTempRemoval, AddMissingMask, Normalize

In [None]:
def setup_seed():
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    np.random.seed(42)
    random.seed(42)
    torch.backends.cudnn.deterministic = True

# setup_seed()

In [None]:
whole_df = pd.read_parquet("data/california_sits_bert_original.parquet")

In [None]:
median = [0.0656, 0.0948, 0.1094, 0.1507, 0.2372, 0.2673, 0.2866, 0.2946, 0.2679, 0.1985]
iqd = [0.0456, 0.0536, 0.0946, 0.0769, 0.0851, 0.1053, 0.1066, 0.1074, 0.1428, 0.1376]

In [None]:
train_transforms = [
    AddNDVIWeights(),
    RandomAddNoise(),
    RandomTempSwapping(),
    RandomTempShift(),
    # RandomTempRemoval(),
    AddMissingMask(),
    Normalize(
        a=[
            0.0656,
            0.0948,
            0.1094,
            0.1507,
            0.2372,
            0.2673,
            0.2866,
            0.2946,
            0.2679,
            0.1985,
        ],
        b=[
            0.0456,
            0.0536,
            0.0946,
            0.0769,
            0.0851,
            0.1053,
            0.1066,
            0.1074,
            0.1428,
            0.1376,
        ],
    ),
]

val_transforms = [
    AddNDVIWeights(),
    AddMissingMask(),
    Normalize(
        a=[
            0.0656,
            0.0948,
            0.1094,
            0.1507,
            0.2372,
            0.2673,
            0.2866,
            0.2946,
            0.2679,
            0.1985,
        ],
        b=[
            0.0456,
            0.0536,
            0.0946,
            0.0769,
            0.0851,
            0.1053,
            0.1066,
            0.1074,
            0.1428,
            0.1376,
        ],
    ),
]

In [None]:
# # split whole df by unique ids
ids = whole_df.id.unique()
np.random.shuffle(ids)
train_ids = ids[:int(len(ids) * 0.8)]
val_ids = ids[int(len(ids) * 0.8):]

train_df = whole_df[whole_df.id.isin(train_ids)].reset_index(drop=True)
val_df = whole_df[whole_df.id.isin(val_ids)].reset_index(drop=True)

train_dataset = SitsDataset(train_df, max_seq_len=45, transform=train_transforms)
val_dataset = SitsDataset(val_df, max_seq_len=45, transform=val_transforms)

In [None]:
class TransformerClassifier(pl.LightningModule):
    def __init__(self, max_seq_len=40, num_classes=13):
        super(TransformerClassifier, self).__init__()
        self.backbone = TransformerBackbone(max_seq_len=max_seq_len)
        self.bottleneck = NDVIWord2VecBottleneck()
        self.classifier = ClassifierHead(num_classes=num_classes)

        self.criterion = nn.CrossEntropyLoss()
        self.val_f1 = MulticlassF1Score(num_classes=num_classes, average='macro')
        self.test_f1 = MulticlassF1Score(num_classes=num_classes, average='macro')

    def forward(self, input):
        x = input["x"]
        doy = input["doy"]
        mask = input["mask"]
        weight = input["weight"]

        features = self.backbone(x, doy, mask)
        features = self.bottleneck(features, weight)
        outputs = self.classifier(features)
        return outputs

    def training_step(self, batch, batch_idx):
        targets = batch["y"]
        outputs = self(batch)

        loss = self.criterion(outputs, targets)

        # Log loss and F1 score
        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        targets = batch["y"]
        outputs = self(batch)
        loss = self.criterion(outputs, targets)

        # Calculate F1 score
        preds = torch.argmax(outputs, dim=1)
        f1_score = self.val_f1(preds, targets)

        # Log loss and F1 score
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_f1", f1_score, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        targets = batch["y"]
        outputs = self(batch)
        loss = self.criterion(outputs, targets)

        # Calculate F1 score
        preds = torch.argmax(outputs, dim=1)
        f1_score = self.test_f1(preds, targets)

        # Log loss and F1 score
        self.log("test_loss", loss)
        self.log("test_f1", f1_score)

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=1e-2)
        return optimizer

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=False)

In [None]:
trainer = pl.Trainer(max_epochs=10)
model = TransformerClassifier()


trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)