In [None]:
import math
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import lmdb
from tqdm import tqdm
import pdb
import matplotlib.pyplot as plt
import collections
collections.Iterable = collections.abc.Iterable
import kornia.augmentation as K

from cvtorchvision import cvtransforms
import numpy as np
import torch
import random
from PIL import ImageFilter
import random
import cv2
from argparse import Namespace
from pretrain_ssl.datasets.SSL4EO.ssl4eo_dataset_lmdb import LMDBDataset
from pretrain_ssl.datasets.SSL4EO.ssl4eo_dataset import SSL4EO

### band statistics: mean & std
# calculated from 50k subset
S1_MEAN = [-12.54847273, -20.19237134]
S1_STD = [5.25697717, 5.91150917]

S2A_MEAN = [752.40087073, 884.29673756, 1144.16202635, 1297.47289228, 1624.90992062, 2194.6423161, 2422.21248945, 2517.76053101, 2581.64687018, 2645.51888987, 2368.51236873, 1805.06846033]
S2A_STD = [1108.02887453, 1155.15170768, 1183.6292542, 1368.11351514, 1370.265037, 1355.55390699, 1416.51487101, 1474.78900051, 1439.3086061, 1582.28010962, 1455.52084939, 1343.48379601]

S2C_MEAN = [1605.57504906, 1390.78157673, 1314.8729939, 1363.52445545, 1549.44374991, 2091.74883118, 2371.7172463, 2299.90463006, 2560.29504086, 830.06605044, 22.10351321, 2177.07172323, 1524.06546312]
S2C_STD = [786.78685367, 850.34818441, 875.06484736, 1138.84957046, 1122.17775652, 1161.59187054, 1274.39184232, 1248.42891965, 1345.52684884, 577.31607053, 51.15431158, 1336.09932639, 1136.53823676]

ALL_BANDS_S2_L2A = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12']
ALL_BANDS_S2_L1C = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B10', 'B11', 'B12']
RGB_BANDS = ['B4', 'B3', 'B2']
ALL_BANDS_S1_GRD = ['VV', 'VH']

def normalize(img, mean, std):
    """
    Normalize a single-channel Numpy array [H,W].
    Input is unbounded float32, output is uint8 in range [0, 255].
    """
    min_value = mean - 2 * std
    max_value = mean + 2 * std
    img = (img - min_value) / (max_value - min_value) * 255.0
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img


def normalize_img(img, means, stds):
    """
    Normalize Numpy array of shape [C,H,W].
    Input is unbounded float32, output is uint8 in range [0, 255].
    """
    if type(means) is list:
        means = np.array(means)
    if type(stds) is list:
        stds = np.array(stds)
        
    min_values = (means - 2 * stds)[:, None, None]  # add extra dimensions to make broadcasting work
    max_values = (means + 2 * stds)[:, None, None]
    img = (img - min_values) / (max_values - min_values) * 255.0
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img


def denormalize_img(img, means, stds):
    """
    Converts a normalized uint8 Numpy array in the range [0, 255] into a 
    denormalized float32 Numpy array, with raw intensity values from Sentinel
    (divided by 10000 and clipped to [0, 1]).

    The input is assumed to have been normalized by the 'normalize' method for each
    channel, with the given mean/std.

    Input and output should have shape [C, H, W].
    """
    if type(means) is list:
        means = np.array(means)
    if type(stds) is list:
        stds = np.array(stds)

    min_values = (means - 2 * stds)[:, None, None]  # add extra dimensions to make broadcasting work
    max_values = (means + 2 * stds)[:, None, None]
    denormalized_img = (img.astype(np.float32) / 255.0) * (max_values - min_values) + min_values
    denormalized_img = np.clip(denormalized_img / 10000.0, 0, 1)
    return denormalized_img

## Image Transforms

In [None]:
# if args.dtype=='uint8':
#     from models.rs_transforms_uint8 import RandomChannelDrop,RandomBrightness,RandomContrast,ToGray
# else:
from pretrain_ssl.models.rs_transforms_float32 import RandomChannelDrop,RandomBrightness,RandomContrast,ToGray
    
