In [1]:
import torch
import numpy as np
import pandas as pd
import os

os.chdir("../")

%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

data_root = Path(os.getenv("PASTIS_DATA_DIR", "")) # Path to PASTIS data directory
assert data_root.is_dir(), f"PASTIS data directory not found at {data_root}"

In [9]:
from src.data.components.pastis import PASTISSubpatchedDatasetS2

dataset = PASTISSubpatchedDatasetS2(data_root, 3, with_datetime=True, metadata=[4])
len(dataset)

7343

In [13]:
from src.data.functional import unique


y = dataset[0]["target"]

classes, counts = unique(y, return_counts=True)
classes, counts


(array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
         92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104,
        105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117,
        118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130,
        131]),
 array([18044,  2375,   769,  3945,  2307,    66,   764,    23,    69,
          165,   528,   284,  4529,   782,   151,   173,   204,     2,
            2,   860,     3,    22,    38,    28,    29,    40,    19,
          134,   288,     2,    75,   545,

In [8]:
from src.data.few_shot_datamodule import sample_k_per_class, get_class_indices


class_indices = get_class_indices(dataset)
class_indices

Getting class indices:   0%|          | 0/37411 [00:00<?, ?it/s]

Getting class indices: 100%|██████████| 37411/37411 [00:05<00:00, 6419.93it/s]


defaultdict(list,
            {np.int64(0): [0,
              1,
              2,
              3,
              4,
              5,
              6,
              7,
              8,
              9,
              10,
              11,
              12,
              13,
              14,
              15,
              16,
              17,
              18,
              19,
              20,
              21,
              22,
              23,
              24,
              25,
              26,
              27,
              28,
              29,
              30,
              31,
              32,
              33,
              34,
              35,
              36,
              37,
              38,
              39,
              40,
              41,
              42,
              43,
              44,
              45,
              46,
              47,
              48,
              49,
              50,
              51,
              52,
              53,
       

In [38]:
torch.manual_seed(42)

sampled = sample_k_per_class(class_indices, 5, exclusive=False)
sampled

{np.int64(0): [1953, 18978, 19710, 33937, 18715],
 np.int64(1): [27044, 3919, 10033, 554, 32266],
 np.int64(2): [8821, 5958, 25734, 14925, 23522],
 np.int64(3): [30986, 2939, 27089, 5057, 35267],
 np.int64(4): [33467, 32753, 34866, 31522, 23284],
 np.int64(5): [3020, 20825, 29958, 19719, 15607],
 np.int64(6): [36576, 12998, 9058, 7951, 3683],
 np.int64(7): [10786, 19824, 28695, 15704, 20814],
 np.int64(8): [29169, 28149, 23811, 16536, 28601],
 np.int64(9): [16728, 30588, 21543, 19630, 21022],
 np.int64(10): [8727, 4426, 21344, 29554, 8956],
 np.int64(11): [7588, 34433, 15929, 31269, 29722],
 np.int64(12): [3748, 34183, 30643, 24157, 14058],
 np.int64(13): [6653, 13088, 4578, 12113, 1711],
 np.int64(14): [13281, 12495, 16198, 14005, 4885],
 np.int64(15): [37240, 30724, 7104, 23728, 34371],
 np.int64(16): [4149, 33185, 4884, 37111, 36165],
 np.int64(17): [22887, 34721, 16003, 9417, 12573],
 np.int64(18): [25341, 1367, 24942, 2843, 29184],
 np.int64(19): [4245, 25563, 23451, 19883, 9049],

In [39]:
unique_indexes = set()
for cls, indices in sampled.items():
    unique_indexes.update(indices)

unique_indexes

{34816,
 1,
 23202,
 30724,
 14341,
 4,
 14342,
 30938,
 20493,
 16,
 10257,
 36884,
 30741,
 20500,
 28695,
 10268,
 12318,
 10270,
 10272,
 6178,
 12324,
 38,
 22569,
 18479,
 47,
 34866,
 4149,
 32826,
 28731,
 32828,
 32833,
 12357,
 26695,
 14409,
 20554,
 12362,
 12364,
 34892,
 10314,
 34894,
 30800,
 32849,
 34897,
 10321,
 34900,
 10325,
 4182,
 32853,
 22616,
 10328,
 30809,
 10329,
 10327,
 30810,
 34902,
 32865,
 16483,
 32869,
 30823,
 30824,
 30825,
 30826,
 30827,
 6252,
 32877,
 30828,
 30830,
 32880,
 30831,
 30832,
 30829,
 16500,
 30834,
 30833,
 10362,
 16507,
 30848,
 30849,
 31776,
 8323,
 30852,
 4230,
 30855,
 10377,
 4233,
 8330,
 20620,
 10379,
 30857,
 8335,
 31779,
 14483,
 4245,
 31780,
 30871,
 16536,
 34630,
 34972,
 16540,
 8349,
 31782,
 30884,
 35447,
 22697,
 35448,
 37078,
 35450,
 30903,
 37080,
 30914,
 30919,
 18638,
 12495,
 18641,
 2257,
 37075,
 37076,
 30932,
 30934,
 37079,
 30936,
 12505,
 32986,
 30939,
 30937,
 37083,
 37086,
 30942,
 8416

In [65]:
from src.data.transforms import TemporalFeatureExtraction

temp = TemporalFeatureExtraction()

In [12]:
import pandas as pd

sample = dataset[0]
temp.transform(sample["dates"])

array([[267.],
       [272.],
       [277.]], dtype=float32)

In [35]:
from src.data.base_datamodule import BaseDataModule

base = BaseDataModule(train=dataset, batch_size=16)
base

<src.data.base_datamodule.BaseDataModule at 0x7ae94ed98050>

In [38]:
base.setup("train")

In [39]:
bx = base.train_dataloader()
bx

<torch.utils.data.dataloader.DataLoader at 0x7aeac1f33a10>

In [42]:
data = next(iter(bx))
data["dates"]

tensor([[20180924, 20180929, 20181123, 20190117, 20190122, 20190216, 20190221,
         20190226, 20190313, 20190328, 20190412, 20190422, 20181004, 20190427,
         20190512, 20190527, 20190601, 20190606, 20190621, 20190626, 20190701,
         20190706, 20190711, 20181009, 20190716, 20190731, 20190810, 20190815,
         20190820, 20190825, 20190830, 20190904, 20190914, 20190919, 20181019,
         20190924, 20191009, 20191019, 20181024, 20181029, 20181108, 20181113,
         20181118],
        [20180924, 20180929, 20181123, 20190117, 20190122, 20190216, 20190221,
         20190226, 20190313, 20190328, 20190412, 20190422, 20181004, 20190427,
         20190512, 20190527, 20190601, 20190606, 20190621, 20190626, 20190701,
         20190706, 20190711, 20181009, 20190716, 20190731, 20190810, 20190815,
         20190820, 20190825, 20190830, 20190904, 20190914, 20190919, 20181019,
         20190924, 20191009, 20191019, 20181024, 20181029, 20181108, 20181113,
         20181118],
        [201