In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
# Modeling
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from torchvision.transforms.functional import to_pil_image

from ecallisto_dataset import (
    EcallistoDatasetBinary,
    CustomSpecAugment,
    custom_resize,
    TimeWarpAugmenter
)
from PIL import Image
import pandas as pd
from io import BytesIO
print(f'PyTorch version {torch.__version__}')
# Check if CUDA is available
if torch.cuda.is_available():
    print(f"GPU is available: {torch.cuda.get_device_name(0)}")
    device = 'cuda'
else:
    print("GPU is not available.")
    device = 'cpu'

In [None]:
from datasets import load_dataset

# Load the dataset with streaming enabled
dataset = load_dataset("i4ds/ecallisto_radio_sunburst", split="test", streaming=True)

# Get the first row
single_row = next(iter(dataset))
print(single_row)


In [3]:
image = Image.open(BytesIO(single_row["image"]["bytes"]))

In [None]:
image

In [None]:
from torchvision.transforms import Resize, ToTensor
image_t = ToTensor()(image)
image_t.shape

In [117]:
image_t.shape

torch.Size([1, 193, 3600])

In [7]:
image_t_img = to_pil_image(image_t)

In [None]:
image_t_img

In [9]:
from ecallisto_dataset import (
    CustomSpecAugment,
    EcallistoDatasetBinary,
    TimeWarpAugmenter,
    custom_resize,
    filter_antennas,
    randomly_reduce_class_samples,
    remove_background,
)

In [66]:
# Transforms
resize_func = Compose(
    [
        lambda x: custom_resize(x, (224, 224)),  # Resize the image
    ]
)
augm_before_resize = TimeWarpAugmenter(1000)
augm_after_resize = CustomSpecAugment(
    frequency_masking_para=1,
    time_masking_para=1,
    method='random',
)

In [None]:
# Transforms

# Data Loader
ds_burst = EcallistoDatasetBinary(
    None,
    resize_func=resize_func,
    augm_before_resize=augm_before_resize,
    augm_after_resize=augm_after_resize,
    normalization_transform=remove_background,
)


test_dataloader = DataLoader(
    ds_burst,
    batch_size=32,
    num_workers=8,
    shuffle=False,
    persistent_workers=False,
)


In [115]:
torch.save(image_t, 'img.torch')

In [112]:
image_aug = ds_burst.augment_image(image_t.squeeze(0))

In [None]:
to_pil_image(image_aug)