class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform, season='fixed'):
        self.base_transform = base_transform
        self.season = season

    def __call__(self, x):

        if self.season=='augment':
            season1 = np.random.choice([0,1,2,3])
            season2 = np.random.choice([0,1,2,3])
        elif self.season=='fixed':
            np.random.seed(42)
            season1 = np.random.choice([0,1,2,3])
            season2 = season1
        elif self.season=='random':
            season1 = np.random.choice([0,1,2,3])
            season2 = season1

        x1 = np.transpose(x[season1,:,:,:],(1,2,0))
        x2 = np.transpose(x[season2,:,:,:],(1,2,0))

        q = self.base_transform(x1)
        k = self.base_transform(x2)

        return [q, k]

class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        #x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        #return x
        return cv2.GaussianBlur(x,(0,0),sigma)


## Visualization functions

In [None]:
def plot_img(img, rgb_bands=[3, 2, 1], band_names=None, row_label=None):
    """
    Plot the given img, assumed to be a Numpy array or Tensor of shape
    [C, H, W], with pixel values between [0, 1]. The output will be a row
    of images: an RGB image (if rgb_bands provided), then an image for each band.

    "band_names" should be a list of length C, with the name
    of each band (for the header). If None, do not produce a header.

    "row_label" is a label for the row (optional)
    """
    if isinstance(img, torch.Tensor):
        img = img.detach().cpu().numpy()

    # If image is uint8, rescale it to be in [0, 1]
    if img.dtype == np.uint8:
        img = img.astype(np.float32) / 255.0

    if band_names is not None:
        assert len(band_names) == img.shape[0]
    n_cols = img.shape[0]
    if rgb_bands is not None:
        n_cols += 1

    # Create row of images
    fig, axeslist = plt.subplots(1, n_cols, figsize=(n_cols*3, 3))

    # Row label
    if row_label is not None:
        axeslist[0].set_ylabel(row_label, size=14)

    # Image of each band
    for i in range(img.shape[0]):
        band_img = img[i, :, :] # [H, W]
        axeslist[i].imshow(band_img, vmin=0, vmax=1)
        axeslist[i].set_axis_off()  # Remove axis ticks
        if band_names is not None:
            axeslist[i].set_title(band_names[i])

    # RGB image
    if rgb_bands is not None:
        rgb_img = img[rgb_bands, :, :].transpose((1, 2, 0))  # Transpose to [H, W, C]
        axeslist[-1].imshow(rgb_img)
        axeslist[-1].set_axis_off()  # Remove axis ticks
        if band_names is not None:
            axeslist[-1].set_title("RGB")

    fig.tight_layout()
    plt.show()


def plot_histogram(ax, values, title):
    """
    Plots a histogram of the given 'values' to 'ax'. The title contains the 'title' string
    (along with mean/std/min/max)
    """
    ax.hist(values, bins=30)
    ax.set_title("{}:\nmean={:.2f}, std={:.2f}\nmin={:.2f}, max={:.2f}\np1={:.2f}, p99={:.2f}".format(
                 title, np.mean(values), np.std(values), np.min(values), np.max(values),
                 np.quantile(values, 0.01), np.quantile(values, 0.99)))


def plot_histogram_all_channels(imgs, band_names=None):
    """
    Plots the distribution of each channel in the given images.

    'imgs' is a Numpy array of images, shape [B, C, H, W]
    band_names is a list of channel names, length C.
    """
    if band_names is None:
        band_names = [f"Channel {i}" for i in range(imgs.shape[1])]

    nrows = math.ceil(len(band_names) / 4)
    ncols = 4
    fig, axeslist = plt.subplots(nrows, ncols, figsize=(3*ncols, 3*nrows))

    # Loop through each band
    for band_idx, band_name in enumerate(band_names):
        vals = imgs[:, band_idx, :, :].flatten()
        plot_histogram(axeslist.ravel()[band_idx], vals, band_name)

    plt.tight_layout()
    plt.show()


