In [None]:
0!pip uninstall -y jax jaxlib
!pip install -Uqq datasets

from google.colab import userdata, runtime
import subprocess

hf_token = userdata.get('hf_token')
input_str = f'{hf_token}\nn\n'
result = subprocess.run(['huggingface-cli', 'login'], input=input_str, text=True, capture_output=True)
print(result.stdout)

Found existing installation: jax 0.4.26
Uninstalling jax-0.4.26:
  Successfully uninstalled jax-0.4.26
Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
  Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25h
    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|

In [None]:
import os
import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, default_collate
from torchvision.transforms import ToTensor, Compose, CenterCrop, Normalize
from sklearn.metrics import f1_score
import json

from datasets import load_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed = 1984

n_workers = os.cpu_count()
print(f'Number of available CPU cores: {n_workers}')

Using device: cuda
Number of available CPU cores: 16


In [None]:
results_dir = '/content/drive/MyDrive/Colab_Notebooks/dye_test_opt/ternary/results'
model_path = os.path.join(results_dir, 'models', 'test_real_only.pth')

context_sz = 154 # 1x1 meter context
bs = 96
seed = 0

imagenet_stats = {'mean': [0.485, 0.456, 0.406],
                  'std': [0.229, 0.224, 0.225]}

In [None]:
ds = load_dataset('mpg-ranch/dye_test', split='test')

# Preprocessing transforms
preprocs = Compose([
    CenterCrop((context_sz, context_sz)),
    ToTensor(),  # Convert the image to a PyTorch tensor
    Normalize(mean=imagenet_stats['mean'], std=imagenet_stats['std']),  # Normalize using ImageNet stats
])

def preproc_transforms(examples):
    examples["img"] = [preprocs(image.convert("RGB")) for image in examples["image"]]
    return examples

print("Applying preprocessing transforms...")
test_ds = ds.map(preproc_transforms, remove_columns=["image","color","size","concentration"], batched=True, batch_size=len(ds))
test_ds.set_format(type='torch')
n_classes = len(np.unique(test_ds['label']))

Downloading data:   0%|          | 0.00/407M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/105M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Applying preprocessing transforms...


Map:   0%|          | 0/549 [00:00<?, ? examples/s]

In [None]:
def modified_f1_score(labels, predictions):
    scores = []

    # For 1 vs. 0, exclude class 2 and consider only 1 and 0 for binary comparison
    mask_1_vs_0 = (labels != 2)
    labels_1_vs_0 = labels[mask_1_vs_0] == 1
    predictions_1_vs_0 = predictions[mask_1_vs_0] == 1
    if np.any(labels_1_vs_0) or np.any(predictions_1_vs_0):
        f1_1_vs_0 = f1_score(labels_1_vs_0, predictions_1_vs_0, pos_label=True, average='binary')
        scores.append(f1_1_vs_0)

    # For 2 vs. 0, exclude class 1 and consider only 2 and 0 for binary comparison
    mask_2_vs_0 = (labels != 1)
    labels_2_vs_0 = labels[mask_2_vs_0] == 2
    predictions_2_vs_0 = predictions[mask_2_vs_0] == 2
    if np.any(labels_2_vs_0) or np.any(predictions_2_vs_0):
        f1_2_vs_0 = f1_score(labels_2_vs_0, predictions_2_vs_0, pos_label=True, average='binary')
        scores.append(f1_2_vs_0)

    # Calculate the mean of the F1 scores if any valid scores were calculated
    mean_f1 = np.mean(scores) if scores else 0.0

    return f1_1_vs_0, f1_2_vs_0, mean_f1

In [None]:
print(f"Seed: {seed}")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

test_loader = DataLoader(test_ds, batch_size=bs, shuffle=False, num_workers=0)

def load_model(arch, n_classes):
    print("Loading model...")
    model = torch.hub.load('facebookresearch/dinov2', arch)
    num_ftrs = model.norm.normalized_shape[0]
    model.head = nn.Linear(num_ftrs, n_classes)
    model.to(device)
    return model

model = load_model('dinov2_vitb14', n_classes)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)

model.eval()

print("Running inference...")

all_labels, all_predictions = [], []
with torch.no_grad():
    preds = []
    for data in test_loader:
        inputs, labels = data['img'].to(device), data['label'].to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

    # Compute F1 score
    blue, red, f1 = modified_f1_score(np.array(all_labels), np.array(all_predictions))
    print(f"Blue F1 score: {blue:.4f}\nRed F1 score: {red:.4f}\nMean F1 score: {f1:.4f}")

with open(os.path.join(results_dir, 'f1_test_results_overall_real_only.json'), 'w') as f:
    json.dump({'blue': blue, 'red': red, 'mean': f1}, f)

Seed: 0
Loading model...


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth

  0%|          | 0.00/330M [00:00<?, ?B/s][A
  1%|          | 1.91M/330M [00:00<00:19, 17.4MB/s][A
  1%|▏         | 4.43M/330M [00:00<00:15, 22.4MB/s][A
  2%|▏         | 6.95M/330M [00:00<00:14, 22.7MB/s][A
  3%|▎         | 10.2M/330M [00:00<00:12, 27.0MB/s][A
  4%|▍         | 14.6M/330M [00:00<00:09, 33.4MB/s][A
  6%|▌         | 19.5M/330M [00:00<00:08, 37.9MB/s][A
  8%|▊         | 25.3M/330M [00:00<00:07, 45.0MB/s][A
 10%|█         | 33.2M/330M [00:00<00:05, 56.3MB/s][A
 13%|█▎        | 42.0M/330M [00:00<00:04, 67.5MB/s][A
 16%|█▌        | 53.5M/330M [00:01<00:03, 83.5MB/s][A
 20%|██        | 66.7M/330M [00:01<00:02, 99.8MB/s][A
 25%|██▍       | 82.5M/330M [00:01<00:02, 120MB/s] [A
 31%|███       | 101M/330M [

Running inference...
Blue F1 score: 0.9290
Red F1 score: 0.9307
Mean F1 score: 0.9298


In [None]:
results_df = ds.to_pandas()
results_df['pred'] = all_predictions
results_df.drop(columns=['image'], inplace=True)
results_df.to_csv(os.path.join(results_dir, 'test_results_real_only.csv'), index=False)

In [None]:
runtime.unassign()