In [None]:
from sits_siam.utils import SitsDataset
from sits_siam.augment import AddNDVIWeights, RandomChanSwapping, RandomChanRemoval, RandomAddNoise, RandomTempSwapping, RandomTempShift, RandomTempRemoval, AddMissingMask, Normalize

import random
import numpy as np
import matplotlib.pyplot as plt

import pandas as pd

np.set_printoptions(precision=4, suppress=True)

In [None]:
df = pd.read_parquet("data/california_sits_bert_original.parquet")
df = df[df.label==5].reset_index(drop=True)

aug_dataset = SitsDataset(
    df,
    max_seq_len=45,
    transform=[
        AddNDVIWeights(),
        # RandomChanSwapping(),
        # RandomChanRemoval(),
        RandomAddNoise(0.02),
        RandomTempSwapping(max_distance=3),
        RandomTempShift(),
        RandomTempRemoval(),
        AddMissingMask(),
        # Normalize()
    ],
)

common_dataset = SitsDataset(
    df,
    max_seq_len=45,
    transform=[
        AddNDVIWeights(),
        # RandomChanSwapping(),
        # RandomChanRemoval(),
        # RandomAddNoise(),
        # RandomTempSwapping(),
        # RandomTempShift(),
        # RandomTempRemoval(),
        AddMissingMask(),
        # Normalize()
    ],
)

In [None]:
def plot_single_sample_lines(sample, ax=None):
    x = sample['x'].numpy()
    doy = sample['doy'].numpy()
    mask = sample['mask'].numpy()

    x = x[~mask]
    doy = doy[~mask]

    if ax is None:
        _, ax = plt.subplots(figsize=(10, 5))

    ax.plot(doy, x[:, 0], color='blue')
    ax.plot(doy, x[:, 1], color='green')
    ax.plot(doy, x[:, 2], color='red')
    ax.plot(doy, x[:, 3], color='orange')
    ax.plot(doy, x[:, 4], color='purple')
    ax.plot(doy, x[:, 5], color='brown')
    ax.plot(doy, x[:, 6], color='pink')
    ax.plot(doy, x[:, 7], color='gray')
    ax.plot(doy, x[:, 8], color='olive')
    ax.plot(doy, x[:, 9], color='cyan')

    ax.set_xlim(0, 366)
    ax.set_ylim(-0.1, 1)
    ax.grid(True)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(2*10, 5))

sample_id = 250

plot_single_sample_lines(common_dataset[sample_id], ax[0])
plot_single_sample_lines(aug_dataset[sample_id], ax[1])

fig.suptitle("Label = " + str(common_dataset[sample_id]["y"].numpy()))
plt.show()

In [None]:
_, ax = plt.subplots(1, 2, figsize=(2*10, 5))

for y in range(100):
    sample = common_dataset[y]

    doys = sample['doy'].numpy()
    ys = np.repeat(y, sample['doy'].shape[0])
    colors = np.clip(sample['x'].numpy()[:, [2,1,0]]*np.pi, 0, 1)

    mask = ~sample['mask'].numpy()

    ax[0].scatter(doys[mask], ys[mask], c=colors[mask])

    sample = aug_dataset[y]

    doys = sample['doy'].numpy()
    ys = np.repeat(y, sample['doy'].shape[0])
    colors = np.clip(sample['x'].numpy()[:, [2,1,0]]*np.pi, 0, 1)

    mask = ~sample['mask'].numpy()

    ax[1].scatter(doys[mask], ys[mask], c=colors[mask])

plt.show()