def plot_batch(imgs, rgb_bands=[3, 2, 1], band_names=None):
    """
    Takes a batch of images (Tensor or Numpy), of shape [B, C, H, W].
    Produces a histogram of pixel values for each channel, and displays a few sample images.
    """
    print("Imgs shape", imgs.shape, imgs.dtype, type(imgs))
    if isinstance(imgs, torch.Tensor):
        imgs = imgs.detach().cpu().numpy()

    plot_histogram_all_channels(imgs, band_names)
    for idx in range(3):
        plot_img(imgs[idx, :, :, :], rgb_bands=rgb_bands, band_names=band_names, row_label=f"Image {idx}")    

## SSL4EO images (raw uint8, not LMDB)


In [None]:
# Load the raw images, without normalizing.
args = Namespace()
args.root = "/mnt/beegfs/bulk/mirror/jyf6/datasets/SSL4EO_data/0k_251k_uint8_jpeg_tif"
args.normalize = False
args.mode = ["s2c"]
args.dtype = "uint8"

train_dataset = SSL4EO(root=args.root, normalize=args.normalize, mode=args.mode, dtype=args.dtype)

# Get a batch of images
imgs = []
for idx in np.random.choice(len(train_dataset), 16):
    # s2c is a numpy array of shape [T, C, H, W], uint8
    s1, s2a, s2c = train_dataset[idx]
    imgs.append(s2c)
imgs = np.concatenate(imgs, axis=0)

# Plot the images
plot_batch(imgs, band_names=ALL_BANDS_S2_L1C)

## SSL4EO images, LMDB (no transform)

In [None]:
# LMDBDataset with transforms
train_dataset_raw = LMDBDataset(
    lmdb_file='/mnt/beegfs/bulk/mirror/jyf6/datasets/SSL4EO_data/0k_251k_uint8_jpeg_tif/ssl4eo_251k_s2c_uint8.lmdb',
    s2c_transform=None,
    is_slurm_job=False,
    normalize=False,
    mode = ['s2c'],
    dtype='uint8'
)

# Plot images directly from the dataset (__getitem__)
imgs = []
for idx in np.random.choice(len(train_dataset_raw), 16):
    # NUMPY array of shape [T, C, H, W]. DENORMALIZED from [0, 255] to the raw S2 values, divided by 10000
    s2c = train_dataset_raw[idx]
    print(s2c.shape, s2c.dtype)
    imgs.append(s2c)
imgs = np.concatenate(imgs, axis=0)
plot_batch(imgs, band_names=ALL_BANDS_S2_L1C)

In [None]:
# Plot images from the dataloader
train_loader_raw = torch.utils.data.DataLoader(train_dataset_raw, batch_size=32, shuffle=False)
imgs = next(iter(train_loader_raw))  # Batch from dataloader: [B, T, C, H, W]
imgs = imgs.reshape((-1, imgs.shape[2], imgs.shape[3], imgs.shape[4]))  # [B*T, C, H, W]
plot_batch(imgs, band_names=ALL_BANDS_S2_L1C)

## Contrastive SSL4EO dataset (LMDB). Augmented views

