In [6]:
import json

import PIL.Image
from matplotlib import pyplot as plt
import os
import torch
import torchvision
from prototypes.classical.descriptors.vetorizer import GaborAttentionLBPVectors
from prototypes.deeplearning.dataloader.IsicDataLoader import LoadDataVectors, LoadPreProcessVectors
import albumentations as A
%load_ext autoreload
%autoreload 2

with open("../config.json", "r") as f:
    config = json.load(f)

# Augmentation

In [7]:
#Augmentation per sample
from PIL import Image
import numpy as np

class Augmentation():
    def __init__(self, augmentation_transform):
        self.augmentation_transform = augmentation_transform

    def __call__(self, sample):        
        return Image.fromarray(self.augmentation_transform(image=np.array(sample))["image"])

augmentation_transform = A.Compose([
    A.CLAHE(p=0.4),
    A.RandomRotate90(p=0.7),
    A.Transpose(p=0.6),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
    A.Blur(blur_limit=3),
    A.OpticalDistortion(p=0.5),
    A.GridDistortion(p=0.5),
    A.HueSaturationValue(p=0.5),
    # Vit transform
    # A.Resize(224, 224),
    # A.ToFloat(always_apply=True),
    # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

transform_augmentation = torchvision.transforms.Compose([Augmentation(augmentation_transform=augmentation_transform),
                                                          torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1.transforms()])

In [8]:
dataloader = LoadDataVectors(hd5_file_path=os.path.join(config["DATASET_PATH"], "train-image.hdf5"),
                               metadata_csv_path=os.path.join(config["DATASET_PATH"], "train-metadata.csv"),
                               metadata_columns=config["METADATA_COLUMNS"].split("\t"),
                               transform=transform_augmentation)

In [9]:
train, val = torch.utils.data.random_split(dataloader,
                                           [config["TRAIN_SPLIT"], 1 - config["TRAIN_SPLIT"]])

In [10]:
# val.transforms = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1.transforms()
feature_vector, metadata, target = next(iter(train))

In [11]:
len(metadata)

In [None]:
plt.imshow((feature_vector.transpose(0, 2).numpy()))

In [None]:
next(iter(val))

In [11]:
val.dataset.transform = transform_augmentation

feature_vector = next(iter(val))

plt.imshow((feature_vector.transpose(0, 2).numpy()))

In [None]:
vector_dataloader = torch.utils.data.DataLoader(train, batch_size=8, shuffle=True, num_workers=8)

In [None]:
next(iter(vector_dataloader))[0][0].numpy()

In [None]:
preprocess_vectors = LoadPreProcessVectors(dataset_base_path="../feature_vectors", feature_name="gabor_attention_maps", target_index=[0], dimensions=128)

In [None]:
x, y = next(iter(preprocess_vectors))

In [None]:
x.shape, y.shape

In [None]:
y