In [1]:
!pip uninstall -y jax jaxlib
!pip install -Uqq datasets optuna kaleido

from google.colab import userdata, runtime
import subprocess

hf_token = userdata.get('hf_token')
input_str = f'{hf_token}\nn\n'
result = subprocess.run(['huggingface-cli', 'login'], input=input_str, text=True, capture_output=True)
print(result.stdout)

Found existing installation: jax 0.4.26
Uninstalling jax-0.4.26:
  Successfully uninstalled jax-0.4.26
Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
  Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m380.1/380.1 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K 

In [None]:
from functools import reduce
import os
import random
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, default_collate
from torchvision.transforms import Resize, Normalize, ToTensor, Compose, transforms, CenterCrop, RandomCrop, RandomChoice
from torchvision.transforms.v2 import CutMix
from sklearn.metrics import f1_score
import optuna
import pickle
import json
import colorsys
import math

from datasets import load_dataset, concatenate_datasets

seed = 1984

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

focus = 'epoch_count'
results_dir = f'/content/drive/MyDrive/Colab_Notebooks/dye_test_opt/ternary/results/{focus}'
os.makedirs(results_dir, exist_ok=True)

studies = ['learning', 'synthetic_properties', 'augs']
best_params = {}

for study in studies:
    study_path = results_dir.replace(f'{focus}', f'{study}/study.pkl')
    with open(study_path, 'rb') as f:
        study = pickle.load(f)
    best_trial = study.best_trial
    best_params.update(best_trial.params)

# Extract the learning rate and batch size
lr = best_params['lr']
bs = best_params['bs']
print(f'Learning rate: {lr}')
print(f'Batch size: {bs}')

best_params['synthetic_prob'] = 0.0 #remove synthetic image generation

# Load dataset
ds = load_dataset('mpg-ranch/dye_test', split='train')

# Dataset preparation
total_samples = len(ds)
n_workers = os.cpu_count()
print(f'Number of available CPU cores: {n_workers}')
n_epochs = 50
context_sz = 154
canvas_sz = context_sz + 14*6 # 1x1 meter context

imagenet_stats = {'mean': [0.485, 0.456, 0.406],
                  'std': [0.229, 0.224, 0.225]}

# Preprocessing transforms
preprocs = Compose([
    CenterCrop((canvas_sz, canvas_sz))
])

def preproc_transforms(examples):
    examples["img"] = [preprocs(image.convert("RGB")) for image in examples["image"]]
    return examples

print("Applying preprocessing transforms...")
ds = ds.map(preproc_transforms, remove_columns=["image","color","size","concentration"], batched=True, batch_size=len(ds))
n_classes = len(np.unique(ds['label']))

# Define model loading function
def load_model(arch, n_classes):
    print("Loading model...")
    model = torch.hub.load('facebookresearch/dinov2', arch)
    num_ftrs = model.norm.normalized_shape[0]
    model.head = nn.Linear(num_ftrs, n_classes)
    model.to(device)
    return model

def modified_f1_score(labels, predictions):
    scores = []

    # For 1 vs. 0, exclude class 2 and consider only 1 and 0 for binary comparison
    mask_1_vs_0 = (labels != 2)
    labels_1_vs_0 = labels[mask_1_vs_0] == 1
    predictions_1_vs_0 = predictions[mask_1_vs_0] == 1
    if np.any(labels_1_vs_0) or np.any(predictions_1_vs_0):
        f1_1_vs_0 = f1_score(labels_1_vs_0, predictions_1_vs_0, pos_label=True, average='binary')
        scores.append(f1_1_vs_0)

    # For 2 vs. 0, exclude class 1 and consider only 2 and 0 for binary comparison
    mask_2_vs_0 = (labels != 1)
    labels_2_vs_0 = labels[mask_2_vs_0] == 2
    predictions_2_vs_0 = predictions[mask_2_vs_0] == 2
    if np.any(labels_2_vs_0) or np.any(predictions_2_vs_0):
        f1_2_vs_0 = f1_score(labels_2_vs_0, predictions_2_vs_0, pos_label=True, average='binary')
        scores.append(f1_2_vs_0)

    # Calculate the mean of the F1 scores if any valid scores were calculated
    mean_f1 = np.mean(scores) if scores else 0.0

    return f1_1_vs_0, f1_2_vs_0, mean_f1