In [None]:
# Data augmentations
train_transforms_s1 = cvtransforms.Compose([
    cvtransforms.RandomResizedCrop(112, scale=(0.2, 1.)),
    cvtransforms.RandomApply([
        RandomBrightness(0.4),
        RandomContrast(0.4)
    ], p=0.8),
    cvtransforms.RandomApply([ToGray(2)], p=0.2),
    cvtransforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
    cvtransforms.RandomHorizontalFlip(),
    #cvtransforms.RandomApply([RandomChannelDrop(min_n_drop=1, max_n_drop=6)], p=0.5),        
    cvtransforms.ToTensor()
])
train_transforms_s2a = cvtransforms.Compose([
    cvtransforms.RandomResizedCrop(112, scale=(0.2, 1.)),
    cvtransforms.RandomApply([
        RandomBrightness(0.4),
        RandomContrast(0.4)
    ], p=0.8),
    cvtransforms.RandomApply([ToGray(12)], p=0.2),
    cvtransforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
    cvtransforms.RandomHorizontalFlip(),
    cvtransforms.RandomApply([RandomChannelDrop(min_n_drop=1, max_n_drop=6)], p=0.5),        
    cvtransforms.ToTensor()
])
train_transforms_s2c = cvtransforms.Compose([
    cvtransforms.RandomResizedCrop(112, scale=(0.2, 1.)),
    cvtransforms.RandomApply([
        RandomBrightness(0.4),
        RandomContrast(0.4)
    ], p=0.8),
    cvtransforms.RandomApply([ToGray(13)], p=0.2),
    cvtransforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
    cvtransforms.RandomHorizontalFlip(),
    cvtransforms.RandomApply([RandomChannelDrop(min_n_drop=1, max_n_drop=6)], p=0.5),        
    cvtransforms.ToTensor()
])

# LMDB data with TwoCropsTransform (actually used)
train_dataset = LMDBDataset(
    lmdb_file='/mnt/beegfs/bulk/mirror/jyf6/datasets/SSL4EO_data/0k_251k_uint8_jpeg_tif/ssl4eo_251k_s2c_uint8.lmdb',
    # s1_transform=TwoCropsTransform(train_transforms_s1,season='augment'),
    # s2a_transform=TwoCropsTransform(train_transforms_s2a,season='augment'),
    s2c_transform=TwoCropsTransform(train_transforms_s2c,season='augment'),
    is_slurm_job=False,
    normalize=False,
    mode = ['s2c'],
    dtype='uint8'
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=False)
view0, view1 = next(iter(train_loader))  # view0, view1 are float32 Tensor [B, C, H, W]
plot_batch(view0, band_names=ALL_BANDS_S2_L1C)

## BigEarthNet dataset (RAW)

In [None]:
from transfer_classification.datasets.BigEarthNet.bigearthnet_dataset_seco import Bigearthnet

args = Namespace()
args.data_dir = "/mnt/beegfs/bulk/mirror/jyf6/datasets/geospatial/datasets/BigEarthNet"
all_bands = ['B01', 'B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
args.download = False
train_dataset = Bigearthnet(
    root=args.data_dir,
    split='train',
    bands=all_bands,
    download=args.download,
    normalize=False,
)

# Get a batch of images (directly from __getitem__)
imgs = []
for idx in np.random.choice(len(train_dataset), 16):
    img, target = train_dataset[idx]  # img: Numpy int16, [H, W, C]  Target: binary vector (multi-hot)
    img = img.transpose((2, 0, 1)).astype(np.float32)
    img = np.clip(img / 10000.0, 0, 1)  # Numpy float32, [C, H, W]
    imgs.append(img)
imgs = np.stack(imgs, axis=0)  # [B, C, H, W]
plot_batch(imgs, band_names=ALL_BANDS_S2_L2A)

## BigEarthNet dataset (LMDB)

In [None]:
import os
from transfer_classification.datasets.BigEarthNet.bigearthnet_dataset_seco_lmdb_s2_uint8 import LMDBDataset

train_transforms = cvtransforms.Compose([
    cvtransforms.RandomResizedCrop(224,scale=(0.8,1.0)), # multilabel, avoid cropping out labels
    cvtransforms.RandomHorizontalFlip(),
    cvtransforms.ToTensor()])
train_dataset = LMDBDataset(
    lmdb_file=os.path.join(args.data_dir, "train_B12.lmdb"),
    transform=train_transforms,
)

# Get images directly from Dataset
imgs = []
for idx in np.random.choice(len(train_dataset), 16):
    img, target = train_dataset[idx]  # Tensor [channel, height, width], unnormalized divided by 10000 (float32)
    imgs.append(img)
imgs = torch.stack(imgs, axis=0)  # [batch, channel, height, width]
plot_batch(imgs, band_names=ALL_BANDS_S2_L2A) 

In [None]:
# Get images from Dataloader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
imgs, target = next(iter(train_loader))
plot_batch(imgs, band_names=ALL_BANDS_S2_L2A)

## EuroSAT dataset

In [None]:
from transfer_classification.datasets.EuroSat.eurosat_dataset import EurosatDataset, Subset
from sklearn.model_selection import train_test_split
from cvtorchvision import cvtransforms

args = Namespace()
args.data_dir = "/mnt/beegfs/bulk/mirror/jyf6/datasets/geospatial/datasets/EuroSAT_MS"
args.bands = "B13"
args.seed = 42
eurosat_dataset = EurosatDataset(root=args.data_dir,bands=args.bands, normalize=False)

train_transforms = cvtransforms.Compose([
        cvtransforms.RandomResizedCrop(224),
        #cvtransforms.Resize(args.in_size),
        cvtransforms.RandomHorizontalFlip(),
        cvtransforms.ToTensor(),
        ])

val_transforms = cvtransforms.Compose([
        cvtransforms.Resize(256),
        cvtransforms.CenterCrop(224),
        cvtransforms.ToTensor(),
        ])

# Split into train/val
indices = np.arange(len(eurosat_dataset))
train_indices, test_indices = train_test_split(indices, train_size=0.8,stratify=eurosat_dataset.targets,random_state=args.seed)    
train_dataset = Subset(eurosat_dataset, train_indices, train_transforms)
val_dataset = Subset(eurosat_dataset, test_indices, val_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

imgs, target = next(iter(train_loader))
plot_batch(imgs, band_names=ALL_BANDS_S2_L1C)

## SustainBench-BigEarthNet dataset

In [None]:
import sys
sys.path.append("../../geo_benchhmarks")
from geo_benchhmarks.test import get_data_loaders

# Augmentations
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(), # Random horizontal flip
    transforms.RandomVerticalFlip(), # Random vertical flip
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Normalize to [-1, 1]
])

