In [None]:
from google.colab import files
import zipfile
import os

uploaded = files.upload()

zip_name = list(uploaded.keys())[0]
extract_path = "/content/dataset"

with zipfile.ZipFile(zip_name, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Dataset extracted to:", extract_path)

In [None]:
!pip install albumentations opencv-python

In [None]:
import os
import cv2
import random
import albumentations as A
from tqdm import tqdm

In [None]:
augment = A.Compose([
    A.RandomResizedCrop(size=(128, 128), scale=(0.6, 1.0)),
    A.Rotate(limit=10, p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
    A.CenterCrop(height=128, width=128, p=0.3),
    A.OneOf([
        A.GaussianBlur(blur_limit=(3, 5), p=1.0),
        A.GaussNoise(var_limit=(10.0, 50.0), p=1.0),
    ], p=0.3),
    A.CoarseDropout(
        max_holes=8,
        max_height=16,
        max_width=16,
        p=0.3
    )
])

In [None]:
DATASET_DIR = "/content/dataset/custom_fs_ts"
TARGET_SAMPLES = 50

In [None]:
for class_name in os.listdir(DATASET_DIR):
    class_path = os.path.join(DATASET_DIR, class_name)
    if not os.path.isdir(class_path):
        continue

    images = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    current_count = len(images)

    if current_count >= TARGET_SAMPLES:
        continue

    print(f"Augmenting class '{class_name}': {current_count} → {TARGET_SAMPLES}")

    while len(images) < TARGET_SAMPLES:
        img_name = random.choice(images)
        img_path = os.path.join(class_path, img_name)

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        augmented = augment(image=image)["image"]

        new_name = f"aug_{len(images)}_{img_name}"
        save_path = os.path.join(class_path, new_name)

        cv2.imwrite(save_path, cv2.cvtColor(augmented, cv2.COLOR_RGB2BGR))
        images.append(new_name)

In [None]:
for class_name in sorted(os.listdir(DATASET_DIR)):
    class_path = os.path.join(DATASET_DIR, class_name)
    if os.path.isdir(class_path):
        print(class_name, "→", len(os.listdir(class_path)))

In [None]:
import torch

print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")

In [None]:
!git clone https://github.com/facebookresearch/dinov3.git
%cd dinov3

In [None]:
!pip install -q -r requirements.txt
!pip install -q einops timm pillow tqdm

In [None]:
from huggingface_hub import login
login()

In [None]:
!pip install -U transformers accelerate safetensors

In [None]:
import torch
from transformers import AutoModel, AutoImageProcessor

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoImageProcessor.from_pretrained(
    "facebook/dinov3-vits16-pretrain-lvd1689m"
)

model = AutoModel.from_pretrained(
    "facebook/dinov3-vits16-pretrain-lvd1689m"
).to(DEVICE)

model.eval()

print("DINOv3 loaded successfully via Hugging Face")

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

dataset = datasets.ImageFolder("/content/dataset/custom_fs_ts", transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2
)

num_classes = len(dataset.classes)
print("Classes:", dataset.classes)

In [None]:
from transformers import AutoModel

encoder = AutoModel.from_pretrained(
    "facebook/dinov3-vits16-pretrain-lvd1689m"
).to(DEVICE)

encoder.eval()

for param in encoder.parameters():
    param.requires_grad = False

print("DINOv3 encoder loaded and frozen")

In [None]:
import torch

def split_into_quadrants(images):

    B, C, H, W = images.shape
    h_mid, w_mid = H // 2, W // 2

    q1 = images[:, :, :h_mid, :w_mid]
    q2 = images[:, :, :h_mid, w_mid:]
    q3 = images[:, :, h_mid:, :w_mid]
    q4 = images[:, :, h_mid:, w_mid:]

    return [q1, q2, q3, q4]

In [None]:
encoder.eval()

def extract_quadrant_features(images):
    with torch.no_grad():
        full_out = encoder(pixel_values=images)
        cls_full = full_out.last_hidden_state[:, 0]
        quadrants = split_into_quadrants(images)
        quad_cls_tokens = []
        for q in quadrants:
            out = encoder(pixel_values=q)
            cls_q = out.last_hidden_state[:, 0]
            quad_cls_tokens.append(cls_q)
        cls_quadrant_mean = torch.stack(quad_cls_tokens, dim=1).mean(dim=1)
        final_features = torch.cat([cls_full, cls_quadrant_mean], dim=1)

    return final_features

In [None]:
from tqdm import tqdm

features = []
labels = []

for imgs, lbls in tqdm(dataloader):
    imgs = imgs.to(DEVICE)
    feats = extract_quadrant_features(imgs)

    features.append(feats.cpu())
    labels.append(lbls)

X = torch.cat(features)
y = torch.cat(labels)

print("Feature shape:", X.shape)

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42
)

X_train = X_train.to(DEVICE)
X_val = X_val.to(DEVICE)
y_train = y_train.to(DEVICE)
y_val = y_val.to(DEVICE)

In [None]:
import torch.nn as nn

emb_dim = X.shape[1] // 2
num_classes = len(dataset.classes)

classifier = nn.Linear(2 * emb_dim, num_classes).to(DEVICE)

In [None]:
from torch.utils.data import TensorDataset

full_dataset = TensorDataset(X, y)

In [None]:
from torch.utils.data import random_split

val_ratio = 0.2
val_size = int(len(full_dataset) * val_ratio)
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(
    full_dataset, [train_size, val_size]
)

In [None]:
from torch.utils.data import DataLoader

batch_size = 32

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3)

In [None]:
epochs = 100

for epoch in range(epochs):


    classifier.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        optimizer.zero_grad()

        logits = classifier(X_batch)
        loss = criterion(logits, y_batch)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * y_batch.size(0)
        correct += (logits.argmax(dim=1) == y_batch).sum().item()
        total += y_batch.size(0)

    train_loss /= total
    train_acc = correct / total


    classifier.eval()
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(DEVICE)
            y_batch = y_batch.to(DEVICE)

            logits = classifier(X_batch)
            val_correct += (logits.argmax(dim=1) == y_batch).sum().item()
            val_total += y_batch.size(0)

    val_acc = val_correct / val_total


    if (epoch + 1) % 10 == 0:
        print(
            f"Epoch [{epoch+1}/{epochs}] "
            f"Train Loss: {train_loss:.4f} "
            f"Train Acc: {train_acc:.4f} "
            f"Val Acc: {val_acc:.4f}"
        )

In [None]:
torch.save({
    "encoder": encoder.state_dict(),
    "classifier": classifier.state_dict(),
    "classes": dataset.classes
}, "dinov3_linear_probe.pth")

print("Model saved successfully")

In [None]:
from PIL import Image
import torch

img_path = "/content/dataset/custom_fs_ts/EA/Dropped Image (2).png"

img = Image.open(img_path).convert("RGB")
img = transform(img).unsqueeze(0).to(DEVICE)

encoder.eval()
classifier.eval()

with torch.no_grad():

    full_out = encoder(pixel_values=img)
    cls_full = full_out.last_hidden_state[:, 0]


    quadrants = split_into_quadrants(img)
    quad_cls_tokens = []

    for q in quadrants:
        out = encoder(pixel_values=q)
        quad_cls_tokens.append(out.last_hidden_state[:, 0])

    cls_quad_mean = torch.stack(quad_cls_tokens, dim=1).mean(dim=1)


    feat = torch.cat([cls_full, cls_quad_mean], dim=1)

    pred = classifier(feat).argmax(dim=1)

print("Predicted class:", dataset.classes[pred.item()])