<a href="https://colab.research.google.com/github/lxouiis/Fairwool-ai/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp "/content/drive/MyDrive/dataset 2.zip" /content/dataset.zip
!unzip -q dataset.zip


In [3]:
import os

print(os.listdir('/content/dataset/train'))
print(os.listdir('/content/dataset/val'))


['grade_B', 'grade_A', 'grade_C']
['grade_B', 'grade_A', 'grade_C']


In [4]:
!ls
!find . -maxdepth 3 -type d -name train


dataset  dataset.zip  drive  __MACOSX  sample_data
./dataset/train
./__MACOSX/dataset/train


In [5]:
!pip install -q timm torchvision


In [6]:
!rm -rf __MACOSX


In [7]:
%%writefile train_deit.py
import argparse
import json
import math
import os
import random
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import timm


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


@dataclass
class TrainConfig:
    data_dir: str
    output_dir: str
    model_name: str
    img_size: int
    batch_size: int
    epochs_head: int
    epochs_ft: int
    unfreeze_blocks: int
    lr_head: float
    lr_ft: float
    weight_decay: float
    label_smoothing: float
    num_workers: int
    seed: int
    patience: int
    grad_clip: float


def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    # Apple Silicon (MPS)
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    return torch.device("cpu")


def build_transforms(img_size: int):
    # Training: strong but safe augmentations (textures should survive)
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.65, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=8),
        transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.02),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    # Validation / inference: standard resize->center crop (256->224 style)
    resize_size = int(round(img_size * 256 / 224))
    val_tf = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])
    return train_tf, val_tf


def compute_class_weights(targets: List[int], num_classes: int) -> torch.Tensor:
    counts = torch.zeros(num_classes, dtype=torch.long)
    for t in targets:
        counts[t] += 1
    counts = counts.clamp(min=1)
    total = counts.sum().item()
    weights = total / (num_classes * counts.float())
    return weights


def freeze_all_but_head(model: nn.Module):
    for n, p in model.named_parameters():
        p.requires_grad = False

    # Unfreeze classifier head(s) and final norms for stability
    for n, p in model.named_parameters():
        if n.startswith("head") or n.startswith("head_dist") or n.startswith("norm") or n.startswith("fc_norm"):
            p.requires_grad = True


def unfreeze_last_blocks(model: nn.Module, unfreeze_blocks: int):
    # Keep everything frozen by default, then unfreeze selected parts
    for n, p in model.named_parameters():
        p.requires_grad = False

    # Unfreeze last N transformer blocks if present
    if hasattr(model, "blocks"):
        blocks = model.blocks
        n_blocks = len(blocks)
        start = max(0, n_blocks - unfreeze_blocks)
        for i in range(start, n_blocks):
            for p in blocks[i].parameters():
                p.requires_grad = True

    # Also unfreeze final norm(s) and head(s)
    for n, p in model.named_parameters():
        if n.startswith("head") or n.startswith("head_dist") or n.startswith("norm") or n.startswith("fc_norm"):
            p.requires_grad = True


def make_optimizer(model: nn.Module, lr: float, weight_decay: float, head_lr_mult: float = 5.0):
    head_params = []
    body_params = []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("head") or n.startswith("head_dist"):
            head_params.append(p)
        else:
            body_params.append(p)

    param_groups = []
    if body_params:
        param_groups.append({"params": body_params, "lr": lr})
    if head_params:
        param_groups.append({"params": head_params, "lr": lr * head_lr_mult})

    opt = torch.optim.AdamW(param_groups, lr=lr, weight_decay=weight_decay)
    return opt


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]:
    model.eval()
    correct = 0
    total = 0

    # Confusion matrix + macro-F1 without sklearn (keeps deps minimal)
    num_classes = loader.dataset.classes.__len__()
    cm = torch.zeros((num_classes, num_classes), dtype=torch.long)

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(images)
        preds = torch.argmax(logits, dim=1)

        correct += (preds == labels).sum().item()
        total += labels.numel()

        for t, p in zip(labels.view(-1), preds.view(-1)):
            cm[t.long(), p.long()] += 1

    acc = correct / max(1, total)

    # Macro F1
    f1s = []
    for k in range(num_classes):
        tp = cm[k, k].item()
        fp = cm[:, k].sum().item() - tp
        fn = cm[k, :].sum().item() - tp
        prec = tp / max(1, (tp + fp))
        rec = tp / max(1, (tp + fn))
        f1 = 0.0 if (prec + rec) == 0 else (2 * prec * rec / (prec + rec))
        f1s.append(f1)
    macro_f1 = sum(f1s) / len(f1s)

    return {
        "val_acc": acc,
        "macro_f1": macro_f1,
    }