train_loader, val_loader, test_loader = get_data_loaders("bigearthnet", train_transform=train_transform, test_transform=test_transform)
imgs = next(iter(train_loader))
print("Imgs shape", imgs.shape, imgs.dtype)
plot_batch(imgs, band_names=ALL_BANDS_S2_L2A)

In [None]:
train_loader, val_loader, test_loader = get_data_loaders("brick-kiln", train_transform=train_transform, test_transform=test_transform)
imgs = next(iter(train_loader))
print("Imgs shape", imgs.shape, imgs.dtype)
plot_batch(imgs, band_names=ALL_BANDS_S2_L2C)

## ShapeContrast dataset

In [None]:
from models.rs_transforms_kornia import RandomChannelDrop,RandomBrightness,RandomContrast,ToGray

class ShapeColorTransform:
    """Take three crops of one image:
    V1 = base_transform(img)
    V2 = intensity_transform(base_transform(img))
    V3 = geometric_transform(base_transform(img))

    V2 is the "shape key" (similar shape but with different color),
    V3 is the "color key" (similar color but geometric transform).
    All get passed through 'base_transform.
    If `season` is `augment`, changing seasons counts as an intensity transform."""

    def __init__(self, base_transform, intensity_transform, geometric_transform, season='fixed'):
        self.base_transform = base_transform
        self.intensity_transform = intensity_transform
        self.geometric_transform = geometric_transform
        self.season = season

    def __call__(self, x):
        if self.season=='augment':
            season1 = np.random.choice([0,1,2,3])
            season2 = np.random.choice([0,1,2,3])
        elif self.season=='fixed':
            np.random.seed(42)
            season1 = np.random.choice([0,1,2,3])
            season2 = season1
        elif self.season=='random':
            season1 = np.random.choice([0,1,2,3])
            season2 = season1

        x1 = torch.tensor(x[season1,:,:,:])  # @joshuafan: For kornia, there is no need to transpose. We keep the order of [C, H, W].
        x2 = torch.tensor(x[season2,:,:,:])
        # print("At start of ShapeColorTransform", x1.shape, x2.shape)

        v1 = self.base_transform(x1).squeeze(0)  # Kornia adds a batch dimension (1) at the beginning. We can remove it.
        v2 = self.intensity_transform(self.base_transform(x2)).squeeze(0)
        v3 = self.geometric_transform(self.base_transform(x1)).squeeze(0)
        # print("After ShapeColorTransform, views:", v1.shape, v2.shape, v3.shape)
        return [v1, v2, v3]

