In [1]:
import numpy as np
import pandas as pd
import os
import h5py
import cv2
import matplotlib.pyplot as plt
import tqdm
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
#TODO: change this to your own directory
# current_dir = "/Users/thatblue340/Documents/Documents/GitHub/EECS-545-final-project"
current_dir = ""

In [6]:
from vae import CVAE
from torch import nn

## Data Loading
- Load images to train VAE

In [None]:
train_metadata = pd.read_csv(os.path.join(current_dir,'train-metadata.csv'),low_memory=False)   
train_metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 401059 entries, 0 to 401058
Data columns (total 55 columns):
 #   Column                        Non-Null Count   Dtype  
---  ------                        --------------   -----  
 0   isic_id                       401059 non-null  object 
 1   target                        401059 non-null  int64  
 2   patient_id                    401059 non-null  object 
 3   age_approx                    398261 non-null  float64
 4   sex                           389542 non-null  object 
 5   anatom_site_general           395303 non-null  object 
 6   clin_size_long_diam_mm        401059 non-null  float64
 7   image_type                    401059 non-null  object 
 8   tbp_tile_type                 401059 non-null  object 
 9   tbp_lv_A                      401059 non-null  float64
 10  tbp_lv_Aext                   401059 non-null  float64
 11  tbp_lv_B                      401059 non-null  float64
 12  tbp_lv_Bext                   401059 non-nul

In [3]:
test_metadata = pd.read_csv(os.path.join(current_dir,'test-metadata.csv'),low_memory=False)
test_metadata.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3 entries, 0 to 2
Data columns (total 44 columns):
 #   Column                       Non-Null Count  Dtype  
---  ------                       --------------  -----  
 0   isic_id                      3 non-null      object 
 1   patient_id                   3 non-null      object 
 2   age_approx                   3 non-null      float64
 3   sex                          3 non-null      object 
 4   anatom_site_general          3 non-null      object 
 5   clin_size_long_diam_mm       3 non-null      float64
 6   image_type                   3 non-null      object 
 7   tbp_tile_type                3 non-null      object 
 8   tbp_lv_A                     3 non-null      float64
 9   tbp_lv_Aext                  3 non-null      float64
 10  tbp_lv_B                     3 non-null      float64
 11  tbp_lv_Bext                  3 non-null      float64
 12  tbp_lv_C                     3 non-null      float64
 13  tbp_lv_Cext             

In [4]:
# training_validation_hdf5 = h5py.File(f"{current_dir}/train-image.hdf5", 'r')
# testing_hdf5 = h5py.File(f"{current_dir}/test-image.hdf5", 'r')
training_validation_hdf5 = h5py.File("train-image.hdf5", 'r')
testing_hdf5 = h5py.File(f"test-image.hdf5", 'r')

## Preprocess data
- Only take the malignant images
- Resize them to (128, 128, 3)
- Normalize pixel values to [0,1]

In [5]:
# import training images 
train_images = []
for i in tqdm.tqdm(range(len(train_metadata))):
    if train_metadata.iloc[i]['target'] == 0: # skip non-target images
        continue
    image_id = train_metadata.iloc[i]['isic_id']
    image = training_validation_hdf5[image_id][()]
    image = np.frombuffer(image, dtype=np.uint8)
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    image = cv2.resize(image, (128, 128))
    image = image / 255
    train_images.append(image)
train_images = np.array(train_images)


print(f"Training images shape: {train_images.shape}")

100%|██████████| 401059/401059 [00:09<00:00, 42256.45it/s]


Training images shape: (393, 128, 128, 3)


In [6]:
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image

In [7]:
# Assumes train_images is a NumPy array: (N, 128, 128, 3)
os.makedirs("augmented_images", exist_ok=True)

augmentation = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])

# Apply 5 augmentations per image
n_augments = 5

for idx, img in enumerate(train_images):
    img = (img * 255).astype("uint8")  # convert back to uint8 if needed
    for j in range(n_augments):
        aug_img = augmentation(img)
        save_img = to_pil_image(aug_img)
        save_img.save(f"augmented_images/img_{idx:03}_aug{j:02}.png")

In [8]:
import cv2

def jigsaw_augment(img, grid_size=4):
    h, w, _ = img.shape
    patch_h, patch_w = h // grid_size, w // grid_size
    patches = []

    for i in range(grid_size):
        for j in range(grid_size):
            patch = img[i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w, :]
            patches.append(patch)

    np.random.shuffle(patches)

    new_img = np.zeros_like(img)
    idx = 0
    for i in range(grid_size):
        for j in range(grid_size):
            new_img[i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w, :] = patches[idx]
            idx += 1

    return new_img

In [9]:
os.makedirs("jigsaw_images", exist_ok=True)

for idx, img in enumerate(train_images):
    img_uint8 = (img * 255).astype(np.uint8)
    for j in range(3):  # 3 jigsaw variations per image
        jigsawed = jigsaw_augment(img_uint8, grid_size=4)
        cv2.imwrite(f"jigsaw_images/img_{idx:03}_jigsaw{j:02}.png", jigsawed[..., ::-1])  # RGB → BGR