In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
while not os.getcwd().endswith("image-drift-generator"):
    os.chdir("..")
    print(os.getcwd())
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
import torch
import torchvision.transforms.v2 as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import polars as pl

from scripts.image_generator import ImageDatasetGenerator
from scripts.image_factory import *
import warnings
from torchvision import models
import torch.nn as nn
%matplotlib inline
warnings.filterwarnings('ignore')
# from alibi_detect.cd import MMDDrift
# set tight layout
plt.rcParams.update({'figure.autolayout': True})

/Users/cristian/Personal/Projects/image-drift-generator


In [2]:
seed = 42
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

torch.set_default_device(device)
print(f"Supported device: {device}.")


Supported device: mps.


In [3]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
batch_size = 32

In [4]:
tensor_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=mean, std=std)
])

In [9]:
img_path = "../../data/places/"

In [11]:
image_generator = ImageDatasetGenerator(
    seed=42,
    input_path=img_path,
)

In [12]:
all_transform = [
    TransformInfo(
        transf_type=ImageTransform.ROTATE,
        drift_level=0.9,
    ),
    TransformInfo(
        transf_type=ImageTransform.BRIGHTNESS,
        drift_level=0.3,
    ),
    TransformInfo(
        transf_type=ImageTransform.CONTRAST,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.HUE,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.SATURATION,
        drift_level=0.5,
    ),
    TransformInfo(
        transf_type=ImageTransform.GAUSSIAN_BLUR,
        drift_level=0.1,
    ),
    TransformInfo(
        transf_type=ImageTransform.GAUSSIAN_NOISE,
        drift_level=1.,
    )
]

In [13]:
image_generator.add_abrupt_drift(
    drift_target=DriftTarget.INPUT,
    input_drift_type=InputDriftType.IMAGE_AUGMENTATION,
    transform_list=all_transform,
)

In [14]:
image_generator.transform_pipeline

Compose(
      RandomRotation(degrees=[-81.0, 81.0], interpolation=InterpolationMode.NEAREST, expand=False, fill=0)
      ColorJitter(brightness=(0.0, 2.5))
      ColorJitter(contrast=(0.5, 1.5))
      ColorJitter(hue=(-0.05, 0.05))
      ColorJitter(saturation=(0.0, 8.5))
      GaussianBlur(kernel_size=(5, 5), sigma=[0.30000000000000004, 0.30000000000000004])
      GaussianNoise(mean=0.1, sigma=0.5, clip=True)
)

In [20]:
data_category = image_generator.sample(32, output_path="../../data/places/places_generated/")

Archive created successfully: ../../data/places/places_generated/sampled_images.zip
Original folder '../../data/places/places_generated/sampled_images' deleted after zipping.


In [21]:
print(data_category.target)

shape: (32, 3)
┌───────────┬───────────┬───────┐
│ sample-id ┆ timestamp ┆ label │
│ ---       ┆ ---       ┆ ---   │
│ i64       ┆ f64       ┆ i64   │
╞═══════════╪═══════════╪═══════╡
│ 32        ┆ 1.7172e9  ┆ 0     │
│ 33        ┆ 1.7172e9  ┆ 0     │
│ 34        ┆ 1.7172e9  ┆ 1     │
│ 35        ┆ 1.7172e9  ┆ 0     │
│ 36        ┆ 1.7172e9  ┆ 1     │
│ …         ┆ …         ┆ …     │
│ 59        ┆ 1.7172e9  ┆ 0     │
│ 60        ┆ 1.7172e9  ┆ 0     │
│ 61        ┆ 1.7172e9  ┆ 0     │
│ 62        ┆ 1.7172e9  ┆ 0     │
│ 63        ┆ 1.7172e9  ┆ 1     │
└───────────┴───────────┴───────┘


In [22]:
print(data_category.input_mapping)

shape: (32, 3)
┌───────────┬───────────┬─────────────────────────────────┐
│ sample-id ┆ timestamp ┆ file_name                       │
│ ---       ┆ ---       ┆ ---                             │
│ i64       ┆ f64       ┆ str                             │
╞═══════════╪═══════════╪═════════════════════════════════╡
│ 32        ┆ 1.7172e9  ┆ 707ae2b2-50ab-4a27-88e6-5f9f0a… │
│ 33        ┆ 1.7172e9  ┆ c4cff463-2b48-4d0d-864a-81a809… │
│ 34        ┆ 1.7172e9  ┆ 6b933fe9-c8c6-4c31-b082-8883ad… │
│ 35        ┆ 1.7172e9  ┆ a18a13ad-61b1-4afc-bd65-591389… │
│ 36        ┆ 1.7172e9  ┆ edb0b297-537a-4ce1-a163-0a1e8a… │
│ …         ┆ …         ┆ …                               │
│ 59        ┆ 1.7172e9  ┆ 6961a9bd-bcc6-460b-8ea4-2d17ca… │
│ 60        ┆ 1.7172e9  ┆ ee19bf8d-675f-49d5-a2a9-7c5c33… │
│ 61        ┆ 1.7172e9  ┆ ba635c98-ed00-4b2e-8fcd-2d1013… │
│ 62        ┆ 1.7172e9  ┆ db3c59ed-b045-450c-8794-a16023… │
│ 63        ┆ 1.7172e9  ┆ 66eb567d-e77d-4235-a179-8961b7… │
└───────────┴───────────┴

In [23]:

print(data_category.input_folder_type)
print(data_category.input_folder)
print(data_category.input_file_type)
print(data_category.is_input_folder)
print(data_category.is_target_folder)


FolderType.ZIP
/Users/cristian/Personal/data/places/places_generated/sampled_images.zip
png
True
False


In [24]:
columns = image_generator.get_dataschema()
for col in columns:
    print(col)

('columns', [ColumnInfo(name='timestamp', role=<ColumnRole.TIME_ID: 'time_id'>, is_nullable=False, data_type=<DataType.FLOAT: 'float'>, predicted_target=None, possible_values=None, model_id=None, dims=None, classes_names=None, subrole=None, image_mode=None), ColumnInfo(name='sample-id', role=<ColumnRole.ID: 'id'>, is_nullable=False, data_type=<DataType.STRING: 'string'>, predicted_target=None, possible_values=None, model_id=None, dims=None, classes_names=None, subrole=None, image_mode=None), ColumnInfo(name='label', role=<ColumnRole.TARGET: 'target'>, is_nullable=False, data_type=<DataType.CATEGORICAL: 'categorical'>, predicted_target=None, possible_values=[0, 1], model_id=None, dims=None, classes_names=None, subrole=None, image_mode=None), ColumnInfo(name='image', role=<ColumnRole.INPUT: 'input'>, is_nullable=False, data_type=<DataType.ARRAY_3: 'array_3'>, predicted_target=None, possible_values=None, model_id=None, dims=(224, 224, 3), classes_names=None, subrole=None, image_mode=<Imag