In [None]:
# Original FID using InceptionV3
import os
import numpy as np
import torch
from torchvision.models import inception_v3
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from pytorch_fid import fid_score
from scipy.special import kl_div

torch.cuda.set_device(2)


# Compute FID
def compute_fid(model, real_loader, fake_loader):
    print("Computing FID...")
    real_activations, fake_activations = [], []

    with torch.no_grad():
        for real_images, _ in real_loader:
            real_images = real_images.cuda()
            real_activations.append(model(real_images))

        for fake_images, _ in fake_loader:
            fake_images = fake_images.cuda()
            fake_activations.append(model(fake_images))

    real_activations = torch.cat(real_activations, 0)
    fake_activations = torch.cat(fake_activations, 0)

    mu1 = torch.mean(real_activations, dim=0)
    mu2 = torch.mean(fake_activations, dim=0)

    sigma1 = np.cov(real_activations.cpu().numpy(), rowvar=False)
    sigma2 = np.cov(fake_activations.cpu().numpy(), rowvar=False)

    fid = fid_score.calculate_frechet_distance(
        mu1.cpu().numpy(), sigma1, mu2.cpu().numpy(), sigma2
    )

    return fid


# Compute KL Divergence
def compute_kl(real_loader, fake_loader):
    print("Computing KL Divergence...")
    real_histogram = np.zeros(256)
    fake_histogram = np.zeros(256)

    for real_images, _ in real_loader:
        real_histogram += np.histogram(
            real_images.numpy().ravel(), bins=256, range=(0, 1)
        )[0]

    for fake_images, _ in fake_loader:
        fake_histogram += np.histogram(
            fake_images.numpy().ravel(), bins=256, range=(0, 1)
        )[0]

    real_histogram /= real_histogram.sum()
    fake_histogram /= fake_histogram.sum()

    kl = kl_div(real_histogram + 1e-10, fake_histogram + 1e-10).sum()

    return kl


# Define paths to your folders
fake_folder = "/workspace/dso/playground/fid_kl/data/loragen"
real_folder = "/workspace/dso/playground/fid_kl/data/fusrs"
pre_trained_ResNet50 = "/workspace/dso/clsar/outputs/res50_fusrs_v2_pretrain/res50_1x128_lr1e-1+200e+im21k_fusrs_v2/best_f1_score_epoch_158.pth"

# Create data loaders
transform = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
    ]
)

real_dataset = ImageFolder(real_folder, transform=transform)
fake_dataset = ImageFolder(fake_folder, transform=transform)

# Initialize Inception model
inception_model = inception_v3(pretrained=True, transform_input=False).cuda()
inception_model.fc = torch.nn.Identity()
inception_model = inception_model.eval()

# Compute FID and KL Divergence for each category
for i, category in enumerate(["cargo", "fishing", "dredger", "tanker"]):
    print(f"Category: {category}")

    indices_real = [idx for idx, label in enumerate(real_dataset.targets) if label == i]
    indices_fake = [idx for idx, label in enumerate(fake_dataset.targets) if label == i]

    real_loader = DataLoader(
        torch.utils.data.Subset(real_dataset, indices_real),
        batch_size=28,
        shuffle=False,
    )
    fake_loader = DataLoader(
        torch.utils.data.Subset(fake_dataset, indices_fake),
        batch_size=28,
        shuffle=False,
    )

    fid = compute_fid(inception_model, real_loader, fake_loader)
    print(f"FID: {fid}")
    kl = compute_kl(real_loader, fake_loader)
    print(f"KL Divergence: {kl}\n")

In [None]:
# Original FID using SAR pre-trained ResNet50
import os
import numpy as np
import torch
from torchvision.models import inception_v3
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from pytorch_fid import fid_score
from scipy.special import kl_div
from mmpretrain import get_model
from sklearn.neighbors import KernelDensity
from scipy.special import kl_div
import numpy as np

torch.cuda.set_device(2)


