#Install Needed Dependencies

In [1]:
!pip install pytorch-metric-learning

Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-2.8.1-py3-none-any.whl.metadata (18 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.6.0->pytorch-metric-learning)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux

In [2]:
!pip install torch torchvision pyyaml scikit-learn



# Load Dataset

In [3]:
import os

os.makedirs('/content/dataset_folder', exist_ok=True)

In [4]:
!pip install -U gdown

file_id = '1iVnP4gjw-iHXa0KerZQ1IfIO0i1jADsR'
output_file = '/content/dataset.zip'
!gdown --id {file_id} --output {output_file}

Downloading...
From (original): https://drive.google.com/uc?id=1iVnP4gjw-iHXa0KerZQ1IfIO0i1jADsR
From (redirected): https://drive.google.com/uc?id=1iVnP4gjw-iHXa0KerZQ1IfIO0i1jADsR&confirm=t&uuid=8a0cf342-41e1-4840-b4ea-8b0cc8a2259b
To: /content/dataset.zip
100% 9.20G/9.20G [01:34<00:00, 97.8MB/s]


In [5]:
unzip_folder = '/content/dataset_folder'
!unzip -q {output_file} -d {unzip_folder}

print("Contents of dataset_folder:")
print(os.listdir('/content/dataset_folder'))

Contents of dataset_folder:
['University-Release']


# ACMMM23-Solution-MBEG

In [6]:
!git clone https://github.com/Reza-Zhu/ACMMM23-Solution-MBEG.git /content/ACMMM23-Solution-MBEG

%cd /content/ACMMM23-Solution-MBEG

