In [1]:
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from google.colab import drive
import random
from torch.utils.data import DataLoader, dataset

In [2]:
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
class MatchingDataset(Dataset):
    def __init__(self, ground_dir, generated_dir, seg_dir, candidate_dir, transform=None):
        self.ground_dir = ground_dir
        self.generated_dir = generated_dir
        self.seg_dir = seg_dir
        self.candidate_dir = candidate_dir
        self.transform = transform

        self.ground_files = sorted([f for f in os.listdir(ground_dir) if f.endswith(".jpg")])
        self.generated_files = sorted([f for f in os.listdir(generated_dir) if f.endswith(".png")])
        self.seg_files = sorted([f for f in os.listdir(seg_dir) if f.endswith(".png")])
        self.candidate_files = sorted([f for f in os.listdir(candidate_dir) if f.endswith(".png")])

        assert len(self.generated_files) == len(self.seg_files), \
            "Mismatch tra immagini generate e segmentate!"
        assert len(self.generated_files) <= len(self.ground_files), \
            "Più immagini generate che ground!"

    def __len__(self):
        return len(self.generated_files)

    def __getitem__(self, idx):

        ground_filename = self.ground_files[idx]
        ground_path = os.path.join(self.ground_dir, ground_filename)
        ground_img = Image.open(ground_path).convert("RGB")

        gen_filename = self.generated_files[idx]
        gen_path = os.path.join(self.generated_dir, gen_filename)
        gen_img = Image.open(gen_path).convert("RGB")

        seg_filename = self.seg_files[idx]
        seg_path = os.path.join(self.seg_dir, seg_filename)
        seg_img = Image.open(seg_path).convert("RGB")

        image_id = os.path.splitext(ground_filename)[0]
        candidate_filename = f"input{image_id}.png"
        candidate_path = os.path.join(self.candidate_dir, candidate_filename)
        candidate_img = Image.open(candidate_path).convert("RGB")

        negative_idx = idx
        while negative_idx == idx:
            negative_idx = random.randint(0, len(self.candidate_files) - 1)
        negative_filename = self.candidate_files[negative_idx]
        negative_path = os.path.join(self.candidate_dir, negative_filename)
        negative_img = Image.open(negative_path).convert("RGB")

        if self.transform:
            ground_img = self.transform(ground_img)
            gen_img = self.transform(gen_img)
            seg_img = self.transform(seg_img)
            candidate_img = self.transform(candidate_img)
            negative_img = self.transform(negative_img)

        return {
            "ground": ground_img,
            "generated": gen_img,
            "seg": seg_img,
            "candidate": candidate_img,
            "candidate_neg": negative_img
        }

In [4]:
transform_rgb = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
])

dataset = MatchingDataset(
    ground_dir="/content/drive/MyDrive/Dataset_Computer_Vision/Dataset_CVUSA/Dataset_CVUSA/streetview",
    generated_dir="/content/drive/MyDrive/generated_images",
    seg_dir="/content/drive/MyDrive/generated_seg",
    candidate_dir="/content/drive/MyDrive/Dataset_Computer_Vision/Dataset_CVUSA/Dataset_CVUSA/bingmap",
    transform=transform_rgb
)

loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class VGGFeatureExtractor(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        vgg = models.vgg16(pretrained=pretrained)
        self.features = nn.Sequential(*list(vgg.features.children()))
        self.pool = vgg.avgpool
        self.fc = nn.Sequential(*list(vgg.classifier.children())[:-1])  # fino a penultimo layer

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class FeatureFusionNet(nn.Module):
    def __init__(self, input_dim, embed_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, embed_dim),
        )

    def forward(self, x):
        x = self.net(x)
        x = F.normalize(x, p=2, dim=1)
        return x

class JointFeatureLearningNet(nn.Module):
    def __init__(self, pretrained=True, embed_dim=256):
        super().__init__()

        self.vgg_G = VGGFeatureExtractor(pretrained)
        self.vgg_A = VGGFeatureExtractor(pretrained)
        self.vgg_S = VGGFeatureExtractor(pretrained)
        self.vgg_C = self.vgg_A

        self.ffn_GAS = FeatureFusionNet(input_dim=4096*3, embed_dim=embed_dim)
        self.ffn_AC = FeatureFusionNet(input_dim=4096*2, embed_dim=embed_dim)

    def forward(self, G, A, S, C):
        fG = self.vgg_G(G)
        fA = self.vgg_A(A)
        fS = self.vgg_S(S)
        fC = self.vgg_C(C)

        embed_G = self.ffn_GAS(torch.cat([fG, fA, fS], dim=1))
        embed_C = self.ffn_AC(torch.cat([fA, fC], dim=1))
        return embed_G, embed_C