class SuperimposeSquare(object):
    def __init__(self, red_hue=0.83, blue_hue=0.45,
                 red_value=0.4, blue_value=0.4,
                 red_saturation=0.4, blue_saturation=0.4,
                 max_opacity=0.3, min_opacity=0.1):
        self.red_hue = red_hue
        self.blue_hue = blue_hue
        self.red_value = red_value
        self.blue_value = blue_value
        self.red_saturation = red_saturation
        self.blue_saturation = blue_saturation
        self.max_opacity = max(0, min(1, max_opacity))
        self.min_opacity = max(0, min(1, min_opacity))

    def __call__(self, tensor):
      image = tensor.unsqueeze(0)
      h, w = image.size()[-2:]

      # Randomly choose between small and large box sizes
      small_box = random.choice([True, False])
      if small_box:
          mask_size = 15
      else:
          mask_size = 77

      color_choice = random.choice(['blue', 'red'])
      if color_choice == 'red':
          hue = self.red_hue
          value = self.red_value
          saturation = self.red_saturation
          label = 2
      else:
          hue = self.blue_hue
          value = self.blue_value
          saturation = self.blue_saturation
          label = 1

      saturation = 1.0  # Full saturation for vivid colors
      color_rgb = colorsys.hsv_to_rgb(hue, saturation, value)
      color_tensor = torch.tensor(color_rgb)

      x = (w - mask_size) // 2  # Centering the square on x-axis
      y = (h - mask_size) // 2  # Centering the square on y-axis

      square = color_tensor.view(3, 1, 1).expand(-1, mask_size, mask_size)
      opacity = random.uniform(self.min_opacity, self.max_opacity)
      square = opacity * square + (1 - opacity) * image[:, :, y:y+mask_size, x:x+mask_size]
      image[:, :, y:y+mask_size, x:x+mask_size] = square

      return image.squeeze(0), label

# Hyperparameters for augmentations
synthetic_prob = best_params['synthetic_prob']
random_crop_prob = best_params['random_crop_prob']
random_horizontal_flip = best_params['random_horizontal_flip']
random_vertical_flip = best_params['random_vertical_flip']
random_rotation = best_params['random_rotation']
brightness = best_params['brightness']
contrast = best_params['contrast']
saturation = best_params['saturation']
hue = best_params['hue']

# Train transforms
train_tfms = Compose([
    RandomChoice([RandomCrop(size=context_sz), CenterCrop(context_sz)], p=[random_crop_prob, 1 - random_crop_prob]),
    transforms.RandomHorizontalFlip(p=0.5 if random_horizontal_flip else 0),
    transforms.RandomVerticalFlip(p=0.5 if random_vertical_flip else 0),
    transforms.RandomRotation(random_rotation),
    transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue),
    #ToTensor applied in collate_fn
    Normalize(mean=imagenet_stats['mean'], std=imagenet_stats['std'])  # Normalize using ImageNet stats
])

val_tfms = Compose([
    CenterCrop(size=context_sz),
    ToTensor(),
    Normalize(mean=imagenet_stats['mean'], std=imagenet_stats['std'])  # Normalize using ImageNet stats
])

def batch_tfms_val(examples):
    examples["img"] = [val_tfms(image) for image in examples["img"]]
    return examples

f1_scores = {epoch: [] for epoch in range(1, n_epochs + 1)}

