In [4]:
!pip install -Uqq datasets optuna

In [18]:
import torch
from torchvision import transforms
from PIL import Image, ImageDraw
import numpy as np

import matplotlib.pyplot as plt

from datasets import load_dataset

from google.colab import userdata, runtime
import subprocess

from torch.utils.data import DataLoader, default_collate
from torchvision.transforms import Resize, Normalize, ToTensor, Compose, transforms, CenterCrop, RandomCrop, RandomChoice
import torch.nn.functional as F
import random

import torch
import random
import colorsys
import math

import pickle
import optuna

hf_token = userdata.get('hf_token')
input_str = f'{hf_token}\nn\n'
result = subprocess.run(['huggingface-cli', 'login'], input=input_str, text=True, capture_output=True)
print(result.stdout)

seed = 1984

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Add token as git credential? (Y/n) Token is valid (permission: write).
Your token has been saved to /root/.cache/huggi

In [22]:
study_path = '/content/drive/MyDrive/Colab_Notebooks/dye_test_opt/ternary/results/synthetic_properties/study.pkl'

with open(study_path, 'rb') as f:
    study = pickle.load(f)
    best_params = study.best_params

class SuperimposeSquare(object):
    def __init__(self, red_hue=0.83, blue_hue=0.45,
                 red_value=0.4, blue_value=0.4,
                 red_saturation=0.4, blue_saturation=0.4,
                 max_opacity=0.3, min_opacity=0.1):
        self.red_hue = red_hue
        self.blue_hue = blue_hue
        self.red_value = red_value
        self.blue_value = blue_value
        self.red_saturation = red_saturation
        self.blue_saturation = blue_saturation
        self.max_opacity = max(0, min(1, max_opacity))
        self.min_opacity = max(0, min(1, min_opacity))

    def __call__(self, tensor):
      image = tensor.unsqueeze(0)
      h, w = image.size()[-2:]

      # Randomly choose between small and large box sizes
      small_box = random.choice([True, False])
      if small_box:
          mask_size = 15
      else:
          mask_size = 77

      color_choice = random.choice(['blue', 'red'])
      if color_choice == 'red':
          hue = self.red_hue
          value = self.red_value
          saturation = self.red_saturation
          label = 2
      else:
          hue = self.blue_hue
          value = self.blue_value
          saturation = self.blue_saturation
          label = 1

      saturation = 1.0  # Full saturation for vivid colors
      color_rgb = colorsys.hsv_to_rgb(hue, saturation, value)
      color_tensor = torch.tensor(color_rgb)

      x = (w - mask_size) // 2  # Centering the square on x-axis
      y = (h - mask_size) // 2  # Centering the square on y-axis

      square = color_tensor.view(3, 1, 1).expand(-1, mask_size, mask_size)
      opacity = random.uniform(self.min_opacity, self.max_opacity)
      square = opacity * square + (1 - opacity) * image[:, :, y:y+mask_size, x:x+mask_size]
      image[:, :, y:y+mask_size, x:x+mask_size] = square

      return image.squeeze(0), label

context_sz = 154

# Load the dataset
ds = load_dataset('mpg-ranch/dye_test', split='train')

# Preprocessing transforms
preprocs = Compose([
    CenterCrop((context_sz, context_sz)),  # saving croping for sub-sample loop
])

def preproc_transforms(examples):
    examples["img"] = [preprocs(image.convert("RGB")) for image in examples["image"]]
    return examples

# Apply the preprocessing transforms
ds = ds.map(preproc_transforms, remove_columns=["image", "color", "size", "concentration"], batched=True, batch_size=len(ds))

batch_size = 8

# Apply the SquareOverlay transform to the batch
transform = transforms.Compose([
    ToTensor()
])

def collate_fun(batch):
    imgs = [transform(item['img']) for item in batch]
    labels = [item['label'] for item in batch]

    new_imgs = []
    new_labels = []

    for img, label in zip(imgs, labels):
        if label == 0:
            transformed_img, new_label = SuperimposeSquare(best_params['red_hue'],
                                                            best_params['blue_hue'],
                                                            best_params['red_value'],
                                                            best_params['blue_value'],
                                                            best_params['red_saturation'],
                                                            best_params['blue_saturation'],
                                                            best_params['max_opacity'],
                                                            best_params['min_opacity']
                                                           )(img)
            new_imgs.append(transformed_img)
            new_labels.append(new_label)
        else:
            new_imgs.append(img)
            new_labels.append(label)

    imgs = torch.stack(new_imgs)
    labels = torch.tensor(new_labels)

    return {'img': imgs, 'label': labels}

# Create a balanced dataset with equal numbers of label 0 and label > 0
label_0_ds = ds.filter(lambda x: x['label'] == 0)

# Now combine the two datasets, making sure there are equal numbers of each
# Assuming we want 8 of each to make a batch of 16 for the plot
balanced_ds = torch.utils.data.ConcatDataset([
    torch.utils.data.Subset(label_0_ds, range(16)),
])

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

dataloader = DataLoader(balanced_ds, batch_size=16, shuffle=True, collate_fn=collate_fun)

colors = ['blue','red']

# Function to plot images in a grid
def plot_images(images, labels, rows=4, cols=4):
    # Increase the size of the subplot for better spacing control
    fig, axes = plt.subplots(rows, cols, figsize=(12, 12))

    # Loop through all the plots in the grid
    for i, ax in enumerate(axes.flat):
        if i < len(images):
            image = images[i].permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
            ax.imshow(image.numpy())  # Show image
            #ax.set_title(f'{colors[labels[i]-1]}')  # Set title with the color label
            ax.axis('off')  # Hide axes

    # Adjust layout to be tighter
    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, wspace=0.05, hspace=0.05)

    # Apply tight layout with reduced padding
    plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)
    plt.show()

# Get a batch of images
dataiter = iter(dataloader)
batch = next(dataiter)
images, labels = batch['img'], batch['label']

# Plot the images
plot_images(images, labels)

Output hidden; open in https://colab.research.google.com to view.