In [6]:
model = JointFeatureLearningNet(pretrained=True, embed_dim=256).to(device)

vgg_params = []
ffn_params = []

for name, param in model.named_parameters():
    if 'vgg' in name:
        vgg_params.append(param)
    else:
        ffn_params.append(param)

triplet_loss = nn.TripletMarginLoss(margin=0.2, p=2)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 229MB/s]


In [7]:
model = JointFeatureLearningNet().to(device)

checkpoint = torch.load('/content/drive/MyDrive/checkpoints_feature/all_models_epoch05.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print("Modello e ottimizzatore ripristinati con successo.")

model.train()
num_epochs = 5
optimizer = torch.optim.Adam([
    {'params': vgg_params, 'lr': 1e-4},
    {'params': ffn_params, 'lr': 1e-3},
])

Modello e ottimizzatore ripristinati con successo.


In [8]:
for epoch in range(num_epochs):
    for batch in loader:
        G = batch["ground"].to(device)
        A = batch["generated"].to(device)
        S = batch["seg"].to(device)
        C_pos = batch["candidate"].to(device)
        C_neg = batch["candidate_neg"].to(device)

        embed_G, embed_pos = model(G, A, S, C_pos)
        _, embed_neg = model(G, A, S, C_neg)

        loss = triplet_loss(embed_G, embed_pos, embed_neg)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}")


Epoch 1/5 - Loss: 0.1175
Epoch 2/5 - Loss: 0.0165
Epoch 3/5 - Loss: 0.0737
Epoch 4/5 - Loss: 0.1505
Epoch 5/5 - Loss: 0.0508


In [9]:
import os
os.makedirs("checkpoints_feature", exist_ok=True)