def train_one_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
    grad_clip: float,
) -> float:
    model.train()
    running_loss = 0.0
    n = 0

    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        if use_amp:
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            if grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
            optimizer.step()

        running_loss += loss.item() * labels.size(0)
        n += labels.size(0)

    return running_loss / max(1, n)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_dir", type=str, required=True, help="Root folder containing train/ and val/")
    ap.add_argument("--output_dir", type=str, required=True)
    ap.add_argument("--model_name", type=str, default="deit_small_patch16_224.fb_in1k")
    ap.add_argument("--img_size", type=int, default=224)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--epochs_head", type=int, default=4)
    ap.add_argument("--epochs_ft", type=int, default=10)
    ap.add_argument("--unfreeze_blocks", type=int, default=4, help="Unfreeze last N transformer blocks")
    ap.add_argument("--lr_head", type=float, default=1e-3)
    ap.add_argument("--lr_ft", type=float, default=1e-4)
    ap.add_argument("--weight_decay", type=float, default=0.05)
    ap.add_argument("--label_smoothing", type=float, default=0.1)
    ap.add_argument("--num_workers", type=int, default=4)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--patience", type=int, default=4, help="Early stopping patience (by val_acc plateau)")
    ap.add_argument("--grad_clip", type=float, default=1.0)
    args = ap.parse_args()

    cfg = TrainConfig(
        data_dir=args.data_dir,
        output_dir=args.output_dir,
        model_name=args.model_name,
        img_size=args.img_size,
        batch_size=args.batch_size,
        epochs_head=args.epochs_head,
        epochs_ft=args.epochs_ft,
        unfreeze_blocks=args.unfreeze_blocks,
        lr_head=args.lr_head,
        lr_ft=args.lr_ft,
        weight_decay=args.weight_decay,
        label_smoothing=args.label_smoothing,
        num_workers=args.num_workers,
        seed=args.seed,
        patience=args.patience,
        grad_clip=args.grad_clip,
    )

    seed_everything(cfg.seed)
    device = get_device()

    out_dir = Path(cfg.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # Data
    train_tf, val_tf = build_transforms(cfg.img_size)
    train_dir = Path(cfg.data_dir) / "train"
    val_dir = Path(cfg.data_dir) / "val"

    if not train_dir.exists() or not val_dir.exists():
        raise FileNotFoundError(
            f"Expected {train_dir} and {val_dir} to exist. "
            f"If you only have class folders, run split_dataset.py first."
        )

    ds_train = datasets.ImageFolder(str(train_dir), transform=train_tf)
    ds_val = datasets.ImageFolder(str(val_dir), transform=val_tf)

    if ds_train.classes != ds_val.classes:
        raise ValueError(f"Class mismatch. train classes={ds_train.classes}, val classes={ds_val.classes}")

    num_classes = len(ds_train.classes)
    class_to_idx = ds_train.class_to_idx
    idx_to_class = {v: k for k, v in class_to_idx.items()}

    # Class weights
    class_weights = compute_class_weights(ds_train.targets, num_classes=num_classes).to(device)

    # Loss
    try:
        criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=cfg.label_smoothing)
    except TypeError:
        criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Loaders
    pin_memory = (device.type == "cuda")
    dl_train = DataLoader(
        ds_train,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )
    dl_val = DataLoader(
        ds_val,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    # Model
    model = timm.create_model(cfg.model_name, pretrained=True, num_classes=num_classes)
    model.to(device)

    # Save config immediately
    with open(out_dir / "config.json", "w", encoding="utf-8") as f:
        json.dump(asdict(cfg), f, indent=2)

    print(f"Device: {device}")
    print(f"Classes: {ds_train.classes}")
    print(f"Class weights: {class_weights.detach().cpu().tolist()}")

    best_acc = -1.0
    best_epoch = -1
    epochs_no_improve = 0

    def save_ckpt(path: Path, epoch: int, val_acc: float):
        torch.save(
            {
                "model_name": cfg.model_name,
                "img_size": cfg.img_size,
                "mean": IMAGENET_MEAN,
                "std": IMAGENET_STD,
                "classes": ds_train.classes,
                "class_to_idx": class_to_idx,
                "idx_to_class": idx_to_class,
                "epoch": epoch,
                "val_acc": val_acc,
                "state_dict": model.state_dict(),
            },
            path,
        )

    # Phase 1: head-only
    print("\n=== Phase 1: Train head (freeze backbone) ===")
    freeze_all_but_head(model)
    optimizer = make_optimizer(model, lr=cfg.lr_head, weight_decay=cfg.weight_decay, head_lr_mult=1.0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, cfg.epochs_head))

    for epoch in range(cfg.epochs_head):
        train_loss = train_one_epoch(model, dl_train, optimizer, criterion, device, cfg.grad_clip)
        metrics = evaluate(model, dl_val, device)
        scheduler.step()

        val_acc = metrics["val_acc"]
        print(
            f"[Head] Epoch {epoch+1}/{cfg.epochs_head} | "
            f"train_loss={train_loss:.4f} | val_acc={val_acc:.4f} | macro_f1={metrics['macro_f1']:.4f}"
        )

        save_ckpt(out_dir / "last.pt", epoch=epoch, val_acc=val_acc)

        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = epoch
            epochs_no_improve = 0
            save_ckpt(out_dir / "best.pt", epoch=epoch, val_acc=val_acc)
        else:
            epochs_no_improve += 1

    # Phase 2: partial fine-tune
    print("\n=== Phase 2: Fine-tune last transformer blocks + head ===")
    unfreeze_last_blocks(model, unfreeze_blocks=cfg.unfreeze_blocks)
    optimizer = make_optimizer(model, lr=cfg.lr_ft, weight_decay=cfg.weight_decay, head_lr_mult=5.0)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, cfg.epochs_ft))

    for epoch in range(cfg.epochs_ft):
        train_loss = train_one_epoch(model, dl_train, optimizer, criterion, device, cfg.grad_clip)
        metrics = evaluate(model, dl_val, device)
        scheduler.step()

        global_epoch = cfg.epochs_head + epoch
        val_acc = metrics["val_acc"]
        print(
            f"[FT]   Epoch {epoch+1}/{cfg.epochs_ft} (global {global_epoch}) | "
            f"train_loss={train_loss:.4f} | val_acc={val_acc:.4f} | macro_f1={metrics['macro_f1']:.4f}"
        )

        save_ckpt(out_dir / "last.pt", epoch=global_epoch, val_acc=val_acc)

        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = global_epoch
            epochs_no_improve = 0
            save_ckpt(out_dir / "best.pt", epoch=global_epoch, val_acc=val_acc)
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= cfg.patience:
            print(f"\nEarly stopping: no improvement in {cfg.patience} evals.")
            break

    print(f"\nBest val_acc={best_acc:.4f} at epoch={best_epoch}")
    print(f"Saved checkpoints in: {out_dir}")


