In [1]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import torch
from torchvision.datasets.cifar import CIFAR10
from fld.datasets.ImageTensorDataset import ImageTensorDataset
from fld.features.DINOv2FeatureExtractor import DINOv2FeatureExtractor
from fld.metrics.FLD import FLD

feature_extractor = DINOv2FeatureExtractor()

train_feat = feature_extractor.get_features(CIFAR10(train=True, root="data", download=True))
test_feat = feature_extractor.get_features(CIFAR10(train=False, root="data", download=True))

Using cache found in /home/mila/k/kirill.neklyudov/.cache/torch/hub/facebookresearch_dinov2_main


Files already downloaded and verified


                                                 

Files already downloaded and verified


                                               

In [2]:
import glob
import numpy as np
import os

def get_samples(sample_dir):
    all_samples = []
    stats = glob.glob(os.path.join(sample_dir, "samples_*.npz"))
    for stat_file in stats:
        with open(stat_file, "rb") as fin:
            stat = np.load(fin)
            all_samples.append(stat['samples'].transpose((0,3,1,2)))
    all_samples = np.concatenate(all_samples, axis=0)[:50_000]
    return ImageTensorDataset(torch.tensor(all_samples))

def get_fld(feats):
  fld_vals = []
  for _ in range(10):
    fld_vals.append(FLD().compute_metric(train_feat, test_feat, feats))
  return np.array(fld_vals)

In [None]:
feat_joint_det = feature_extractor.get_features(get_samples('../checkpoint/temp_ab_joint_vf/eval/samples/'))
feat_joint_stoch = feature_extractor.get_features(get_samples('../checkpoint/temp_ab_joint_vf/eval/samples_stoch/'))

fld_joint_det = get_fld(feat_joint_det)
fld_joint_stoch = get_fld(feat_joint_stoch)
print(f"FLD_joint_det: {fld_joint_det.mean():.3f}±{fld_joint_det.std():.3f}")
print(f"FLD_joint_stoch: {fld_joint_stoch.mean():.3f}±{fld_joint_stoch.std():.3f}")

In [6]:
feat_a = feature_extractor.get_features(get_samples('/network/scratch/k/kirill.neklyudov/5294839/eval/samples'))
feat_b = feature_extractor.get_features(get_samples('/network/scratch/k/kirill.neklyudov/5294900/eval/samples'))
feat_a_stoch = feature_extractor.get_features(get_samples('/network/scratch/k/kirill.neklyudov/5294839/eval/samples_stoch'))
feat_b_stoch = feature_extractor.get_features(get_samples('/network/scratch/k/kirill.neklyudov/5294900/eval/samples_stoch'))
feat_joint_det = feature_extractor.get_features(get_samples('../checkpoint/cond_joint_vf/eval/samples/'))
feat_joint_stoch = feature_extractor.get_features(get_samples('../checkpoint/cond_joint_vf/eval/samples_stoch/'))

                                                 

In [7]:
fld_a = get_fld(feat_a)
fld_b = get_fld(feat_b)
fld_a_stoch = get_fld(feat_a_stoch)
fld_b_stoch = get_fld(feat_b_stoch)
fld_joint_det = get_fld(feat_joint_det)
fld_joint_stoch = get_fld(feat_joint_stoch)
fld_mixed = get_fld(torch.concatenate([feat_a[:25_000],feat_b[:25_000]]))
fld_mixed_stoch = get_fld(torch.concatenate([feat_a_stoch[:25_000],feat_b_stoch[:25_000]]))

                                               

In [8]:
print(f"FLD_A: {fld_a.mean():.3f}±{fld_a.std():.3f}")
print(f"FLD_B: {fld_b.mean():.3f}±{fld_b.std():.3f}")
print(f"FLD_A_stoch: {fld_a_stoch.mean():.3f}±{fld_a_stoch.std():.3f}")
print(f"FLD_B_stoch: {fld_b_stoch.mean():.3f}±{fld_b_stoch.std():.3f}")
print(f"FLD_joint_det: {fld_joint_det.mean():.3f}±{fld_joint_det.std():.3f}")
print(f"FLD_joint_stoch: {fld_joint_stoch.mean():.3f}±{fld_joint_stoch.std():.3f}")
print(f"FLD_Mixed: {fld_mixed.mean():.3f}±{fld_mixed.std():.3f}")
print(f"FLD_Mixed_stoch: {fld_mixed_stoch.mean():.3f}±{fld_mixed_stoch.std():.3f}")

FLD_A: 6.824±0.086
FLD_B: 7.059±0.130
FLD_A_stoch: 6.263±0.112
FLD_B_stoch: 6.302±0.160
FLD_joint_det: 6.860±0.082
FLD_joint_stoch: 6.195±0.082
FLD_Mixed: 7.038±0.115
FLD_Mixed_stoch: 6.269±0.151


## Unconditional entire CIFAR

In [3]:
feat_det = feature_extractor.get_features(get_samples('/network/scratch/k/kirill.neklyudov/5617628/eval/samples'))
feat_stoch = feature_extractor.get_features(get_samples('/network/scratch/k/kirill.neklyudov/5617628/eval/samples_stoch'))

                                                 

In [4]:
fld_det = get_fld(feat_det)
fld_stoch = get_fld(feat_stoch)

                                               

In [5]:
print(f"FLD_det: {fld_det.mean():.3f}±{fld_det.std():.3f}")
print(f"FLD_stoch: {fld_stoch.mean():.3f}±{fld_stoch.std():.3f}")

FLD_det: 8.059±0.116
FLD_stoch: 7.508±0.112
