# Method 2

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader, Subset
from torchgeo.datasets import VHR10
import random

# ----- Step 1: Load dataset -----
def preprocess(sample):
    sample["image"] = sample["image"].float() / 255.0
    return sample

ds = VHR10(
    root="data/VHR10/",
    split="positive",
    transforms=preprocess,
    download=True,
    checksum=True,
)

# ----- Step 2: Split base vs novel classes -----
# NWPU VHR-10: 10 classes (1–10). Giả sử chọn 3 novel classes: airplane=1, baseball diamond=4, tennis court=5
novel_classes = [1, 4, 5]
base_classes = [c for c in range(1, 11) if c not in novel_classes]

base_indices, novel_indices = [], []
for i in range(len(ds)):
    labels = ds[i]["label"]
    if any(l in novel_classes for l in labels):
        novel_indices.append(i)
    else:
        base_indices.append(i)

base_ds = Subset(ds, base_indices)
novel_ds = Subset(ds, novel_indices)

# Few-shot: chọn K=5 ảnh cho mỗi novel class
K = 5
fewshot_indices = []
for cls in novel_classes:
    cls_idxs = [i for i in novel_indices if cls in ds[i]["label"]]
    fewshot_indices.extend(random.sample(cls_idxs, min(K, len(cls_idxs))))
fewshot_ds = Subset(ds, fewshot_indices)

# ----- Step 3: Model (toy Faster R-CNN) -----
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights="DEFAULT"
)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
    in_features, 11  # 10 classes + background
)

# ----- Step 4: Memory Bank for MCL -----
class MemoryBank:
    def __init__(self, size=8192, feat_dim=256, device="cpu"):
        self.size = size
        self.ptr = 0
        self.feats = torch.zeros(size, feat_dim, device=device)
        self.labels = torch.zeros(size, dtype=torch.long, device=device)

    def update(self, feats, labels):
        bsz = labels.shape[0]   # <-- luôn sync với labels
        if bsz > self.size:
            feats, labels = feats[:self.size], labels[:self.size]
            bsz = self.size

        idxs = (self.ptr + torch.arange(bsz)) % self.size
        idxs = idxs.long().to(self.feats.device)

        # Debug
        # print("update:", bsz, feats.shape, labels.shape, idxs.shape)

        self.feats[idxs] = feats.detach()
        # self.labels[idxs] = labels.detach()
        self.labels[idxs] = labels.detach().view(-1)[:len(idxs)]

        self.ptr = (self.ptr + bsz) % self.size

    def get(self):
        return self.feats, self.labels

def supervised_contrastive_loss(features, labels, memory_feats, memory_labels, temperature=0.1):
    # Normalize
    features = F.normalize(features, dim=1)
    memory_feats = F.normalize(memory_feats, dim=1)

    logits = torch.mm(features, memory_feats.t()) / temperature
    labels = labels.view(-1, 1)
    mask = (labels == memory_labels.view(1, -1)).float()

    exp_logits = torch.exp(logits)
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)
    loss = -(mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
    return loss.mean()

# ----- Step 5: Training Loop (pseudo) -----
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

memory = MemoryBank(size=1024, feat_dim=in_features)

# def extract_roi_features(model, images, targets):
#     # Toy: use box_head from FasterRCNN to get features
#     features = model.backbone(images.tensors)
#     return features  # placeholder (need integration with roi_heads)

def extract_gt_features(model, images, targets):
    # Transform
    transformed = model.transform(images, targets)
    images_t, targets_t = transformed

    # Backbone
    features = model.backbone(images_t.tensors)
    if isinstance(features, torch.Tensor):
        features = {"0": features}

    # Lấy gt boxes làm proposal
    proposals = [t["boxes"] for t in targets_t]

    # ROI Pooling trên GT boxes
    box_features = model.roi_heads.box_roi_pool(
        features, proposals, images_t.image_sizes
    )
    # Head
    box_features = model.roi_heads.box_head(box_features)

    labels = torch.cat([t["labels"] for t in targets_t])
    return box_features, labels

for epoch in range(2):  # demo only
    for batch in DataLoader(fewshot_ds, batch_size=2, shuffle=True, collate_fn=lambda x: x):
        images = [s["image"] for s in batch]
        targets = []
        for s in batch:
            boxes = s["bbox_xyxy"]
            labels = s["label"]
            targets.append({"boxes": boxes, "labels": labels})

        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        detection_loss = sum(loss for loss in loss_dict.values())

        # MCL: giả sử có roi_feats (cần implement thêm)
        # roi_feats = torch.randn(len(targets), in_features, device=device)  # placeholder
        # roi_labels = torch.cat([t["labels"] for t in targets])

        roi_feats, roi_labels = extract_gt_features(model, images, targets)

        memory_feats, memory_labels = memory.get()
        memory_feats, memory_labels = memory_feats.to(device), memory_labels.to(device)

        if memory_labels.sum() > 0:
            mcl_loss = supervised_contrastive_loss(roi_feats, roi_labels, memory_feats, memory_labels)
        else:
            mcl_loss = torch.tensor(0.0, device=device)

        loss = detection_loss + 0.3 * mcl_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        memory.update(roi_feats.cpu(), roi_labels.cpu())

    print(f"Epoch {epoch}, Loss {loss.item():.4f}")


Files already downloaded and verified
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Epoch 0, Loss 3.2169
Epoch 1, Loss 1.9979


In [2]:
# Assuming the training loop is complete
# Path to save the model
model_path = "faster_rcnn_mcl_finetuned.pth"

# Save the state dictionary of the model
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to faster_rcnn_mcl_finetuned.pth


In [3]:
# Step 1: Instantiate the same model architecture
# You need the same model definition as before
reloaded_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    weights=None  # Start with no pre-trained weights
)
in_features = reloaded_model.roi_heads.box_predictor.cls_score.in_features
reloaded_model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
    in_features, 11  # Match the output classes
)

# Step 2: Load the saved state dictionary
model_path = "faster_rcnn_mcl_finetuned.pth"
reloaded_model.load_state_dict(torch.load(model_path))

# Step 3: Set the model to evaluation mode
reloaded_model.eval()

# Step 4: Move the model to the desired device
device = "cuda" if torch.cuda.is_available() else "cpu"
reloaded_model.to(device)

print("Model loaded successfully and ready for inference!")

Model loaded successfully and ready for inference!