# Compute FID
def compute_fid(model, real_loader, fake_loader):
    print("Computing FID...")
    real_activations, fake_activations = [], []

    with torch.no_grad():
        for real_images, _ in real_loader:
            real_images = real_images.cuda()
            real_activations.append(model(real_images)[0])

        for fake_images, _ in fake_loader:
            fake_images = fake_images.cuda()
            fake_activations.append(model(fake_images)[0])

    real_activations = torch.cat(real_activations, 0)
    fake_activations = torch.cat(fake_activations, 0)

    mu1 = torch.mean(real_activations, dim=0)
    mu2 = torch.mean(fake_activations, dim=0)

    sigma1 = np.cov(real_activations.cpu().numpy(), rowvar=False)
    sigma2 = np.cov(fake_activations.cpu().numpy(), rowvar=False)

    fid = fid_score.calculate_frechet_distance(
        mu1.cpu().numpy(), sigma1, mu2.cpu().numpy(), sigma2
    )

    return fid


def compute_kl(model, real_loader, fake_loader):
    # Extract features from real images
    real_features = []
    for images, _ in real_loader:
        images = images.cuda()
        with torch.no_grad():
            features = model(images)[0]
        real_features.append(features.cpu().numpy())
    real_features = np.concatenate(real_features)

    # Extract features from fake images
    fake_features = []
    for images, _ in fake_loader:
        images = images.cuda()
        with torch.no_grad():
            features = model(images)[0]
        fake_features.append(features.cpu().numpy())
    fake_features = np.concatenate(fake_features)

    # Fit KDE to real and fake features
    kde_real = KernelDensity(kernel="gaussian", bandwidth=0.2).fit(real_features)
    kde_fake = KernelDensity(kernel="gaussian", bandwidth=0.2).fit(fake_features)

    # Compute KL divergence
    log_dens_real = kde_real.score_samples(real_features)
    log_dens_fake = kde_fake.score_samples(fake_features)
    kl_real_fake = kl_div(log_dens_real, log_dens_fake).sum()
    kl_fake_real = kl_div(log_dens_fake, log_dens_real).sum()

    # Return symmetric KL divergence
    return 0.5 * (kl_real_fake + kl_fake_real)


# Define paths to your folders
fake_folder = "//workspace/dso/gensar/lora/output/sarlora/256/rank/256_fp32_s20000+100e_wp0_bs32_lr1e-03_rank1"
real_folder = "/workspace/dso/gensar/geneval/data/fusrs"
pre_trained_ResNet50 = "/workspace/dso/clsar/outputs/res50_fusrs_v2_pretrain/res50_1x128_lr1e-1+200e+im21k_fusrs_v2/best_f1_score_epoch_158.pth"

# Create data loaders
transform = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
    ]
)

real_dataset = ImageFolder(real_folder, transform=transform)
fake_dataset = ImageFolder(fake_folder, transform=transform)

# Initialize SAR pre-trained ResNet50
model = get_model(
    "resnet50_8xb32_in1k",
    head=None,  # to extract only activation vectors
    pretrained=pre_trained_ResNet50,
).cuda()
model = model.eval()

# Compute FID and KL Divergence for each category
for i, category in enumerate(["cargo", "fishing", "dredger", "tanker"]):
    print(f"Category: {category}")

    indices_real = [idx for idx, label in enumerate(real_dataset.targets) if label == i]
    indices_fake = [idx for idx, label in enumerate(fake_dataset.targets) if label == i]

    real_loader = DataLoader(
        torch.utils.data.Subset(real_dataset, indices_real),
        batch_size=28,
        shuffle=False,
    )
    fake_loader = DataLoader(
        torch.utils.data.Subset(fake_dataset, indices_fake),
        batch_size=28,
        shuffle=False,
    )

    fid = compute_fid(model, real_loader, fake_loader)
    print(f"FID: {fid}")
    # kl = compute_kl(model, real_loader, fake_loader)
    # print(f"KL Divergence: {kl}\n")

In [None]:
from mmpretrain import get_model

model = get_model(
    "resnet50_8xb32_in1k",
    # head=dict(num_classes=5),
    head=None,
    pretrained="/workspace/dso/clsar/outputs/res50_fusrs_v2_pretrain/res50_1x128_lr1e-1+200e+im21k_fusrs_v2/best_f1_score_epoch_158.pth",
)

In [None]:
import torch

model = model.cuda().eval()
x = torch.rand((1, 3, 224, 224)).cuda()
y = model(x)[0]
print(type(y), y.shape)

In [None]:
model