In [1]:
import lmdb
from bigearthnet_patch_interface.s2_interface import BigEarthNet_S2_Patch
from pathlib import Path
import importlib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import Tensor, cat, stack

from torch.utils.data import DataLoader, ConcatDataset

In [2]:
import os
from os.path import join, dirname
from dotenv import load_dotenv

# Get paths
load_dotenv('./.env')

BEN_LMDB_PATH = os.environ.get("BEN_LMDB_PATH")
BEN_PATH = os.environ.get("BEN_PATH")

TRAIN_CSV_FILE = os.environ.get("TRAIN_CSV")
TEST_CSV_FILE = os.environ.get("TEST_CSV")
VAL_CSV_FILE = os.environ.get("VAL_CSV")

assert Path(BEN_PATH).exists()
assert Path(BEN_LMDB_PATH).exists()
assert os.path.exists(TRAIN_CSV_FILE)

In [3]:
env = lmdb.open(BEN_LMDB_PATH, readonly=True, readahead=False, lock=False)
txn = env.begin()
cur = txn.cursor()

In [12]:
import dataset_class as ds_class
importlib.reload(ds_class)

val_ds = ds_class.BenDataset(VAL_CSV_FILE, BEN_LMDB_PATH)
test_ds = ds_class.BenDataset(TEST_CSV_FILE, BEN_LMDB_PATH)
train_ds = ds_class.BenDataset(TRAIN_CSV_FILE, BEN_LMDB_PATH)

ds = ConcatDataset([val_ds, test_ds, train_ds])
ds_loader = DataLoader(ds, batch_size=1, shuffle=False)

print("Dataset size:", len(ds_loader))

Dataset size: 14714


#### Plot class distribution

In [17]:
# for X, y in [next(iter(ds_loader))]:
    # print(X.shape)
    # print(y.shape)

# _, y = next(iter(ds_loader))

class_dist = torch.zeros([19], dtype=torch.uint8)
print(class_dist)
cnt = 0

for _, y in ds_loader:
    cnt += 1
    
    if cnt % 1000 == 0:
        print(cnt)
        print(class_dist)
    
    class_dist = torch.add(class_dist, y)
    
print(class_dist)
# tensor([[   0, 7827,    2, 7840,    0, 4846,  447,  395, 1230,  363, 5378,    0,
#           811,   12, 1713, 1703,  289, 5128, 2229]])

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       dtype=torch.uint8)
1000
tensor([[  0, 944,   0, 173,   0, 180,   0,  49, 110,  52, 118,   0,   3,   0,
           9,  96,  11, 166, 181]])
2000
tensor([[   0, 1265,    0,  829,    0,  524,   50,   74,  151,   63,  486,    0,
          114,    0,  212,  248,   23,  619,  287]])
3000
tensor([[   0, 1849,    0, 1374,    0,  945,  102,   93,  233,   83, 1005,    0,
          177,    2,  334,  352,   37, 1026,  496]])
4000
tensor([[   0, 2337,    0, 1959,    0, 1262,  116,  115,  326,  124, 1304,    0,
          197,    2,  388,  485,   43, 1339,  609]])
5000
tensor([[   0, 3087,    0, 2239,    0, 1571,  124,  172,  426,  151, 1548,    0,
          222,    2,  398,  569,   63, 1500,  802]])
6000
tensor([[   0, 3378,    0, 2951,    0, 1866,  203,  178,  486,  164, 1948,    0,
          355,    5,  720,  709,   73, 2067,  891]])
7000
tensor([[   0, 3723,    0, 3721,    0, 2315,  218,  189,  544,  170, 2556,    0,
       