In [4]:
import os
import random
import torch
from copy import deepcopy
from torch.utils.data import DataLoader

from super_gradients.training import models
from super_gradients.training.datasets.detection_datasets import YoloDarknetFormatDetectionDataset
from super_gradients.training.metrics import DetectionMetrics

# Configuration
DATA_DIR = "D:/data"
CLASS_NAMES = ["airplane","ship","storage-tank","ground-track-field","harbor",
               "bridge","large-vehicle","small-vehicle","helicopter"]
NUM_CLASSES = len(CLASS_NAMES)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N_WAY, K_SHOT, Q_QUERY = 3, 5, 5
INNER_STEPS = 3
INNER_LR, META_LR = 1e-3, 1e-4
EPOCHS = 10
TASKS_PER_BATCH = 5

def load_dataset(split):
    return YoloDarknetFormatDetectionDataset(
        data_dir=DATA_DIR,
        images_dir=os.path.join(DATA_DIR, "images", split),
        labels_dir=os.path.join(DATA_DIR, "labels", split),
        classes=CLASS_NAMES,  # 🔧 Fix here
    )

train_ds = load_dataset("train")
val_ds   = load_dataset("val")

def build_index(dataset):
    idx = {}
    for s in dataset:
        for ann in s["target"]:
            if len(ann) >= 5:
                cls = int(ann[-1])
                idx.setdefault(cls, []).append(s)
    return idx

def sample_task(idx, n_way, k_shot, q_query):
    chosen = random.sample(list(idx), n_way)
    sup, qry = [], []
    for c in chosen:
        items = idx[c]
        if len(items) < k_shot + q_query:
            continue
        sel = random.sample(items, k_shot + q_query)
        sup.extend(sel[:k_shot])
        qry.extend(sel[k_shot:])
    return sup, qry

def collate_batch(samples):
    imgs = torch.stack([s["image"] for s in samples]).to(DEVICE)
    tgts = [s["target"].to(DEVICE) for s in samples]
    return imgs, tgts

class MAMLDetector(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.net = models.get("yolo_nas_s", pretrained_weights="coco").to(DEVICE)

    def forward(self, x):
        return self.net(x)

    def clone(self):
        clone = deepcopy(self)
        for p in clone.net.backbone.parameters():
            p.requires_grad = False
        return clone.to(DEVICE)

meta_model = MAMLDetector()
meta_opt = torch.optim.Adam(meta_model.parameters(), lr=META_LR)
class_idx = build_index(train_ds)

for epoch in range(1, EPOCHS + 1):
    print(f"Epoch {epoch}/{EPOCHS}")
    meta_opt.zero_grad()
    total_loss = 0.0

    for task in range(TASKS_PER_BATCH):
        learner = meta_model.clone()
        inner_opt = torch.optim.SGD(filter(lambda p: p.requires_grad, learner.parameters()), lr=INNER_LR)

        support, query = sample_task(class_idx, N_WAY, K_SHOT, Q_QUERY)
        if not support or not query:
            continue

        learner.train()
        for _ in range(INNER_STEPS):
            imgs_s, tgts_s = collate_batch(support)
            preds_s = learner(imgs_s)
            loss_s = learner.net.loss(preds_s, tgts_s)
            inner_opt.zero_grad()
            loss_s.backward()
            inner_opt.step()

        learner.eval()
        imgs_q, tgts_q = collate_batch(query)
        preds_q = learner(imgs_q)
        loss_q = learner.net.loss(preds_q, tgts_q)
        loss_q.backward()
        total_loss += loss_q.item()

    meta_opt.step()
    avg_loss = total_loss / TASKS_PER_BATCH if TASKS_PER_BATCH else float("nan")
    print(f"Meta Loss (avg): {avg_loss:.4f}")

# Validation
print("\nEvaluating on validation data")
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, collate_fn=lambda b: {
    "images": torch.stack([s["image"] for s in b]).to(DEVICE),
    "targets": [s["target"].to(DEVICE) for s in b]
})
metrics = DetectionMetrics(num_classes=NUM_CLASSES, compute_on_step=False)

with torch.no_grad():
    for batch in val_loader:
        imgs = batch["images"]
        tgts = batch["targets"]
        preds = meta_model(imgs)
        metrics.update(preds=preds, target=tgts)

res = metrics.compute()
print(f"mAP@50:    {res['map_50']:.4f}")
print(f"mAP@50-95: {res['map_50_95']:.4f}")


Indexing dataset annotations: 100%|████████████████████████████████████████████████| 1165/1165 [00:32<00:00, 35.39it/s]
Indexing dataset annotations: 100%|██████████████████████████████████████████████████| 379/379 [00:10<00:00, 34.64it/s]


URLError: <urlopen error [Errno 11001] getaddrinfo failed>

In [5]:
import urllib.request

url = "https://deci-pretrained-models.s3.amazonaws.com/super_gradients/YoloNAS_S_COCO.ckpt"
urllib.request.urlretrieve(url, "YoloNAS_S_COCO.ckpt")


HTTPError: HTTP Error 403: Forbidden