# Seed loop
print(f"Seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def create_fold_mapping(unique_ids, n_folds):
    fold_numbers = np.random.choice(n_folds, size=len(unique_ids), replace=True)
    return dict(zip(unique_ids, fold_numbers))

def assign_fold(batch, fold_mapping):
    batch['fold'] = [fold_mapping[idx] for idx in batch['idx']]
    return batch

# Assuming 'ds' is your dataset and it has been shuffled
unique_ids = ds.unique('idx')  # Get unique ids
fold_mapping = create_fold_mapping(unique_ids, 5)

# Map folds to the dataset based on the fold mapping
ds = ds.map(lambda batch: assign_fold(batch, fold_mapping), batched=True, batch_size=len(ds))

for fold in range(5):
  print(f"Fold: {fold}")
  train_ds = ds.filter(lambda example: example['fold'] != fold)
  val_ds = ds.filter(lambda example: example['fold'] == fold)

  print(len(train_ds), len(val_ds))

  val_ds.set_transform(batch_tfms_val)

  # Model, optimizer, and loss function setup
  model = load_model('dinov2_vitb14', n_classes)
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
  criterion = nn.CrossEntropyLoss().to(device)

  def collate_fun(batch):
    imgs = [ToTensor()(item['img']) for item in batch]  # Convert PIL images to tensors
    imgs = torch.stack(imgs)
    labels = torch.tensor([item['label'] for item in batch])

    new_imgs = []
    new_labels = []

    for img, label in zip(imgs, labels):
        if label == 0 and synthetic_prob > random.random():
            transformed_img, new_label = SuperimposeSquare(best_params['red_hue'],
                                                            best_params['blue_hue'],
                                                            best_params['red_value'],
                                                            best_params['blue_value'],
                                                            best_params['red_saturation'],
                                                            best_params['blue_saturation'],
                                                            best_params['max_opacity'],
                                                            best_params['min_opacity']
                                                           )(img)
            transformed_img = train_tfms(transformed_img)
            new_imgs.append(transformed_img)
            new_labels.append(new_label)
        else:
            img = train_tfms(img)
            new_imgs.append(img)
            new_labels.append(label)

    imgs = torch.stack(new_imgs)
    labels = torch.tensor(new_labels)
    # Wrap the result in a dictionary
    return {'img': imgs, 'label': labels}

  train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=n_workers, collate_fn=collate_fun)
  val_loader = DataLoader(val_ds, batch_size=bs, num_workers=n_workers)

  # Training loop
  for epoch in range(1, n_epochs +1):
      print(f"Epoch: {epoch}/{n_epochs}")
      model.train()
      for _, data in enumerate(train_loader):
          inputs, labels = data['img'].to(device), data['label'].to(device)
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = criterion(outputs, labels)
          loss.backward()
          optimizer.step()

      # Validation loop
      model.eval()
      all_labels, all_predictions = [], []
      with torch.no_grad():
          for data in val_loader:
              inputs, labels = data['img'].to(device), data['label'].to(device)
              outputs = model(inputs)
              _, predicted = torch.max(outputs.data, 1)
              all_labels.extend(labels.cpu().numpy())
              all_predictions.extend(predicted.cpu().numpy())

      # Compute F1 score
      blue, red, f1 = modified_f1_score(np.array(all_labels), np.array(all_predictions))
      print(f"Blue F1 score: {blue:.4f}\nRed F1 score: {red:.4f}\nMean F1 score: {f1:.4f}")

      f1_scores[epoch].append(f1)

# Return the mean F1 score over the seeds
mean_f1_scores = {epoch: np.mean(scores) for epoch, scores in f1_scores.items()}
highest_f1_epoch, highest_f1_score = max(mean_f1_scores.items(), key=lambda x: x[1])

results = {
    'highest_f1_epoch': highest_f1_epoch,
    'highest_f1_score': highest_f1_score
}

print(f"Highest F1-score {highest_f1_score} at epoch {highest_f1_epoch}")

best_epoch_path = os.path.join(results_dir, 'epoch_result_real_only.json')
with open(best_epoch_path, 'w') as f:
    json.dump(results, f)

Using device: cuda
Learning rate: 8.113530089356352e-06
Batch size: 96


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/414 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/407M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/105M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Number of available CPU cores: 16
Applying preprocessing transforms...


Map:   0%|          | 0/2136 [00:00<?, ? examples/s]

Seed: 1984


Map:   0%|          | 0/2136 [00:00<?, ? examples/s]

Fold: 0


Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

1704 432
Loading model...


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth
100%|██████████| 330M/330M [00:02<00:00, 146MB/s]


