In [None]:
import math

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW, Adam

import numpy as np
from tqdm.std import tqdm

import pytorch_lightning as pl
from torchmetrics.classification import MulticlassF1Score

import copy

torch.set_printoptions(precision=2, sci_mode=False, profile="full", threshold=10_000)

In [None]:
def get_activation_function(name):
    if name == "relu":
        return F.relu
    elif name == "gelu":
        return F.gelu
    else:
        raise ValueError(f"Unsupported activation function: {name}")


def get_activation_module(name):
    if name == "relu":
        return nn.ReLU()
    elif name == "gelu":
        return nn.GELU()
    else:
        raise ValueError(f"Unsupported activation function: {name}")


class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention' without using the math library.
    """

    def forward(self, query, key, value, mask=None, dropout=None):
        d_k = query.size(-1)
        # Compute scaled dot-product attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(d_k, dtype=query.dtype, device=query.device)
        )

        if mask is not None:
            scores = scores.masked_fill(mask, float("-inf"))

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        output = torch.matmul(p_attn, value)
        return output, p_attn


class MultiHeadedAttention(nn.Module):
    """
    Multi-Headed Attention module without using built-in attention modules.
    Supports parameters: d_model, num_heads, dropout, batch_first.
    """

    def __init__(self, d_model, num_heads, dropout=0.1, batch_first=True):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.num_heads = num_heads
        self.batch_first = batch_first

        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, d_model) for _ in range(3)]
        )
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
        if self.batch_first:
            batch_size, seq_len, _ = query.size()
        else:
            seq_len, batch_size, _ = query.size()
            # Transpose to batch first
            query = query.transpose(0, 1)
            key = key.transpose(0, 1)
            value = value.transpose(0, 1)

        # Linear projections
        query, key, value = [
            linear(x)
            .view(batch_size, seq_len, self.num_heads, self.d_k)
            .transpose(1, 2)
            for linear, x in zip(self.linear_layers, (query, key, value))
        ]  # Each tensor is of shape (batch_size, num_heads, seq_len, d_k)

        # Prepare masks
        if key_padding_mask is not None:
            # key_padding_mask: (batch_size, seq_len)
            # Expand to (batch_size, 1, 1, seq_len)
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(
                2
            )  # (batch_size, 1, 1, seq_len)
            # Expand to match the number of heads
            key_padding_mask = key_padding_mask.expand(-1, self.num_heads, -1, -1)
        if attn_mask is not None:
            # attn_mask: (seq_len, seq_len)
            attn_mask = attn_mask.unsqueeze(0)  # (1, seq_len, seq_len)
            attn_mask = attn_mask.expand(batch_size * self.num_heads, -1, -1).view(
                batch_size, self.num_heads, seq_len, seq_len
            )
        # Combine masks
        if key_padding_mask is not None and attn_mask is not None:
            combined_mask = key_padding_mask | attn_mask
        elif key_padding_mask is not None:
            combined_mask = key_padding_mask
        elif attn_mask is not None:
            combined_mask = attn_mask
        else:
            combined_mask = None

        # Apply attention
        x, attn = self.attention(
            query, key, value, mask=combined_mask, dropout=self.dropout
        )

        # Concatenate heads and apply final linear layer
        x = x.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        x = self.output_linear(x)

        if not self.batch_first:
            x = x.transpose(0, 1)

        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model,
        num_heads,
        dim_feedforward=2048,
        dropout=0.1,
        activation="relu",
        batch_first=True,
    ):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = MultiHeadedAttention(
            d_model, num_heads, dropout=dropout, batch_first=batch_first
        )
        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.activation_fn = get_activation_function(activation)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Self-attention
        attn_output = self.self_attn(
            src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
        )
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Feedforward network
        ff_output = self.linear2(self.dropout(self.activation_fn(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)
        return src


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
        )
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        output = src

        for layer in self.layers:
            output = layer(
                output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
            )

        if self.norm is not None:
            output = self.norm(output)

        return output

In [None]:
class SitsTransformer(nn.Module):
    def __init__(
        self,
        input_dim=10,
        num_classes=13,
        d_model=128,
        n_head=16,
        n_layers=1,
        d_inner=128,
        activation="relu",
        dropout=0.2,
        max_len=366,
        max_seq_len=70,
        T=1000,
        max_temporal_shift=30,
    ):
        super(SitsTransformer, self).__init__()
        self.modelname = self._get_name()
        self.max_seq_len = max_seq_len

        self.mlp_dim = [input_dim, 32, 64, d_model]
        layers = []
        for i in range(len(self.mlp_dim) - 1):
            layers.append(LinLayer(self.mlp_dim[i], self.mlp_dim[i + 1]))
        self.mlp1 = nn.Sequential(*layers)

        self.inlayernorm = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
        self.position_enc = PositionalEncoding(
            d_model, max_len=max_len + 2 * max_temporal_shift, T=T
        )

        encoder_layer = TransformerEncoderLayer(
            d_model, n_head, d_inner, dropout, activation, batch_first=True
        )
        encoder_norm = nn.LayerNorm(d_model)
        self.transformerencoder = TransformerEncoder(
            encoder_layer, n_layers, encoder_norm
        )

        layers = []
        decoder = [d_model, 64, 32, num_classes]
        for i in range(len(decoder) - 1):
            layers.append(nn.Linear(decoder[i], decoder[i + 1]))
            if i < (len(decoder) - 2):
                layers.extend([nn.BatchNorm1d(decoder[i + 1]), nn.ReLU()])
        self.decoder = nn.Sequential(*layers)

        self.input_sample = {
            "doy": torch.randint(1, max_len, (2, self.max_seq_len), dtype=torch.int64),
            "mask": torch.zeros((2, self.max_seq_len), dtype=torch.bool),
            "weight": torch.rand((2, self.max_seq_len), dtype=torch.float32),
            "x": torch.rand((2, self.max_seq_len, input_dim), dtype=torch.float32)
        }
        self.expected_output_sample = torch.rand((2, num_classes), dtype=torch.float32)

    def forward(self, input, is_bert=False):
        x = input["x"]
        doy = input["doy"]
        mask = input["mask"]
        weight = input["weight"]

        x = self.mlp1(x)

        x = self.inlayernorm(x)
        x = self.dropout(x + self.position_enc(doy))

        x = self.transformerencoder(x, src_key_padding_mask=mask)

        if not is_bert:
            weight = self.dropout(weight)
            weight /= weight.sum(1, keepdim=True)
            x = torch.bmm(weight.unsqueeze(1), x).squeeze()
        else:
            x, _ = torch.max(x, dim=1)

        logits = self.decoder(x)

        return logits


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, T: int = 10000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(T) / d_model))
        pe = torch.zeros(max_len + 1, d_model)
        pe[1:, 0::2] = torch.sin(position * div_term)
        pe[1:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, doy):
        """
        Args:
            doy: Tensor, shape [batch_size, seq_len]
        """
        return self.pe[doy]


class LinLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LinLayer, self).__init__()
        self.lin = nn.Linear(in_dim, out_dim)
        self.ln = nn.LayerNorm(out_dim)

    def forward(self, x):
        x = self.lin(x)
        x = self.ln(x)
        x = F.relu(x)
        return x

In [None]:
# Check consistency of output sample

with torch.inference_mode():
    model = SitsTransformer()
    output = model(model.input_sample)
    assert output.shape == model.expected_output_sample.shape

In [None]:
def get_weight(x):
    all_zero_mask = np.all(x == 0, axis=1)

    score = np.ones(x.shape[0])
    score = np.minimum(score, (x[:, [0, 1, 2]].sum(1) - 0.2) / 0.6)  # rgb
    cloud = score * 100 > 20
    dark = x[:, [6, 8, 9]].sum(1) < 0.35 # NIR, SWIR1, SWIR2

    ndvi = (x[:, 6] - x[:, 2]) / (x[:, 6] + x[:, 2] + 1e-8)
    ndvi[cloud] = -1
    ndvi[dark] = -1
    ndvi = ndvi.clip(-1, 1)

    weight = np.exp(ndvi)
    weight /= weight.sum()

    weight[all_zero_mask] = 0
    
    return weight

In [None]:
def normalize(x, mean, std):
    x = x.copy()
    
    all_zero_mask = np.all(x == 0, axis=1)

    x = (x - mean) / std
    x[all_zero_mask] = 0

    return x

In [None]:
class SitsDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, mean=None, std=None):
        self.mean = np.array([[0.0656, 0.0948, 0.1094, 0.1507, 0.2372, 0.2673, 0.2866, 0.2946, 0.2679, 0.1985]], dtype=np.half)
        self.std = np.array([0.036289, 0.043310, 0.064736, 0.057953, 0.074167, 0.096407, 0.097816, 0.098368, 0.089847, 0.097866], dtype=np.half)

        self.xs = np.zeros((dataframe.id.nunique(), 70, 10), dtype=np.half)
        self.doys = np.zeros((dataframe.id.nunique(), 70), np.int16)
        self.ys = dataframe[["id", "label"]].groupby("id").first().label.to_numpy()

        for _, row in tqdm(dataframe.iterrows(), total=len(dataframe)):
            self.xs[int(row.id - dataframe.id.min()), int(row.time)] = [row.blue, row.green, row.red, row.red_edge_1, row.red_edge_2, row.red_edge_3, row.nir, row.red_edge_4, row.swir_1, row.swir_2]
            self.doys[int(row.id - dataframe.id.min()), int(row.time)] = row.doy
    
    def __len__(self):
        return self.ys.shape[0]
    
    def __getitem__(self, idx):
        x = self.xs[idx]

        return {
            "doy": torch.from_numpy(self.doys[idx]).long(),
            "mask": torch.from_numpy(x.sum(1) == 0),
            "x": torch.from_numpy(normalize(x, self.mean, self.std)).float(),
            "weight": torch.from_numpy(get_weight(x)).float(),
        }, torch.tensor(self.ys[idx], dtype=torch.long)

In [None]:
class SitsDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, mean=None, std=None):
        self.mean = np.array([[0.0656, 0.0948, 0.1094, 0.1507, 0.2372,
                               0.2673, 0.2866, 0.2946, 0.2679, 0.1985]],
                             dtype=np.half)
        self.std = np.array([0.036289, 0.043310, 0.064736, 0.057953, 0.074167,
                             0.096407, 0.097816, 0.098368, 0.089847, 0.097866],
                            dtype=np.half)

        bands = ['blue', 'green', 'red', 'red_edge_1', 'red_edge_2',
                 'red_edge_3', 'nir', 'red_edge_4', 'swir_1', 'swir_2']

        # Sort the dataframe by 'id' and 'time'
        dataframe = dataframe.sort_values(['id', 'time'])

        # Extract numpy arrays from dataframe columns
        ids = dataframe['id'].to_numpy()
        times = dataframe['time'].astype(int).to_numpy()
        doys = dataframe['doy'].to_numpy()
        bands_data = dataframe[bands].to_numpy()

        # Map unique ids to indices
        unique_ids, id_indices = np.unique(ids, return_inverse=True)
        num_ids = len(unique_ids)

        # Initialize arrays
        self.xs = np.zeros((num_ids, 70, 10), dtype=np.half)
        self.doys = np.zeros((num_ids, 70), dtype=np.int16)

        # Assign values using advanced indexing
        self.xs[id_indices, times, :] = bands_data
        self.doys[id_indices, times] = doys

        # Extract labels
        labels_df = dataframe[['id', 'label']].drop_duplicates('id').set_index('id')
        self.ys = labels_df.loc[unique_ids, 'label'].to_numpy()

    def __len__(self):
        return self.ys.shape[0]

    def __getitem__(self, idx):
        x = self.xs[idx]
        return {
            "doy": torch.from_numpy(self.doys[idx]).long(),
            "mask": torch.from_numpy(x.sum(1) == 0),
            "x": torch.from_numpy(normalize(x, self.mean, self.std)).float(),
            "weight": torch.from_numpy(get_weight(x)).float(),
        }, torch.tensor(self.ys[idx], dtype=torch.long)

In [None]:
whole_df = pd.read_parquet("data/california_sits_bert_original.parquet")
train_df = whole_df[whole_df["use_bert"] == 0].reset_index(drop=True)
val_df = whole_df[whole_df["use_bert"] == 1].reset_index(drop=True)

train_dataset = SitsDataset(train_df)
val_dataset = SitsDataset(val_df)

In [None]:
# # split whole df by unique ids
ids = whole_df.id.unique()
np.random.shuffle(ids)
train_ids = ids[:int(len(ids) * 0.8)]
val_ids = ids[int(len(ids) * 0.8):]

train_df = whole_df[whole_df.id.isin(train_ids)].reset_index(drop=True)
val_df = whole_df[whole_df.id.isin(val_ids)].reset_index(drop=True)

train_dataset = SitsDataset(train_df)
val_dataset = SitsDataset(val_df)

In [None]:
del whole_df
del train_df
del val_df

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=False)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=False)

In [None]:
class SitsLightningModel(pl.LightningModule):
    def __init__(self):
        super(SitsLightningModel, self).__init__()
        self.model = SitsTransformer()
        self.criterion = nn.CrossEntropyLoss()

        # Initialize Macro F1 score for training, validation, and test
        self.val_f1 = MulticlassF1Score(num_classes=13, average='macro')
        self.test_f1 = MulticlassF1Score(num_classes=13, average='macro')

    def forward(self, batch):
        return self.model(batch, True)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)

        # Check if any output is NaN
        if torch.isnan(outputs).any():
            print("outputs")
            print(outputs)

            print("\ntargets")
            print(targets)

            print("\nxs")
            print(inputs["x"])

            print("\ndoys")
            print(inputs["doy"])

            print("\nmasks")
            print(inputs["mask"])

            raise Exception("NaN detected in training output")

        loss = self.criterion(outputs, targets)

        # Log loss and F1 score
        self.log("train_loss", loss, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, targets)

        # Calculate F1 score
        preds = torch.argmax(outputs, dim=1)
        f1_score = self.val_f1(preds, targets)

        # Log loss and F1 score
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_f1", f1_score, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, targets)

        # Calculate F1 score
        preds = torch.argmax(outputs, dim=1)
        f1_score = self.test_f1(preds, targets)

        # Log loss and F1 score
        self.log("test_loss", loss)
        self.log("test_f1", f1_score)

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
trainer = pl.Trainer(max_epochs=150)

# Initialize the model
model = SitsLightningModel()

# Train the model
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)