In [12]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
import torch
import torchvision.transforms.v2 as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import polars as pl

from scripts.image_generator import ImageDatasetGenerator
from scripts.image_factory import *
import warnings
from torchvision import models
import torch.nn as nn
%matplotlib inline
warnings.filterwarnings('ignore')
# from alibi_detect.cd import MMDDrift
# set tight layout
plt.rcParams.update({'figure.autolayout': True})




In [13]:
seed = 42
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
print(f"Supported device: {device}.")
g_seed = torch.Generator(device=device)
g_seed.manual_seed(seed);

Supported device: cpu.


In [14]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
batch_size = 32

In [15]:
tensor_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=mean, std=std)
])

In [16]:
img_path = "../../data/places/"
dataset = ImageFolder(
    root=img_path,
    transform=tensor_transform
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, generator=g_seed if device.type == 'cuda' else None)

In [17]:
image_generator = ImageDatasetGenerator(
    seed=42,
    input_path=img_path,
    output_path="../../data/places_generated/",
)

In [31]:
all_transform = [
    TransformInfo(
        transf_type=ImageTransform.ROTATE,
        drift_level=0.9,
    ),
    TransformInfo(
        transf_type=ImageTransform.BRIGHTNESS,
        drift_level=0.3,
    ),
    TransformInfo(
        transf_type=ImageTransform.CONTRAST,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.HUE,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.SATURATION,
        drift_level=0.5,
    ),
    TransformInfo(
        transf_type=ImageTransform.GAUSSIAN_BLUR,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.GAUSSIAN_NOISE,
        drift_level=1.,
    )
]

In [32]:
image_generator.add_abrupt_drift(
    drift_target=DriftTarget.INPUT,
    input_drift_type=InputDriftType.IMAGE_AUGMENTATION,
    transform_list=all_transform,
)

In [33]:
image_generator.transform_pipeline

Compose(
      RandomRotation(degrees=[-81.0, 81.0], interpolation=InterpolationMode.NEAREST, expand=False, fill=0)
      ColorJitter(brightness=(0.0, 2.5))
      ColorJitter(contrast=(0.5, 1.5))
      ColorJitter(hue=(-0.05, 0.05))
      ColorJitter(saturation=(0.0, 8.5))
      GaussianBlur(kernel_size=(5, 5), sigma=[0.30000000000000004, 0.30000000000000004])
      GaussianNoise(mean=0.1, sigma=0.5, clip=True)
)

In [34]:
data_category = image_generator.sample(32)

Archive created successfully: ../../data/places_generated/sampled_images.zip
Original folder '../../data/places_generated/sampled_images' deleted after zipping.


In [35]:
print(data_category.target)

shape: (32, 3)
┌───────────┬───────────┬───────┐
│ sample-id ┆ timestamp ┆ label │
│ ---       ┆ ---       ┆ ---   │
│ i64       ┆ f64       ┆ i64   │
╞═══════════╪═══════════╪═══════╡
│ 64        ┆ 1.7285e9  ┆ 1     │
│ 65        ┆ 1.7285e9  ┆ 1     │
│ 66        ┆ 1.7285e9  ┆ 1     │
│ 67        ┆ 1.7285e9  ┆ 0     │
│ 68        ┆ 1.7285e9  ┆ 1     │
│ …         ┆ …         ┆ …     │
│ 91        ┆ 1.7285e9  ┆ 1     │
│ 92        ┆ 1.7285e9  ┆ 1     │
│ 93        ┆ 1.7285e9  ┆ 0     │
│ 94        ┆ 1.7285e9  ┆ 1     │
│ 95        ┆ 1.7285e9  ┆ 0     │
└───────────┴───────────┴───────┘


In [36]:
print(data_category.input_mapping)

shape: (32, 3)
┌───────────┬───────────┬─────────────────────────────────┐
│ sample-id ┆ timestamp ┆ file_name                       │
│ ---       ┆ ---       ┆ ---                             │
│ i64       ┆ f64       ┆ str                             │
╞═══════════╪═══════════╪═════════════════════════════════╡
│ 64        ┆ 1.7285e9  ┆ f7d4225e-6f85-4809-a03a-835852… │
│ 65        ┆ 1.7285e9  ┆ 4f59436e-aa0a-4eb9-b705-e1da28… │
│ 66        ┆ 1.7285e9  ┆ 67ce62c4-d3ff-4fb9-85d2-3bf5b3… │
│ 67        ┆ 1.7285e9  ┆ 7f06874e-6399-4f13-b27f-d2c7f6… │
│ 68        ┆ 1.7285e9  ┆ 766f12f9-dceb-4f63-85f3-49ad78… │
│ …         ┆ …         ┆ …                               │
│ 91        ┆ 1.7285e9  ┆ 18270172-2a62-4914-b249-40ab16… │
│ 92        ┆ 1.7285e9  ┆ 151c12a2-1902-44aa-962d-e3f350… │
│ 93        ┆ 1.7285e9  ┆ 6cfacf0d-c400-44fa-8b01-ca2389… │
│ 94        ┆ 1.7285e9  ┆ 9c5ef26d-a4f8-4471-bd67-f3c40d… │
│ 95        ┆ 1.7285e9  ┆ c9285567-7113-4357-935f-41c85e… │
└───────────┴───────────┴

In [37]:

print(data_category.input_folder_type)
print(data_category.input_folder)
print(data_category.input_file_type)
print(data_category.is_input_folder)
print(data_category.is_target_folder)


FolderType.ZIP
../../data/places_generated/sampled_images
png
True
False


In [23]:
columns = image_generator.get_dataschema()
for col in columns:
    print(col)

name='timestamp' role=<ColumnRole.TIME_ID: 'time_id'> is_nullable=False data_type=<DataType.FLOAT: 'float'> predicted_target=None possible_values=None model_id=None dims=None classes_names=None subrole=None image_mode=None
name='sample-id' role=<ColumnRole.ID: 'id'> is_nullable=False data_type=<DataType.STRING: 'string'> predicted_target=None possible_values=None model_id=None dims=None classes_names=None subrole=None image_mode=None
name='label' role=<ColumnRole.TARGET: 'target'> is_nullable=False data_type=<DataType.CATEGORICAL: 'categorical'> predicted_target=None possible_values=[0, 1] model_id=None dims=None classes_names=None subrole=None image_mode=None
name='image' role=<ColumnRole.INPUT: 'input'> is_nullable=False data_type=<DataType.ARRAY_3: 'array_3'> predicted_target=None possible_values=None model_id=None dims=(224, 224, 3) classes_names=None subrole=None image_mode=<ImageMode.RGB: 'rgb'>