Epoch: 1/50
Blue F1 score: 0.2973
Red F1 score: 0.8504
Mean F1 score: 0.5738
Epoch: 2/50
Blue F1 score: 0.8468
Red F1 score: 0.8618
Mean F1 score: 0.8543
Epoch: 3/50
Blue F1 score: 0.8889
Red F1 score: 0.8976
Mean F1 score: 0.8933
Epoch: 4/50
Blue F1 score: 0.8889
Red F1 score: 0.9242
Mean F1 score: 0.9066
Epoch: 5/50
Blue F1 score: 0.8696
Red F1 score: 0.8960
Mean F1 score: 0.8828
Epoch: 6/50
Blue F1 score: 0.9060
Red F1 score: 0.8777
Mean F1 score: 0.8918
Epoch: 7/50
Blue F1 score: 0.8143
Red F1 score: 0.9147
Mean F1 score: 0.8645
Epoch: 8/50
Blue F1 score: 0.8889
Red F1 score: 0.9048
Mean F1 score: 0.8968
Epoch: 9/50
Blue F1 score: 0.8406
Red F1 score: 0.9008
Mean F1 score: 0.8707
Epoch: 10/50
Blue F1 score: 0.9060
Red F1 score: 0.9134
Mean F1 score: 0.9097
Epoch: 11/50
Blue F1 score: 0.8793
Red F1 score: 0.9254
Mean F1 score: 0.9023
Epoch: 12/50
Blue F1 score: 0.8689
Red F1 score: 0.9313
Mean F1 score: 0.9001
Epoch: 13/50
Blue F1 score: 0.8870
Red F1 score: 0.9048
Mean F1 score: 0.

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

1701 435
Loading model...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Epoch: 1/50
Blue F1 score: 0.4910
Red F1 score: 0.8000
Mean F1 score: 0.6455
Epoch: 2/50
Blue F1 score: 0.8571
Red F1 score: 0.9136
Mean F1 score: 0.8854
Epoch: 3/50
Blue F1 score: 0.8113
Red F1 score: 0.9673
Mean F1 score: 0.8893
Epoch: 4/50
Blue F1 score: 0.9000
Red F1 score: 0.9615
Mean F1 score: 0.9308
Epoch: 5/50
Blue F1 score: 0.8943
Red F1 score: 0.9740
Mean F1 score: 0.9342
Epoch: 6/50
Blue F1 score: 0.9138
Red F1 score: 0.9677
Mean F1 score: 0.9408
Epoch: 7/50
Blue F1 score: 0.8649
Red F1 score: 0.9737
Mean F1 score: 0.9193
Epoch: 8/50
Blue F1 score: 0.8649
Red F1 score: 0.9673
Mean F1 score: 0.9161
Epoch: 9/50
Blue F1 score: 0.9231
Red F1 score: 0.9673
Mean F1 score: 0.9452
Epoch: 10/50
Blue F1 score: 0.9043
Red F1 score: 0.9737
Mean F1 score: 0.9390
Epoch: 11/50
Blue F1 score: 0.9091
Red F1 score: 0.9677
Mean F1 score: 0.9384
Epoch: 12/50
Blue F1 score: 0.9076
Red F1 score: 0.9530
Mean F1 score: 0.9303
Epoch: 13/50
Blue F1 score: 0.9138
Red F1 score: 0.9737
Mean F1 score: 0.

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

1743 393
Loading model...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Epoch: 1/50
Blue F1 score: 0.5946
Red F1 score: 0.8293
Mean F1 score: 0.7119
Epoch: 2/50
Blue F1 score: 0.8777
Red F1 score: 0.8788
Mean F1 score: 0.8782
Epoch: 3/50
Blue F1 score: 0.9091
Red F1 score: 0.8322
Mean F1 score: 0.8707
Epoch: 4/50
Blue F1 score: 0.8936
Red F1 score: 0.8689
Mean F1 score: 0.8812
Epoch: 5/50
Blue F1 score: 0.9091
Red F1 score: 0.8960
Mean F1 score: 0.9025
Epoch: 6/50
Blue F1 score: 0.8503
Red F1 score: 0.9147
Mean F1 score: 0.8825
Epoch: 7/50
Blue F1 score: 0.9262
Red F1 score: 0.8955
Mean F1 score: 0.9108
Epoch: 8/50
Blue F1 score: 0.9014
Red F1 score: 0.8800
Mean F1 score: 0.8907
Epoch: 9/50
Blue F1 score: 0.9014
Red F1 score: 0.8976
Mean F1 score: 0.8995
Epoch: 10/50
Blue F1 score: 0.9315
Red F1 score: 0.9048
Mean F1 score: 0.9181
Epoch: 11/50
Blue F1 score: 0.9342
Red F1 score: 0.9134
Mean F1 score: 0.9238
Epoch: 12/50
Blue F1 score: 0.9467
Red F1 score: 0.8819
Mean F1 score: 0.9143
Epoch: 13/50
Blue F1 score: 0.9315
Red F1 score: 0.8689
Mean F1 score: 0.

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