if __name__ == "__main__":
    main()


Writing train_deit.py


In [8]:
!pip install -q timm torchvision

In [9]:
!python train_deit.py \
  --data_dir "/content/dataset" \
  --output_dir "/content/outputs_deit" \
  --model_name "deit_small_patch16_224" \
  --batch_size 32 \
  --epochs_head 4 \
  --epochs_ft 10 \
  --unfreeze_blocks 4 \
  --lr_head 1e-3 \
  --lr_ft 1e-4 \
  --num_workers 2

model.safetensors: 100% 88.2M/88.2M [00:01<00:00, 69.2MB/s]
Device: cuda
Classes: ['grade_A', 'grade_B', 'grade_C']
Class weights: [1.622666597366333, 2.817129611968994, 0.49291208386421204]

=== Phase 1: Train head (freeze backbone) ===
  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast():
[Head] Epoch 1/4 | train_loss=0.8310 | val_acc=0.7986 | macro_f1=0.7275
[Head] Epoch 2/4 | train_loss=0.6967 | val_acc=0.8141 | macro_f1=0.7367
[Head] Epoch 3/4 | train_loss=0.6498 | val_acc=0.8569 | macro_f1=0.7809
[Head] Epoch 4/4 | train_loss=0.6485 | val_acc=0.8655 | macro_f1=0.7914