In [10]:
torch.save({
    'epoch': epoch + 1,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoints_feature/all_models_epoch05.pt')

print("Checkpoint salvato correttamente!")

Checkpoint salvato correttamente!


In [19]:
!cp -r checkpoints_feature drive/MyDrive/

In [12]:
@torch.no_grad()
def evaluate_matching_accuracy(model, dataloader, device):
    model.eval()
    total = 0
    correct = 0

    for batch in dataloader:
        G = batch["ground"].to(device)
        A = batch["generated"].to(device)
        S = batch["seg"].to(device)
        C_pos = batch["candidate"].to(device)
        C_neg = batch["candidate_neg"].to(device)

        embed_G, embed_pos = model(G, A, S, C_pos)
        _, embed_neg = model(G, A, S, C_neg)

        sim_pos = F.cosine_similarity(embed_G, embed_pos, dim=1)
        sim_neg = F.cosine_similarity(embed_G, embed_neg, dim=1)

        correct += (sim_pos > sim_neg).sum().item()
        total += G.size(0)

    acc = correct / total
    print(f"Accuracy (pos > neg): {acc:.4f}")
    return acc

In [13]:

acc = evaluate_matching_accuracy(
    model=model,
    dataloader=loader,
    device=device
    #k=5
)

print(f"Accuracy: {acc:.2f}%")

Accuracy (pos > neg): 0.5016
Accuracy: 0.50%


In [14]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [15]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
import random

class MatchingEvalDataset(Dataset):
    def __init__(self, ground_dir, generated_dir, seg_dir, candidate_dir,
                 transform=None, num_candidates=10):
        self.ground_dir = ground_dir
        self.generated_dir = generated_dir
        self.seg_dir = seg_dir
        self.candidate_dir = candidate_dir
        self.transform = transform
        self.num_candidates = num_candidates

        self.ground_files = sorted([f for f in os.listdir(ground_dir) if f.endswith(".jpg")])
        self.generated_files = sorted([f for f in os.listdir(generated_dir) if f.endswith(".png")])
        self.seg_files = sorted([f for f in os.listdir(seg_dir) if f.endswith(".png")])
        self.candidate_files = sorted([f for f in os.listdir(candidate_dir) if f.endswith(".png")])

        self.candidate_map = {f: i for i, f in enumerate(self.candidate_files)}

        assert len(self.generated_files) == len(self.seg_files), \
            "Mismatch tra immagini generate e segmentate!"
        assert len(self.generated_files) <= len(self.ground_files), \
            "Più immagini generate che ground!"

    def __len__(self):
        return len(self.generated_files)

    def __getitem__(self, idx):
        ground_filename = self.ground_files[idx]
        image_id = os.path.splitext(ground_filename)[0]
        gt_candidate_filename = f"input{image_id}.png"

        if gt_candidate_filename not in self.candidate_map:
            raise ValueError(f"Candidato corretto {gt_candidate_filename} non trovato")

        ground_img = Image.open(os.path.join(self.ground_dir, ground_filename)).convert("RGB")
        gen_img = Image.open(os.path.join(self.generated_dir, self.generated_files[idx])).convert("RGB")
        seg_img = Image.open(os.path.join(self.seg_dir, self.seg_files[idx])).convert("RGB")
        gt_candidate_img = Image.open(os.path.join(self.candidate_dir, gt_candidate_filename)).convert("RGB")

        all_candidates = set(self.candidate_files)
        all_candidates.remove(gt_candidate_filename)
        negative_candidates = random.sample(list(all_candidates), self.num_candidates - 1)

        candidate_imgs = []
        gt_index = random.randint(0, self.num_candidates - 1)

        for i in range(self.num_candidates):
            if i == gt_index:
                img = gt_candidate_img
            else:
                fname = negative_candidates.pop()
                img = Image.open(os.path.join(self.candidate_dir, fname)).convert("RGB")
            if self.transform:
                img = self.transform(img)
            candidate_imgs.append(img)

        if self.transform:
            ground_img = self.transform(ground_img)
            gen_img = self.transform(gen_img)
            seg_img = self.transform(seg_img)

        candidates_tensor = torch.stack(candidate_imgs, dim=0)

        return {
            "ground": ground_img,
            "generated": gen_img,
            "seg": seg_img,
            "candidates": candidates_tensor,
            "gt_index": gt_index
        }


In [16]:
import torch.nn.functional as F

def evaluate_matching_topk(model, dataloader, device, topk=[1, 5]):
    model.eval()
    correct_at_k = {k: 0 for k in topk}
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            G = batch["ground"].to(device)
            A = batch["generated"].to(device)
            S = batch["seg"].to(device)
            C_all = batch["candidates"].to(device)
            gt_indices = batch["gt_index"]

            B, N, C, H, W = C_all.shape
            C_all_flat = C_all.view(B * N, C, H, W)

            embed_G, _ = model(G, A, S, C_all[:, 0])
            embed_G = embed_G.unsqueeze(1)

            embed_C = []
            for i in range(N):
                _, embed_ci = model(G, A, S, C_all[:, i])
                embed_C.append(embed_ci)
            embed_C = torch.stack(embed_C, dim=1)

            sims = F.cosine_similarity(embed_G, embed_C, dim=2)

            ranks = sims.argsort(dim=1, descending=True)

            for k in topk:
                for i in range(B):
                    if gt_indices[i] in ranks[i, :k]:
                        correct_at_k[k] += 1

            total += B

    for k in topk:
        acc = correct_at_k[k] / total
        print(f"Top-{k} accuracy: {acc:.4f}")

    return [correct_at_k[k] / total for k in topk]


In [17]:
eval_dataset = MatchingEvalDataset(
    ground_dir="/content/drive/MyDrive/Dataset_Computer_Vision/Dataset_CVUSA/Dataset_CVUSA/streetview",
    generated_dir="/content/drive/MyDrive/generated_images",
    seg_dir="/content/drive/MyDrive/generated_seg",
    candidate_dir="/content/drive/MyDrive/Dataset_Computer_Vision/Dataset_CVUSA/Dataset_CVUSA/bingmap",
    transform=transform_rgb,
    num_candidates=10
)

eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=8, shuffle=False)


In [18]:
top1, top5 = evaluate_matching_topk(model, eval_loader, device)

Top-1 accuracy: 0.0898
Top-5 accuracy: 0.5038