1683 453
Loading model...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Epoch: 1/50
Blue F1 score: 0.3019
Red F1 score: 0.8923
Mean F1 score: 0.5971
Epoch: 2/50
Blue F1 score: 0.8312
Red F1 score: 0.9077
Mean F1 score: 0.8694
Epoch: 3/50
Blue F1 score: 0.8462
Red F1 score: 0.9091
Mean F1 score: 0.8776
Epoch: 4/50
Blue F1 score: 0.8462
Red F1 score: 0.9147
Mean F1 score: 0.8804
Epoch: 5/50
Blue F1 score: 0.8387
Red F1 score: 0.9147
Mean F1 score: 0.8767
Epoch: 6/50
Blue F1 score: 0.8696
Red F1 score: 0.9219
Mean F1 score: 0.8957
Epoch: 7/50
Blue F1 score: 0.8679
Red F1 score: 0.9134
Mean F1 score: 0.8907
Epoch: 8/50
Blue F1 score: 0.8235
Red F1 score: 0.9219
Mean F1 score: 0.8727
Epoch: 9/50
Blue F1 score: 0.8387
Red F1 score: 0.9231
Mean F1 score: 0.8809
Epoch: 10/50
Blue F1 score: 0.8834
Red F1 score: 0.9231
Mean F1 score: 0.9033
Epoch: 11/50
Blue F1 score: 0.8679
Red F1 score: 0.9302
Mean F1 score: 0.8991
Epoch: 12/50
Blue F1 score: 0.8712
Red F1 score: 0.9231
Mean F1 score: 0.8971
Epoch: 13/50
Blue F1 score: 0.8481
Red F1 score: 0.9048
Mean F1 score: 0.

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2136 [00:00<?, ? examples/s]

1713 423
Loading model...


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Epoch: 1/50
Blue F1 score: 0.2609
Red F1 score: 0.8403
Mean F1 score: 0.5506
Epoch: 2/50
Blue F1 score: 0.8571
Red F1 score: 0.9489
Mean F1 score: 0.9030
Epoch: 3/50
Blue F1 score: 0.9009
Red F1 score: 0.9706
Mean F1 score: 0.9357
Epoch: 4/50
Blue F1 score: 0.9107
Red F1 score: 0.9701
Mean F1 score: 0.9404
Epoch: 5/50
Blue F1 score: 0.9231
Red F1 score: 0.9706
Mean F1 score: 0.9468
Epoch: 6/50
Blue F1 score: 0.9189
Red F1 score: 0.9778
Mean F1 score: 0.9483
Epoch: 7/50
Blue F1 score: 0.9580
Red F1 score: 0.9706
Mean F1 score: 0.9643
Epoch: 8/50
Blue F1 score: 0.9744
Red F1 score: 0.9429
Mean F1 score: 0.9586
Epoch: 9/50
Blue F1 score: 0.9204
Red F1 score: 0.9778
Mean F1 score: 0.9491
Epoch: 10/50
Blue F1 score: 0.9107
Red F1 score: 0.9778
Mean F1 score: 0.9442
Epoch: 11/50
Blue F1 score: 0.9381
Red F1 score: 0.9778
Mean F1 score: 0.9579
Epoch: 12/50
Blue F1 score: 0.9483
Red F1 score: 0.9635
Mean F1 score: 0.9559
Epoch: 13/50
Blue F1 score: 0.9322
Red F1 score: 0.9706
Mean F1 score: 0.

In [None]:
runtime.unassign()