# SPARK Dataset


In [12]:
from spark_utils import PyTorchSparkDataset
from matplotlib import pyplot as plt
from random import randint

In [13]:
import torch
import torchvision

# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
# some APIs may slightly change in the future
torchvision.disable_beta_transforms_warning()

In [14]:
# Wrap a PyTorchSparkDataset dataset for usage with torchvision.transforms.v2

class PyTorchSparkDatasetV2(torch.utils.data.Dataset):

    def __init__(self, class_map, split='train', root_dir='', transform=None, detection=True):
        super().__init__()
        self.dataset = PyTorchSparkDataset(class_map, split=split, root_dir=root_dir, transform=transform, detection=True)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label, bbox = self.dataset[idx]

        image = torchvision.tv_tensors.Image(image)
        label = torch.tensor([label])
        bbox = torchvision.tv_tensors.BoundingBoxes(bbox,
                                                    format = torchvision.tv_tensors.BoundingBoxFormat.XYXY, 
                                                    canvas_size=image.size)

        target = {'boxes': bbox, 'labels': label}
        return image, target

In [15]:
from torchvision.transforms import v2 as T
from torchvision import models, datasets, tv_tensors

def get_transform(train):
    transforms = []
    transforms.append(T.ToImage())

    if train:
        transforms.append(T.RandomPhotometricDistort(p = 0.5))
        transforms.append(T.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}))
        # transforms.append(T.RandomIoUCrop())
        transforms.append(T.RandomHorizontalFlip(p = 0.5))
        # transforms.append(T.SanitizeBoundingBoxes())

    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    
    return T.Compose(transforms)


In [16]:
# Set up the path to a local copy of the SPARK dataset, labels csv files should be in the same directory.
# The image sets should be in /data/train, /data/validation and /data/test.

data_path = './data/'

class_map = {'proba_2':1, 'cheops':2, 'debris':3, 'double_star':4, 'earth_observation_sat_1':5, 'lisa_pathfinder':6,
                        'proba_3_csc' :7, 'proba_3_ocs':8, 'smart_1':9, 'soho':10, 'xmm_newton':11}

num_classes = len(class_map)

def get_dataset(is_train, class_map, data_path):

    split = "train" if is_train else "validation"

    dataset = PyTorchSparkDatasetV2(class_map, split=split, root_dir=data_path, transform=get_transform(is_train))

    return dataset

In [17]:
dataset = get_dataset(is_train=True, class_map=class_map, data_path=data_path)

dataset_valid = get_dataset(is_train=False, class_map=class_map, data_path=data_path)


In [18]:
batch_size = 10

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    collate_fn=lambda batch: tuple(zip(*batch)),
)

data_loader_valid = torch.utils.data.DataLoader(
    dataset_valid,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=lambda batch: tuple(zip(*batch)),
)

In [19]:
# Check dataset format for debugging purposes

sample = dataset[0]
image, target = sample
print(type(image))
print(type(target), list(target.keys()))
print(type(target["boxes"]), type(target["labels"]))

<class 'torchvision.tv_tensors._image.Image'>
<class 'dict'> ['boxes', 'labels']
<class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'> <class 'torch.Tensor'>


In [20]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

In [21]:
from engine import train_one_epoch, evaluate

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# get the model using our helper function
model = get_model_instance_segmentation(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it for 5 epochs
num_epochs = 5

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_valid, device=device)

print("That's it!")

In [9]:
# model = models.get_model("fasterrcnn_resnet50_fpn_v2", weights='COCO_V1', weights_backbone=None).train()

# for imgs, targets in data_loader:
#     loss_dict = model(imgs, targets)
#     # Put your training logic here

#     print(f"{[img.shape for img in imgs] = }")
#     print(f"{[type(target) for target in targets] = }")
#     for name, loss_val in loss_dict.items():
#         print(f"{name:<20}{loss_val:.3f}")

In [10]:
# !train.py\
#     --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
#     --lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1

In [11]:
# rows = 3
# cols = 4

# fig, axes = plt.subplots(rows, cols, figsize=(15, 15))
# total_images=9 # total number of images in the split

# for i in range(rows):
#     for j in range(cols):
#         dataset.visualize(randint(0, total_images),size = (10,10),ax=axes[i][j])
#         axes[i][j].axis('off')
# fig.tight_layout()