In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os

while not os.getcwd().endswith('image_drift_generator'):
    os.chdir('..')
    print(os.getcwd())
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
import torch
import torchvision.transforms.v2 as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import polars as pl

from image_drift_generator.image_generator import ImageDatasetGenerator
from image_drift_generator.image_factory import *
import warnings
from torchvision import models
import torch.nn as nn

%matplotlib inline
warnings.filterwarnings('ignore')
# from alibi_detect.cd import MMDDrift
# set tight layout
plt.rcParams.update({'figure.autolayout': True})

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
seed = 42
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

torch.set_default_device(device)
print(f'Supported device: {device}.')


Supported device: mps.


In [None]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
batch_size = 32

In [None]:
tensor_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        # transforms.Normalize(mean=mean, std=std)
    ]
)

In [None]:
img_path = '../../data/places/'

In [12]:
image_generator = ImageDatasetGenerator(
    seed=42,
    input_path=img_path,
)

In [None]:
all_transform = [
    TransformInfo(
        transf_type=ImageTransform.ROTATE,
        drift_level=0.9,
    ),
    TransformInfo(
        transf_type=ImageTransform.BRIGHTNESS,
        drift_level=0.3,
    ),
    TransformInfo(
        transf_type=ImageTransform.CONTRAST,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.HUE,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.SATURATION,
        drift_level=0.5,
    ),
    TransformInfo(
        transf_type=ImageTransform.GAUSSIAN_BLUR,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.GAUSSIAN_NOISE,
        drift_level=1.0,
    ),
]

In [15]:
image_generator.add_abrupt_drift(
    drift_target=DriftTarget.INPUT,
    input_drift_type=InputDriftType.IMAGE_AUGMENTATION,
    transform_list=all_transform,
)

In [16]:
image_generator.transform_pipeline

Compose(
      RandomRotation(degrees=[-81.0, 81.0], interpolation=InterpolationMode.NEAREST, expand=False, fill=0)
      ColorJitter(brightness=(1.5, 1.5))
      ColorJitter(contrast=(0.5, 0.5))
      ColorJitter(hue=(0.05, 0.05))
      ColorJitter(saturation=(7.5, 7.5))
      GaussianBlur(kernel_size=(5, 5), sigma=[0.30000000000000004, 0.30000000000000004])
      GaussianNoise(mean=0.1, sigma=0.5, clip=True)
)

In [None]:
data_category = image_generator.sample(
    32, output_path='../../data/places/places_generated/'
)

Archive created successfully: ../../data/places/places_generated/sampled_images.zip
Original folder '../../data/places/places_generated/sampled_images' deleted after zipping.


In [18]:
print(data_category.target)

shape: (32, 3)
┌───────────┬───────────┬───────┐
│ sample-id ┆ timestamp ┆ label │
│ ---       ┆ ---       ┆ ---   │
│ i64       ┆ f64       ┆ i64   │
╞═══════════╪═══════════╪═══════╡
│ 0         ┆ 1.7172e9  ┆ 1     │
│ 1         ┆ 1.7172e9  ┆ 1     │
│ 2         ┆ 1.7172e9  ┆ 0     │
│ 3         ┆ 1.7172e9  ┆ 1     │
│ 4         ┆ 1.7172e9  ┆ 0     │
│ …         ┆ …         ┆ …     │
│ 27        ┆ 1.7172e9  ┆ 0     │
│ 28        ┆ 1.7172e9  ┆ 0     │
│ 29        ┆ 1.7172e9  ┆ 0     │
│ 30        ┆ 1.7172e9  ┆ 1     │
│ 31        ┆ 1.7172e9  ┆ 1     │
└───────────┴───────────┴───────┘


In [19]:
print(data_category.input_mapping)

shape: (32, 3)
┌───────────┬───────────┬─────────────────────────────────┐
│ sample-id ┆ timestamp ┆ file_name                       │
│ ---       ┆ ---       ┆ ---                             │
│ i64       ┆ f64       ┆ str                             │
╞═══════════╪═══════════╪═════════════════════════════════╡
│ 0         ┆ 1.7172e9  ┆ e6c1313a-589c-4afc-9cb1-d36a95… │
│ 1         ┆ 1.7172e9  ┆ c8bc7a53-5aa1-4706-82c4-944d11… │
│ 2         ┆ 1.7172e9  ┆ a3bbe67a-2d11-4e83-8206-df016d… │
│ 3         ┆ 1.7172e9  ┆ ae1cc5c1-7239-4fb9-90d2-25b74f… │
│ 4         ┆ 1.7172e9  ┆ 334df5a3-e337-40c8-bffa-475771… │
│ …         ┆ …         ┆ …                               │
│ 27        ┆ 1.7172e9  ┆ 5760139c-29da-4a81-b43c-4921df… │
│ 28        ┆ 1.7172e9  ┆ 9e4bf50b-81b8-4b34-91ad-5b353a… │
│ 29        ┆ 1.7172e9  ┆ fa94bc39-5570-4a8c-a30b-1cae85… │
│ 30        ┆ 1.7172e9  ┆ f039a178-6c92-48eb-ad97-ec7673… │
│ 31        ┆ 1.7172e9  ┆ 9f50d228-f6bf-4b9a-a445-55e97c… │
└───────────┴───────────┴

In [None]:
print(data_category.input_folder_type)
print(data_category.input_folder)
print(data_category.input_file_type)
print(data_category.is_input_folder)
print(data_category.is_target_folder)


FolderType.ZIP
/Users/cristian/Personal/data/places/places_generated/sampled_images.zip
png
True
False


In [21]:
columns = image_generator.get_dataschema()
for col in columns:
    print(col)

('columns', [ColumnInfo(name='timestamp', role=<ColumnRole.TIME_ID: 'time_id'>, is_nullable=False, data_type=<DataType.FLOAT: 'float'>, predicted_target=None, possible_values=None, model_id=None, dims=None, classes_names=None, subrole=None, image_mode=None), ColumnInfo(name='sample-id', role=<ColumnRole.ID: 'id'>, is_nullable=False, data_type=<DataType.STRING: 'string'>, predicted_target=None, possible_values=None, model_id=None, dims=None, classes_names=None, subrole=None, image_mode=None), ColumnInfo(name='label', role=<ColumnRole.TARGET: 'target'>, is_nullable=False, data_type=<DataType.CATEGORICAL: 'categorical'>, predicted_target=None, possible_values=[0, 1], model_id=None, dims=None, classes_names=None, subrole=None, image_mode=None), ColumnInfo(name='image', role=<ColumnRole.INPUT: 'input'>, is_nullable=False, data_type=<DataType.ARRAY_3: 'array_3'>, predicted_target=None, possible_values=None, model_id=None, dims=(224, 224, 3), classes_names=None, subrole=None, image_mode=<Imag