=== Phase 2: Fine-tune last transformer blocks + head ===
[FT]   Epoch 1/10 (global 4) | train_loss=0.6933 | val_acc=0.9349 | macro_f1=0.8931
[FT]   Epoch 2/10 (global 5) | train_loss=0.6244 | val_acc=0.9220 | macro_f1=0.8838
[FT]   Epoch 3/10 (global 6) | train_loss=0.5836 | val_acc=0.8243 | macro_f1=0.8016
[FT]   Epoch 4/10 (global 7) | train_loss=0.5537 | val_acc=0.7678 | macro_f

In [10]:
!find /content/drive -name "best.pt"

In [11]:
!zip -r fairwool_project.zip \
/content/app.py \
/content/predict.py \
/content/train_deit.py \
/content/outputs_deit \
/content/outputs_deit2 \
/content/cloudflared-linux-amd64.deb


  adding: content/train_deit.py (deflated 71%)
  adding: content/outputs_deit/ (stored 0%)
  adding: content/outputs_deit/best.pt (deflated 7%)
  adding: content/outputs_deit/config.json (deflated 45%)
  adding: content/outputs_deit/last.pt (deflated 7%)


In [12]:
from google.colab import files
files.download("fairwool_project.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [13]:
%%writefile predict.py
import argparse
from pathlib import Path

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import timm


def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    return torch.device("cpu")


def build_infer_transform(img_size: int, mean, std):
    resize_size = int(round(img_size * 256 / 224))
    return transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--checkpoint", type=str, required=True)
    ap.add_argument("--image_path", type=str, required=True)
    ap.add_argument("--topk", type=int, default=3)
    args = ap.parse_args()

    ckpt = torch.load(args.checkpoint, map_location="cpu")

    model_name = ckpt["model_name"]
    classes = ckpt["classes"]
    img_size = ckpt["img_size"]
    mean = ckpt["mean"]
    std = ckpt["std"]

    device = get_device()

    model = timm.create_model(model_name, pretrained=False, num_classes=len(classes))
    model.load_state_dict(ckpt["state_dict"], strict=True)
    model.to(device)
    model.eval()

    tf = build_infer_transform(img_size, mean, std)

    img = Image.open(args.image_path).convert("RGB")
    x = tf(img).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(x)
        probs = F.softmax(logits, dim=1).squeeze(0).detach().cpu()

    topk = min(args.topk, len(classes))
    values, indices = torch.topk(probs, k=topk)

    pred_idx = int(torch.argmax(probs).item())
    pred_class = classes[pred_idx]
    pred_prob = float(probs[pred_idx].item())

    print(f"Prediction: {pred_class} (p={pred_prob:.4f})")
    print("Top-k:")
    for v, i in zip(values.tolist(), indices.tolist()):
        print(f"  {classes[i]}: {v:.4f}")

    # Optional: print full distribution
    print("\nFull probabilities:")
    for cls, p in zip(classes, probs.tolist()):
        print(f"  {cls}: {p:.4f}")


if __name__ == "__main__":
    main()


Writing predict.py


In [14]:
import random, glob

images = []
for c in ["grade_A","grade_B","grade_C"]:
    images += glob.glob(f"/content/dataset/val/{c}/*")

for img in random.sample(images, 6):
    print("\nIMAGE:", img)
    !python predict.py --checkpoint "/content/outputs_deit/best.pt" --image_path "$img"


IMAGE: /content/dataset/val/grade_A/79d04b31eca473391201305076.jpg
Prediction: grade_A (p=0.7413)
Top-k:
  grade_A: 0.7413
  grade_B: 0.1464
  grade_C: 0.1123

Full probabilities:
  grade_A: 0.7413
  grade_B: 0.1464
  grade_C: 0.1123

IMAGE: /content/dataset/val/grade_A/d32a51126513ed621235157601.jpg
Prediction: grade_A (p=0.6185)
Top-k:
  grade_A: 0.6185
  grade_C: 0.2373
  grade_B: 0.1442

Full probabilities:
  grade_A: 0.6185
  grade_B: 0.1442
  grade_C: 0.2373

IMAGE: /content/dataset/val/grade_C/line_2018-10-10 11_47_45.349139.jpg
Prediction: grade_C (p=0.6620)
Top-k:
  grade_C: 0.6620
  grade_A: 0.1908
  grade_B: 0.1472

Full probabilities:
  grade_A: 0.1908
  grade_B: 0.1472
  grade_C: 0.6620

IMAGE: /content/dataset/val/grade_C/371.jpg
Prediction: grade_C (p=0.7361)
Top-k:
  grade_C: 0.7361
  grade_B: 0.1388
  grade_A: 0.1252

Full probabilities:
  grade_A: 0.1252
  grade_B: 0.1388
  grade_C: 0.7361

IMAGE: /content/dataset/val/grade_A/0017_000_00.png
Prediction: grade_A (p=0.

In [15]:
!ls "/content/dataset/val/grade_A" | head

0001_000_01.png
0001_000_02.png
0001_000_03.png
0001_000_05.png
0002_000_00.png
0002_000_02.png
0002_000_03.png
0002_000_04.png
0002_000_05.png
0002_000_06.png


In [16]:
!python predict.py \
  --checkpoint "/content/outputs_deit/best.pt" \
  --image_path "/content/dataset/val/grade_A//content/dataset/val/grade_A//content/dataset/val/grade_A/00bd30945ae5fb701316505594.jpg" \
  --topk 3

Traceback (most recent call last):
  File "/content/predict.py", line 79, in <module>
    main()
  File "/content/predict.py", line 53, in main
    img = Image.open(args.image_path).convert("RGB")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/PIL/Image.py", line 3513, in open
    fp = builtins.open(filename, "rb")
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
FileNotFoundError: [Errno 2] No such file or directory: '/content/dataset/val/grade_A//content/dataset/val/grade_A//content/dataset/val/grade_A/00bd30945ae5fb701316505594.jpg'


In [17]:
!ls "/content/dataset/val/grade_A/00bd30945ae5fb701316505594.jpg"

/content/dataset/val/grade_A/00bd30945ae5fb701316505594.jpg


In [18]:
!python predict.py \
  --checkpoint "/content/outputs_deit/best.pt" \
  --image_path "/content/dataset/val/grade_A/00bd30945ae5fb701316505594.jpg" \
  --topk 3

Prediction: grade_A (p=0.6354)
Top-k:
  grade_A: 0.6354
  grade_B: 0.2133
  grade_C: 0.1513

Full probabilities:
  grade_A: 0.6354
  grade_B: 0.2133
  grade_C: 0.1513


In [21]:
import random, glob

images = []
for c in ["grade_A","grade_B","grade_C"]:
    images += glob.glob(f"/content/dataset/val/{c}/*")

for img in random.sample(images, 4):
    print("\nIMAGE:", img)
    !python predict.py --checkpoint "/content/outputs_deit2/best.pt" --image_path "$img"


IMAGE: /content/dataset/val/grade_A/b774903aca7e08be1548164385.jpg
Prediction: grade_A (p=0.6697)
Top-k:
  grade_A: 0.6697
  grade_B: 0.1767
  grade_C: 0.1536

Full probabilities:
  grade_A: 0.6697
  grade_B: 0.1767
  grade_C: 0.1536

IMAGE: /content/dataset/val/grade_A/0001_000_01.png
Prediction: grade_A (p=0.9429)
Top-k:
  grade_A: 0.9429
  grade_B: 0.0503
  grade_C: 0.0068

Full probabilities:
  grade_A: 0.9429
  grade_B: 0.0503
  grade_C: 0.0068

IMAGE: /content/dataset/val/grade_C/hole_2018-10-11 13_59_17.317728.jpg
Prediction: grade_C (p=0.7580)
Top-k:
  grade_C: 0.7580
  grade_A: 0.1481
  grade_B: 0.0939

Full probabilities:
  grade_A: 0.1481
  grade_B: 0.0939
  grade_C: 0.7580

IMAGE: /content/dataset/val/grade_A/54.jpg
Prediction: grade_A (p=0.4643)
Top-k:
  grade_A: 0.4643
  grade_C: 0.3934
  grade_B: 0.1422

Full probabilities:
  grade_A: 0.4643
  grade_B: 0.1422
  grade_C: 0.3934


In [22]:
%%writefile app.py
import streamlit as st
import torch
import torch.nn.functional as F
from PIL import Image
import timm
from torchvision import transforms

# ---------------- THEME ----------------
st.set_page_config(
    page_title="ReWoolution AI",
    page_icon="üêë",
    layout="centered"
)

# Pastoral theme colors
st.markdown("""
<style>
body {
    background-color: #f5f1ed;
}
.main-title {
    font-size: 36px;
    font-weight: 700;
    color: #404954;
}
.subtitle {
    font-size: 18px;
    color: #70635B;
}
.card {
    padding: 20px;
    border-radius: 12px;
    background-color: white;
    box-shadow: 0px 4px 12px rgba(0,0,0,0.08);
}
.gradeA {color: #2E7D32; font-weight: bold;}
.gradeB {color: #F9A825; font-weight: bold;}
.gradeC {color: #C62828; font-weight: bold;}
</style>
""", unsafe_allow_html=True)

# ---------- Load model ----------
ckpt = torch.load("outputs_deit2/best.pt", map_location="cpu")
classes = ckpt["classes"]
img_size = ckpt["img_size"]
mean, std = ckpt["mean"], ckpt["std"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = timm.create_model(ckpt["model_name"], pretrained=False, num_classes=len(classes))
model.load_state_dict(ckpt["state_dict"])
model.to(device).eval()

# ---------- Transform ----------
tf = transforms.Compose([
    transforms.Resize(int(round(img_size * 256 / 224))),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

# ---------- UI ----------
st.markdown('<div class="main-title">üêë FairWool AI</div>', unsafe_allow_html=True)
st.markdown('<div class="subtitle">AI-powered grading for sustainable wool ecosystems</div>', unsafe_allow_html=True)
st.markdown("---")

uploaded = st.file_uploader("Upload wool image", type=["jpg","png","jpeg"])

if uploaded:
    img = Image.open(uploaded).convert("RGB")
    st.image(img, caption="Input Image", use_container_width=True)

    x = tf(img).unsqueeze(0).to(device)

    with torch.no_grad():
        probs = F.softmax(model(x), dim=1).squeeze(0).cpu()

    pred_idx = int(torch.argmax(probs))
    grade = classes[pred_idx]
    conf = float(probs[pred_idx])

    # Confidence label
    if conf > 0.75:
        conf_label = "HIGH"
    elif conf > 0.55:
        conf_label = "MEDIUM"
    else:
        conf_label = "LOW ‚Äî Recommend manual review"

    # Use case mapping
    if grade == "grade_A":
        use_case = "Premium fiber ‚Äì Textile / Fabric"
        grade_color = "gradeA"
    elif grade == "grade_B":
        use_case = "Medium fiber ‚Äì Insulation / Craft"
        grade_color = "gradeB"
    else:
        use_case = "Coarse fiber ‚Äì Industrial / Low-grade"
        grade_color = "gradeC"

    st.markdown('<div class="card">', unsafe_allow_html=True)
    st.markdown(f"### Predicted Grade: <span class='{grade_color}'>{grade}</span>", unsafe_allow_html=True)
    st.write(f"Confidence: **{conf:.2f} ({conf_label})**")
    st.write(f"Best Use: **{use_case}**")
    st.markdown("</div>", unsafe_allow_html=True)

    st.markdown("---")
    st.markdown("### üåç Social Impact")
    st.write("‚úî Helps pastoralists get fair grading")
    st.write("‚úî Reduces wool waste")
    st.write("‚úî Supports women cooperatives")


    st.markdown(
    """
    ### üåæ Our Vision
    > **‚ÄúToday, shepherds don‚Äôt know what their wool is worth.
    > FairWool AI gives grading power and price intelligence directly in their hands ‚Äî
    > turning waste into dignity.‚Äù**
    """)

Writing app.py


In [23]:
!pkill -f streamlit || true
!pip -q install streamlit timm torchvision

!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb
!dpkg -i cloudflared-linux-amd64.deb

!streamlit run app.py --server.port 8501 --server.headless true --server.enableCORS false --server.enableXsrfProtection false & \
cloudflared tunnel --url http://localhost:8501

^C
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m9.1/9.1 MB[0m [31m77.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m6.9/6.9 MB[0m [31m101.7 MB/s[0m eta [36m0:00:00[0m
[?25hSelecting previously unselected package cloudflared.
(Reading database ... 121852 files and directories currently installed.)
Preparing to unpack cloudflared-linux-amd64.deb ...
Unpacking cloudflared (2026.2.0) ...
Setting up cloudflared (2026.2.0) ...
Processing triggers for man-db (2.10.2-1) ...
[90m2026-02-21T21:23:59Z[0m [32mINF[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms o

In [26]:
mkdir -p "/content/drive/MyDrive/github_projects/my_ai_project"

In [27]:
cp "/content/Untitled0.ipynb" "/content/drive/MyDrive/github_projects/my_ai_project/"

cp: cannot stat '/content/Untitled0.ipynb': No such file or directory
