In [13]:
import sys
sys.path.append('/atlas2/u/jonxuxu/harvest-piles/src')

import numpy as np
import torch

from config import Swin_Pretrain
from torch.utils.data import DataLoader
from torchvision.transforms import (
    Compose,
    Resize,
    ToTensor,
    ToPILImage,
    Normalize,
)
from transformers import Swinv2Config
import os

In [14]:
config = Swin_Pretrain()
pretrained_model_path = "microsoft/swinv2-base-patch4-window8-256"
model_config = Swinv2Config.from_pretrained(pretrained_model_path)

In [57]:
import pandas as pd
from torch.utils.data import Dataset
import torch
import os
import cv2

class SkysatUnlabelled(Dataset):
    def __init__(self, filenames, image_dir, transform):
        self.x = filenames
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        path = os.path.join(self.image_dir, self.x[index])
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return self.transform(image)

def create_SkysatUnlabelled_dataset(csv_file, image_dir, transform, train_split):
    df = pd.read_csv(csv_file, usecols=["filename"])
    filename_list = df["filename"].tolist()
    train_size = int(len(filename_list) * train_split)
    train_examples = filename_list[:train_size]
    test_examples = filename_list[train_size:]
    print(type(test_examples))

    return SkysatUnlabelled(train_examples, image_dir, transform), SkysatUnlabelled(
        test_examples, image_dir, transform
    )


In [58]:
class MaskGenerator:
    """
    A class to generate boolean masks for the pretraining task.

    A mask is a 1D tensor of shape (image_size / model_patch_size)**2 where the value is either 0 or 1,
    where 1 indicates "masked".
    """

    def __init__(
        self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6
    ):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio

        if self.input_size % self.mask_patch_size != 0:
            raise ValueError("Input size must be divisible by mask patch size")
        if self.mask_patch_size % self.model_patch_size != 0:
            raise ValueError("Mask patch size must be divisible by model patch size")

        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size

        self.token_count = self.rand_size**2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))

    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1

        mask = mask.reshape((self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)

        return torch.tensor(mask.flatten())

mask_generator = MaskGenerator(
    input_size=model_config.image_size,
    mask_patch_size=config.mask_patch_size,
    model_patch_size=config.model_patch_size,
    mask_ratio=config.mask_ratio,
)

transforms = Compose(
    [
        ToPILImage(),
        Resize((model_config.image_size, model_config.image_size)),
        # torchvision.transforms.RandomHorizontalFlip(),
        # torchvision.transforms.RandomVerticalFlip(),
        ToTensor(),
        Normalize(
            mean=[0.412, 0.368, 0.326], std=[0.110, 0.097, 0.098]
        ),  # our dataset vals
    ]
)

def preprocess_images(x):
    """Preprocess a batch of images by applying transforms + creating a corresponding mask, indicating
    which patches to mask."""
    out = {
        "pixel_values": transforms(x),
        "mask": mask_generator(),
    }
    return out

train_set, test_set = create_SkysatUnlabelled_dataset(
    os.path.join(config.dataset_path, "merged.csv"),
    os.path.join(config.dataset_path, "merged"),
    preprocess_images,
    config.train_val_split,
)

train_dl = DataLoader(train_set, batch_size=config.per_device_train_batch_size)
test_dl = DataLoader(test_set, batch_size=config.per_device_eval_batch_size)

<class 'list'>


In [54]:
data_iterator = iter(test_dl)
batch = next(data_iterator)

In [None]:
batch[0]

In [21]:
config.train_val_split

0.8