In [None]:
base_transform = K.RandomResizedCrop(size=(112, 112), scale=(0.2, 1.), p=1)
intensity_transforms = [RandomBrightness(0.4, p=0.8),
                        RandomContrast(0.4, p=0.8),
                        ToGray(13, p=0.2),
                        RandomChannelDrop(min_n_drop=1, max_n_drop=6, p=0.5)]
intensity_transform = K.AugmentationSequential(*intensity_transforms, data_keys=["input"])
geometric_transforms = [K.RandomHorizontalFlip(p=0.5),
                        K.RandomVerticalFlip(p=0.5),
                        K.RandomGaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0), p=0.5),
                        K.RandomJigsaw(grid=(16,16), p=0.8)]
geometric_transform = K.AugmentationSequential(*geometric_transforms, data_keys=["input"])

train_dataset = LMDBDataset(
    lmdb_file='/mnt/beegfs/bulk/mirror/jyf6/datasets/geospatial/datasets/SSL4EO/0k_251k_uint8_jpeg_tif/ssl4eo_251k_s2c_uint8.lmdb',
    s2c_transform=ShapeColorTransform(base_transform, intensity_transform, geometric_transform, season='augment'),
    is_slurm_job=False,
    normalize = False,
    mode = ['s2c'],
    dtype='uint8'
)

train_dataset_raw = LMDBDataset(
    lmdb_file='/mnt/beegfs/bulk/mirror/jyf6/datasets/geospatial/datasets/SSL4EO/0k_251k_uint8_jpeg_tif/ssl4eo_251k_s2c_uint8.lmdb',
    s2c_transform=None,
    is_slurm_job=False,
    normalize = False,
    mode = ['s2c'],
    dtype='uint8'
)

# Get images directly from the dataset __getitem__ method
for idx in [1]:
    s2c_orig = train_dataset_raw[idx]  # [season, channel, height, width]
    print(s2c_orig.shape)  
    plot_img(s2c_orig[0, :, :, :], band_names=ALL_BANDS_S2_L1C, rgb_bands=[3,2,1], row_label=f"Raw image")

    # Print example values
    s2c_views = train_dataset[idx]  # list of [channel, height, width]
    for view_idx in range(3):
        # print(s2c_views[view_idx].shape)
        plot_img(s2c_views[view_idx][:, :, :], band_names=ALL_BANDS_S2_L1C, rgb_bands=[3,2,1], row_label=f"View {view_idx} S2C")

# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=0, shuffle=False)
# train_loader_raw = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=0, shuffle=False)
# for idx, s2c in enumerate(train_loader):
#     if idx>0:
#         break
#     print(s2c[0].shape)

#     # Print example values
#     for view_idx in range(3):
#         plot_img(s2c[view_idx][0, :, :, :], band_names=ALL_BANDS_S2_L1C, rgb_bands=[3,2,1], row_label=f"View {view_idx} S2C")



## Brickkiln dataset

In [None]:
import sys
sys.path.append("../../../geo_benchhmarks")
from ssl_main import GeoBenchDataSet

data_set_name = "brick-kiln"
args.bands = "B13"
train_supervised_ds = GeoBenchDataSet(dataset_name=data_set_name,split="train",sl_split="sl", transforms=train_transform, bands=args.bands)
test_ds = GeoBenchDataSet(dataset_name=data_set_name,split="test", transforms=test_transform, bands=args.bands)

# Now put as torch data loaders 
train_supervised_loader = torch.utils.data.DataLoader(train_supervised_ds, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False)