Cloning into '/content/ACMMM23-Solution-MBEG'...
remote: Enumerating objects: 45, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (40/40), done.[K
remote: Total 45 (delta 17), reused 20 (delta 4), pack-reused 0 (from 0)[K
Receiving objects: 100% (45/45), 8.80 MiB | 17.70 MiB/s, done.
Resolving deltas: 100% (17/17), done.
/content/ACMMM23-Solution-MBEG


In [7]:
weights_path = "/content/ACMMM23-Solution-MBEG/weights"

os.makedirs(weights_path, exist_ok=True)

print(f"Directory created at: {weights_path}")

Directory created at: /content/ACMMM23-Solution-MBEG/weights


In [8]:
import yaml

with open("/content/ACMMM23-Solution-MBEG/settings.yaml", 'r') as f:
    config = yaml.safe_load(f)

    config['dataset_path'] = '/content/dataset_folder/University-Release'
    config['weight_save_path'] = weights_path
    config['num_epochs'] = 10
    config['batch_size'] = 4
    config['model'] = 'ResNet'
    config['name'] = 'ResNet_1652'

    with open("settings.yaml", 'w') as f:
        yaml.dump(config, f)

In [9]:
import os
os.chdir("/content/ACMMM23-Solution-MBEG")

# Create a subset training data of 100 classes each containing 20 images

In [20]:
import os
import shutil
import random

full_dataset_path = "/content/dataset_folder/University-Release/train/"
subset_path = "/content/dataset_subset"
subset_size_per_class = 20
num_classes_to_sample = 100

modalities = ['drone', 'satellite']

sat_classes = set(os.listdir(os.path.join(full_dataset_path, 'satellite')))
drone_classes = set(os.listdir(os.path.join(full_dataset_path, 'drone')))
shared_classes = list(sat_classes & drone_classes)

selected_classes = random.sample(shared_classes, min(num_classes_to_sample, len(shared_classes)))
print(f"Selected {len(selected_classes)} classes.")

for modality in modalities:
    modality_full_path = os.path.join(full_dataset_path, modality)
    modality_subset_path = os.path.join(subset_path, modality)
    os.makedirs(modality_subset_path, exist_ok=True)

    for cls in selected_classes:
        class_full_path = os.path.join(modality_full_path, cls)
        class_subset_path = os.path.join(modality_subset_path, cls)
        os.makedirs(class_subset_path, exist_ok=True)

        all_images = [f for f in os.listdir(class_full_path) if f.lower().endswith(('.jpg','.png','.jpeg'))]
        if not all_images:
            print(f"No images found in {class_full_path}, skipping.")
            continue

        sampled_images = random.sample(all_images, min(subset_size_per_class, len(all_images)))

        for img_name in sampled_images:
            src = os.path.join(class_full_path, img_name)
            dst = os.path.join(class_subset_path, img_name)
            shutil.copyfile(src, dst)

print("Subset folder with randomly selected classes and images created successfully!")


Selected 100 classes.
Subset folder with randomly selected classes and images created successfully!


# Modifying Data Preprocessing & Model Definition

In [21]:
code='''
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from pathlib import Path

class DualResNet(nn.Module):
    def __init__(self, num_classes=1652, pretrained=True):
        super(DualResNet, self).__init__()
        self.backbone1 = models.resnet18(pretrained=pretrained)
        self.backbone2 = models.resnet18(pretrained=pretrained)
        self.backbone1.fc = nn.Identity()
        self.backbone2.fc = nn.Identity()
        self.classifier1 = nn.Linear(512, num_classes)
        self.classifier2 = nn.Linear(512, num_classes)

    def forward(self, x1, x2):
        f1 = self.backbone1(x1)
        f2 = self.backbone2(x2)
        out1 = self.classifier1(f1)
        out2 = self.classifier2(f2)
        return out1, out2, f1, f2

def get_num_classes(data_dir):
      sat_dir = Path(data_dir) / "satellite"
      class_folders = [f for f in sat_dir.iterdir() if f.is_dir()]
      return len(class_folders)
'''
with open("dualresnet.py", "w") as f:
        f.write(code)

In [57]:
code = '''
import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path

class PairedU1652Dataset(Dataset):
    def __init__(self, root_dir, transform_sat=None, transform_drone=None):
        self.root_dir = root_dir

        self.sat_dir = os.path.join(root_dir, 'satellite')
        self.drone_dir = os.path.join(root_dir, 'drone')

        self.transform_sat = transform_sat
        self.transform_drone = transform_drone

        self.sat_classes = sorted(os.listdir(self.sat_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.sat_classes)}

        self.sat_images = []
        self.drone_images = []
        self.labels = []

        for cls_name in self.sat_classes:
            sat_cls_path = os.path.join(self.sat_dir, cls_name)
            drone_cls_path = os.path.join(self.drone_dir, cls_name)

            sat_imgs = sorted(os.listdir(sat_cls_path))
            drone_imgs = sorted(os.listdir(drone_cls_path))

            assert len(sat_imgs) == 1, f"Expected exactly 1 satellite image per class '{cls_name}', but got {len(sat_imgs)}"

            sat_img_name = sat_imgs[0]
            sat_img_path = os.path.join(sat_cls_path, sat_img_name)

            for drone_img_name in drone_imgs:
                drone_img_path = os.path.join(drone_cls_path, drone_img_name)

                self.sat_images.append(sat_img_path)       # Repeat the same satellite image path
                self.drone_images.append(drone_img_path)
                self.labels.append(self.class_to_idx[cls_name])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sat_img_path = self.sat_images[idx]
        drone_img_path = self.drone_images[idx]
        label = self.labels[idx]

        sat_img = Image.open(sat_img_path).convert('RGB')
        drone_img = Image.open(drone_img_path).convert('RGB')

        if self.transform_sat:
            sat_img = self.transform_sat(sat_img)
        if self.transform_drone:
            drone_img = self.transform_drone(drone_img)

        return sat_img, drone_img, label

'''
with open("Modified_Preprocessing.py", "w") as f:
    f.write(code)


In [58]:
import importlib
import Modified_Preprocessing
importlib.reload(Modified_Preprocessing)


<module 'Modified_Preprocessing' from '/content/ACMMM23-Solution-MBEG/Modified_Preprocessing.py'>

# Train Using DualResNet

In [60]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

from dualresnet import DualResNet
from Modified_Preprocessing import PairedU1652Dataset, get_num_classes
from torch.utils.data import DataLoader
from pathlib import Path

def train():
    num_epochs = get_yaml_value("num_epochs")
    lr = get_yaml_value("lr")
    batch_size = get_yaml_value("batch_size")
    data_dir = "/content/dataset_subset"
    device = torch.device("cpu")

    num_classes = 100
    print(f"\n Detected number of classes: {num_classes}")

    transform_sat = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    transform_drone = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    dataset = PairedU1652Dataset(root_dir=data_dir,
                                transform_sat=transform_sat, transform_drone=transform_drone)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

    model = DualResNet(num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_acc = 0.0
    weights_dir = "/content/ACMMM23-Solution-MBEG/weights"
    print("\n Starting training...\n")
    since = time.time()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        correct1 = correct2 = total = 0

        for sat_img, drone_img, labels in dataloader:
            sat_img, drone_img, labels = sat_img.to(device), drone_img.to(device), labels.to(device)

            optimizer.zero_grad()
            out1, out2, _, _ = model(sat_img, drone_img)

            loss1 = criterion(out1, labels)
            loss2 = criterion(out2, labels)
            loss = (loss1 + loss2) / 2
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)
            correct1 += (out1.argmax(1) == labels).sum().item()
            correct2 += (out2.argmax(1) == labels).sum().item()
            total += labels.size(0)

        epoch_loss = total_loss / total
        acc1 = correct1 / total
        acc2 = correct2 / total
        avg_acc = (acc1 + acc2) / 2

        print(f" Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss:.4f} | Sat Acc: {acc1:.4f}, Drone Acc: {acc2:.4f}")

        if avg_acc > best_acc:
            best_acc = avg_acc
            weights_path = os.path.join(weights_dir, "best_model.pth")
            torch.save(model.state_dict(), weights_path)
            print(f"New best model saved with avg accuracy: {best_acc:.4f}")

    time_elapsed = time.time() - since
    print(f"\nTraining complete in {int(time_elapsed // 60)}m {int(time_elapsed % 60)}s")
    print(f" Best avg accuracy: {best_acc:.4f}")
if __name__ == '__main__':
    train()


 Detected number of classes: 100

 Starting training...

 Epoch [1/10] | Loss: 5.0207 | Sat Acc: 0.0200, Drone Acc: 0.0170
New best model saved with avg accuracy: 0.0185
 Epoch [2/10] | Loss: 4.0769 | Sat Acc: 0.0810, Drone Acc: 0.0335
New best model saved with avg accuracy: 0.0573
 Epoch [3/10] | Loss: 3.3876 | Sat Acc: 0.2695, Drone Acc: 0.0550
New best model saved with avg accuracy: 0.1623
 Epoch [4/10] | Loss: 2.3824 | Sat Acc: 0.7440, Drone Acc: 0.0715
New best model saved with avg accuracy: 0.4078
 Epoch [5/10] | Loss: 1.8651 | Sat Acc: 0.9630, Drone Acc: 0.0995
New best model saved with avg accuracy: 0.5312
 Epoch [6/10] | Loss: 1.6415 | Sat Acc: 0.9955, Drone Acc: 0.1455
New best model saved with avg accuracy: 0.5705
 Epoch [7/10] | Loss: 1.4621 | Sat Acc: 0.9770, Drone Acc: 0.2335
New best model saved with avg accuracy: 0.6052
 Epoch [8/10] | Loss: 1.2339 | Sat Acc: 0.9875, Drone Acc: 0.3275
New best model saved with avg accuracy: 0.6575
 Epoch [9/10] | Loss: 1.0106 | Sat Acc

In [61]:
code = '''
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

from dualresnet import DualResNet
from Modified_Preprocessing import PairedU1652Dataset, get_num_classes
from torch.utils.data import DataLoader
from pathlib import Path

def train():
    # Configs
    num_epochs = get_yaml_value("num_epochs")
    lr = get_yaml_value("lr")
    batch_size = get_yaml_value("batch_size")
    data_dir = "/content/dataset_subset"
    device = torch.device("cpu")

    #sat_dir = '/content/dataset_subset/satellite'
    #num_classes = get_num_classes(sat_dir)
    num_classes = 100
    print(f"\n Detected number of classes: {num_classes}")

    transform_sat = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])
    transform_drone = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ])

    dataset = PairedU1652Dataset(root_dir=data_dir,
                                transform_sat=transform_sat, transform_drone=transform_drone)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

    model = DualResNet(num_classes=num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_acc = 0.0
    weights_dir = "/content/ACMMM23-Solution-MBEG/weights"
    print("\n Starting training...\n")
    since = time.time()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        correct1 = correct2 = total = 0

        for sat_img, drone_img, labels in dataloader:
            sat_img, drone_img, labels = sat_img.to(device), drone_img.to(device), labels.to(device)

            optimizer.zero_grad()
            out1, out2, _, _ = model(sat_img, drone_img)

            loss1 = criterion(out1, labels)
            loss2 = criterion(out2, labels)
            loss = (loss1 + loss2) / 2
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * labels.size(0)
            correct1 += (out1.argmax(1) == labels).sum().item()
            correct2 += (out2.argmax(1) == labels).sum().item()
            total += labels.size(0)

        epoch_loss = total_loss / total
        acc1 = correct1 / total
        acc2 = correct2 / total
        avg_acc = (acc1 + acc2) / 2

        print(f" Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss:.4f} | Sat Acc: {acc1:.4f}, Drone Acc: {acc2:.4f}")

        if avg_acc > best_acc:
            best_acc = avg_acc
            weights_path = os.path.join(weights_dir, "best_model.pth")
            torch.save(model.state_dict(), weights_path)
            print(f"New best model saved with avg accuracy: {best_acc:.4f}")

    time_elapsed = time.time() - since
    print(f"\nTraining complete in {int(time_elapsed // 60)}m {int(time_elapsed % 60)}s")
    print(f" Best avg accuracy: {best_acc:.4f}")
if __name__ == '__main__':
    train()
'''
with open("train.py", "w") as f:
      f.write(code)

# Create subset of testing dataset of 100 classes

In [75]:
import os
import shutil
import random

full_dataset_path = "/content/dataset_folder/University-Release/test/"
subset_path = "/content/dataset_subset_test"
num_classes_to_sample = 100

modalities = ['query_drone', 'query_satellite']


sat_classes = set(os.listdir(os.path.join(full_dataset_path, 'query_satellite')))
drone_classes = set(os.listdir(os.path.join(full_dataset_path, 'query_drone')))
shared_classes = list(sat_classes & drone_classes)


selected_classes = random.sample(shared_classes, min(num_classes_to_sample, len(shared_classes)))
print(f"Selected {len(selected_classes)} classes.")


for modality in modalities:
    modality_full_path = os.path.join(full_dataset_path, modality)
    modality_subset_path = os.path.join(subset_path, modality)
    os.makedirs(modality_subset_path, exist_ok=True)

    for cls in selected_classes:
        class_full_path = os.path.join(modality_full_path, cls)
        class_subset_path = os.path.join(modality_subset_path, cls)
        os.makedirs(class_subset_path, exist_ok=True)

        all_images = [f for f in os.listdir(class_full_path) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
        if not all_images:
            print(f"No images found in {class_full_path}, skipping.")
            continue


        for img_name in all_images:
            src = os.path.join(class_full_path, img_name)
            dst = os.path.join(class_subset_path, img_name)
            shutil.copyfile(src, dst)

print("Subset folder with all images from selected classes created successfully!")


Selected 100 classes.
Subset folder with all images from selected classes created successfully!


# Testing & Evaluating

In [76]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from dualresnet import DualResNet


def make_dataloaders(root_dir, image_size=224, batch_size=16, num_workers=2):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    ds_sat = datasets.ImageFolder(os.path.join(root_dir, "query_satellite"), transform)
    ds_dr = datasets.ImageFolder(os.path.join(root_dir, "query_drone"), transform)

    return {
        'satellite': DataLoader(ds_sat, batch_size=batch_size, shuffle=False, num_workers=num_workers),
        'drone': DataLoader(ds_dr, batch_size=batch_size, shuffle=False, num_workers=num_workers),
    }, {
        'satellite': ds_sat,
        'drone': ds_dr
    }

import numpy as np
import torch

def compute_mAP(index, good_index, junk_index):
    cmc = torch.IntTensor(len(index)).zero_()
    if len(good_index) == 0:
        return 0.0, cmc

    mask = ~np.isin(index, junk_index)
    index = index[mask]

    order = np.where(np.isin(index, good_index))[0]

    if len(order) == 0:
        return 0.0, cmc

    cmc[order[0]:] = 1

    num_good = len(good_index)
    precision_at_i = [(i + 1) / (rank + 1) for i, rank in enumerate(order)]
    ap = np.sum(precision_at_i) / num_good

    return ap, cmc

def evaluate(qf, ql, gf, gl):
    qf = qf.view(1, -1)
    scores = torch.nn.functional.cosine_similarity(gf, qf, dim=1).cpu().numpy()
    index = np.argsort(scores)[::-1]
    gl = np.asarray(gl)

    good_index = np.argwhere(gl == ql).flatten()
    junk_index = np.argwhere(gl == -1).flatten()

    ap, cmc = compute_mAP(index, good_index, junk_index)
    return ap, cmc


def extract_features(model, loader, view=1):
    model.eval()
    feats, labels = [], []
    with torch.no_grad():
        for imgs, labs in loader:
            if view == 1:
                feat, _, _, _ = model(imgs, torch.zeros_like(imgs))
            else:
                _, feat, _, _ = model(torch.zeros_like(imgs), imgs)
            feat = F.normalize(feat, dim=1)
            feats.append(feat.cpu())
            labels.extend(labs.numpy())
    return torch.cat(feats), np.array(labels)


def test():
    root_dir = '/content/dataset_subset_test'
    weights_path = '/content/ACMMM23-Solution-MBEG/weights/best_model.pth'

    dataloaders, datasets = make_dataloaders(root_dir)
    model = DualResNet(num_classes=len(datasets['satellite'].classes))
    model.load_state_dict(torch.load(weights_path, map_location='cpu'))
    model = model.to('cpu')
    model.eval()

    gf, gl = extract_features(model, dataloaders['drone'], view=2)
    qf, ql = extract_features(model, dataloaders['satellite'], view=1)

    CMCs, APs = [], []
    for i in range(len(ql)):
        ap, cmc = evaluate(qf[i], ql[i], gf, gl)
        if cmc[0] == -1:
            continue
        CMCs.append(cmc.unsqueeze(0))
        APs.append(ap)

    CMC = torch.cat(CMCs).float().mean(0).numpy()
    mAP = np.mean(APs)
    print(f"Recall@1: {CMC[0]*100:.2f}%, Recall@5: {CMC[4]*100:.2f}%, mAP: {mAP*100:.2f}%")


if __name__ == '__main__':
    test()


Recall@1: 2.00%, Recall@5: 7.00%, mAP: 2.68%


In [77]:
code = '''
import os
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from dualresnet import DualResNet


def make_dataloaders(root_dir, image_size=224, batch_size=16, num_workers=2):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    ds_sat = datasets.ImageFolder(os.path.join(root_dir, "query_satellite"), transform)
    ds_dr = datasets.ImageFolder(os.path.join(root_dir, "query_drone"), transform)

    return {
        'satellite': DataLoader(ds_sat, batch_size=batch_size, shuffle=False, num_workers=num_workers),
        'drone': DataLoader(ds_dr, batch_size=batch_size, shuffle=False, num_workers=num_workers),
    }, {
        'satellite': ds_sat,
        'drone': ds_dr
    }

import numpy as np
import torch

def compute_mAP(index, good_index, junk_index):
    cmc = torch.IntTensor(len(index)).zero_()
    if len(good_index) == 0:
        return 0.0, cmc

    mask = ~np.isin(index, junk_index)
    index = index[mask]

    order = np.where(np.isin(index, good_index))[0]

    if len(order) == 0:
        return 0.0, cmc

    cmc[order[0]:] = 1

    num_good = len(good_index)
    precision_at_i = [(i + 1) / (rank + 1) for i, rank in enumerate(order)]
    ap = np.sum(precision_at_i) / num_good

    return ap, cmc

def evaluate(qf, ql, gf, gl):
    qf = qf.view(1, -1)
    scores = torch.nn.functional.cosine_similarity(gf, qf, dim=1).cpu().numpy()
    index = np.argsort(scores)[::-1]
    gl = np.asarray(gl)

    good_index = np.argwhere(gl == ql).flatten()
    junk_index = np.argwhere(gl == -1).flatten()

    ap, cmc = compute_mAP(index, good_index, junk_index)
    return ap, cmc


def extract_features(model, loader, view=1):
    model.eval()
    feats, labels = [], []
    with torch.no_grad():
        for imgs, labs in loader:
            if view == 1:
                feat, _, _, _ = model(imgs, torch.zeros_like(imgs))
            else:
                _, feat, _, _ = model(torch.zeros_like(imgs), imgs)
            feat = F.normalize(feat, dim=1)
            feats.append(feat.cpu())
            labels.extend(labs.numpy())
    return torch.cat(feats), np.array(labels)


def test():
    root_dir = '/content/dataset_subset_test'
    weights_path = '/content/ACMMM23-Solution-MBEG/weights/best_model.pth'

    dataloaders, datasets = make_dataloaders(root_dir)
    model = DualResNet(num_classes=len(datasets['satellite'].classes))
    model.load_state_dict(torch.load(weights_path, map_location='cpu'))
    model = model.to('cpu')
    model.eval()

    gf, gl = extract_features(model, dataloaders['drone'], view=2)
    qf, ql = extract_features(model, dataloaders['satellite'], view=1)

    CMCs, APs = [], []
    for i in range(len(ql)):
        ap, cmc = evaluate(qf[i], ql[i], gf, gl)
        if cmc[0] == -1:
            continue
        CMCs.append(cmc.unsqueeze(0))
        APs.append(ap)

    CMC = torch.cat(CMCs).float().mean(0).numpy()
    mAP = np.mean(APs)
    print(f"Recall@1: {CMC[0]*100:.2f}%, Recall@5: {CMC[4]*100:.2f}%, mAP: {mAP*100:.2f}%")


if __name__ == '__main__':
    test()

'''
with open("U1652_test_and_evaluate.py", "w") as f:
        f.write(code)