In [1]:
import os
import random
import torch
import numpy as np 
import matplotlib.pyplot as plt
from torchvision.io import read_image
from torchvision import transforms
from tqdm.notebook import tqdm
from metrics import FrechetInceptionDistance, KernelInceptionDistance, IPR

from dataset import init_dataset

In [2]:
experiment_dir = "/datatmp/users/igeorvasilis/ddpm-continual-learning/results/cifar10/mask_v1_re/0"

args_dir = os.path.join(experiment_dir, "args.pt")
args = torch.load(args_dir)
args.device = 'cuda:2'
args.dataset_path = "/datatmp/users/igeorvasilis/datasets/cifar10"

args.pretrained_model_dir = os.path.join(experiment_dir, "epoch-200")
args.train_log_dir = os.path.join(experiment_dir, "train_log.pt")
args.folder_name = "ddim_fake_images"

train_log = torch.load(args.train_log_dir)

torch.manual_seed(args.gen_seed)
np.random.seed(args.gen_seed)
random.seed(args.gen_seed)

In [3]:
# load pretrained model
import json

with open(f"{args.pretrained_model_dir}/unet/config.json") as f:
    # load json, 'r') as f:
    config = json.load(f)
config 

{'_class_name': 'UNet2DModel',
 '_diffusers_version': '0.23.0',
 'act_fn': 'silu',
 'add_attention': True,
 'attention_head_dim': 8,
 'attn_norm_num_groups': None,
 'block_out_channels': [128, 128, 256, 256, 512, 512],
 'center_input_sample': False,
 'class_embed_type': None,
 'down_block_types': ['DownBlock2D',
  'DownBlock2D',
  'DownBlock2D',
  'DownBlock2D',
  'AttnDownBlock2D',
  'DownBlock2D'],
 'downsample_padding': 1,
 'downsample_type': 'conv',
 'dropout': 0.0,
 'flip_sin_to_cos': True,
 'freq_shift': 0,
 'in_channels': 3,
 'layers_per_block': 2,
 'mid_block_scale_factor': 1,
 'norm_eps': 1e-05,
 'norm_num_groups': 32,
 'num_class_embeds': None,
 'num_train_timesteps': None,
 'out_channels': 3,
 'resnet_time_scale_shift': 'default',
 'sample_size': 32,
 'time_embedding_type': 'positional',
 'up_block_types': ['UpBlock2D',
  'AttnUpBlock2D',
  'UpBlock2D',
  'UpBlock2D',
  'UpBlock2D',
  'UpBlock2D'],
 'upsample_type': 'conv'}

In [4]:
# print args values
print("Arguments:")
for k, v in vars(args).items():
    print(f"\t{k}: {v}")

Arguments:
	device: cuda:2
	dataset_name: CIFAR10
	dataset_path: /datatmp/users/igeorvasilis/datasets/cifar10
	target_dir: ./dataset
	pr_flip: False
	pr_rotate: False
	labels: [0]
	num_train_timesteps: 1000
	beta_start: 0.0001
	beta_end: 0.02
	beta_schedule: squaredcos_cap_v2
	mask: None
	num_tasks: None
	pipeline: ddim
	num_inference_steps: 50
	image_size: 32
	in_channels: 3
	out_channels: 3
	layers_per_block: 2
	block_out_channels: [128, 128, 256, 256, 512, 512]
	down_block_types: ['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D', 'DownBlock2D']
	up_block_types: ['UpBlock2D', 'AttnUpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D']
	num_epochs: 200
	train_batch_size: 64
	eval_batch_size: 64
	gradient_accumulation_steps: 1
	learning_rate: 0.0001
	lr_warmup_steps: 5
	pretrained_model_dir: /datatmp/users/igeorvasilis/ddpm-continual-learning/results/cifar10/mask_v1_re/0/epoch-200
	sample_image_epochs: 1
	generate_image_epochs: 200
	n_fake_images: 

# Evaluation

### Frechet Inception Distance (FID)

In [9]:
args.labels = [0, 1]

In [5]:
preprocess = transforms.Compose([ transforms.ToTensor()])

In [6]:
# List all fake images from the corresponding directory 
fake_images_dir = f'{args.pretrained_model_dir}/{args.folder_name}'
fake_images_list = os.listdir(fake_images_dir)
n_fake_images = len(fake_images_list)

# Create eval dataloader
trainset, testset = init_dataset(dataset_name=args.dataset_name, dataset_path=args.dataset_path, 
                                 labels=args.labels, preprocess=transforms.Compose([ transforms.ToTensor()]))

# trainset = torch.utils.data.ConcatDataset([trainset, testset])
eval_dataloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.eval_batch_size, shuffle=True)

# Define FID metric
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(args.device)

# Iterate over all fake images
n_images_to_eval = min(len(fake_images_list), len(eval_dataloader.dataset))
for batch_idx in tqdm(range(0, n_images_to_eval, args.eval_batch_size), 
                      desc=f'Calculating FID for {n_images_to_eval} images...'):
    
    # Get the real images
    real_images, _ = next(iter(eval_dataloader))
    real_images = real_images.to(args.device)

    # Get the fake images
    fake_images = [read_image(f"{fake_images_dir}/{i}") for i in fake_images_list[batch_idx:batch_idx+args.eval_batch_size]]
    fake_images = torch.stack(fake_images).to(args.device) 
    fake_images = fake_images.float() / 255.0
    
    real_images = real_images[:fake_images.shape[0]]

    # Update the FID metric
    fid.update(real_images, real=True)
    fid.update(fake_images, real=False)

# Compute the FID score
print(fid.compute())
fid.reset()

Number of training images: 5000, Number of test images: 1000


Calculating FID for 5000 images...:   0%|          | 0/79 [00:00<?, ?it/s]

tensor(116.8470, device='cuda:2')


### Kernel Inception Distance (KID) 

In [18]:
# List all fake images from the corresponding directory 
fake_images_dir = f'{args.pretrained_model_dir}/{args.folder_name}'
fake_images_list = os.listdir(fake_images_dir)
n_fake_images = len(fake_images_list)

# Create eval dataloader
trainset, testset = init_dataset(dataset_name=args.dataset_name, dataset_path=args.dataset_path, 
                                 labels=args.labels, preprocess=transforms.Compose([ transforms.ToTensor()]))

# trainset = torch.utils.data.ConcatDataset([trainset, testset])
eval_dataloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.eval_batch_size, shuffle=True)

# Define kid metric
subsets = 100
subset_size = 1000
kid = KernelInceptionDistance(
    feature=2048, 
    normalize=True,
    subsets=subsets,
    subset_size=subset_size
).to(args.device)

# Iterate over all fake images
n_images_to_eval = min(len(fake_images_list), len(eval_dataloader.dataset))
for batch_idx in tqdm(range(0, n_images_to_eval, args.eval_batch_size), desc='Calculating kid...'):
    
    # Get the corresponding real images
    real_images, _ = next(iter(eval_dataloader))
    real_images = real_images.to(args.device)

    # Get the fake images
    fake_images = [read_image(f"{fake_images_dir}/{i}") for i in fake_images_list[batch_idx:batch_idx+args.eval_batch_size]]
    fake_images = torch.stack(fake_images).to(args.device) 
    fake_images = fake_images.float() / 255.0

    # Update the kid metric
    kid.update(real_images, real=True)
    kid.update(fake_images, real=False)

# Compute the kid score
print(kid.compute())
kid.reset()

Number of training images: 5000, Number of test images: 1000


Calculating kid...:   0%|          | 0/25 [00:00<?, ?it/s]

(tensor(0.1294, device='cuda:2'), tensor(0.0016, device='cuda:2'))


### Precision Recall 

In [9]:
trainset, _ = init_dataset(dataset_name=args.dataset_name, dataset_path=args.dataset_path, labels=args.labels, preprocess=preprocess)
eval_dataloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.eval_batch_size, shuffle=True)

real_images = torch.cat([image for image, _ in eval_dataloader])

Number of training images: 10000, Number of test images: 2000


In [10]:
# List all fake images from the corresponding directory 
fake_images_dir = f'{args.pretrained_model_dir}/{args.folder_name}'
fake_images_list = os.listdir(fake_images_dir)
fake_images = torch.stack([read_image(f'{fake_images_dir}/{i}') for i in fake_images_list])
fake_images = fake_images.float() / 255.0

In [11]:
n_images_to_eval = min(real_images.shape[0], fake_images.shape[0])
real_images = real_images[:n_images_to_eval]
fake_images = fake_images[:n_images_to_eval]

In [12]:
from improved_precision_recall import IPR 

# Define IPR metric
ipr = IPR(batch_size=8, k=3, num_samples=n_images_to_eval, device=args.device)

# Compute Manifold 
ipr.compute_manifold_ref(real_images)

loading vgg16 for improved precision and recall...done
IPR: resizing (32, 32) to (224, 224)


extracting features of 10000 images: 100%|██████████| 1250/1250 [00:33<00:00, 37.80it/s]


In [13]:
metric = ipr.precision_and_recall(fake_images)
# Print results
print('precision =', metric.precision)
print('recall =', metric.recall)

# r_score = ipr.realism(fake_images)
# print('realism =', r_score)

IPR: resizing (32, 32) to (224, 224)


extracting features of 10000 images: 100%|██████████| 1250/1250 [00:33<00:00, 37.73it/s]




computing precision...: 100%|██████████| 10000/10000 [00:00<00:00, 29049.32it/s]
computing recall...: 100%|██████████| 10000/10000 [00:00<00:00, 28971.21it/s]

precision = 0.8955
recall = 0.4318





### Classifier 

* Cifar10

In [14]:
# Load model directly
from transformers import AutoModelForImageClassification
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import math
from utils import get_preprocess_function
args.device = 'cuda:1'

model = AutoModelForImageClassification.from_pretrained("nateraw/vit-base-patch16-224-cifar10").to(args.device)
preprocess = get_preprocess_function("EvalCifar10")

pred_labels = {label:0 for label in args.labels}
pred_labels['no'] = 0

for b_idx in tqdm(range(0, len(fake_images_list), args.eval_batch_size)):
    
    fake_images_names = fake_images_list[b_idx:b_idx+args.eval_batch_size]
    fake_images = [ read_image(f'{fake_images_dir}/{fi_name}') for fi_name in fake_images_names ]
    fake_images = [ preprocess(image) for image in fake_images ]
    fake_images = torch.stack(fake_images)
    fake_images = fake_images.float() / 255.0

    # Get predictions
    outputs = model(fake_images.to(args.device))
    outputs = torch.softmax(outputs.logits, dim=1)
    predicted = torch.argmax(outputs, dim=1).cpu().numpy()    
    for label in predicted: 
        if label in pred_labels: pred_labels[label] += 1
        else: pred_labels['no'] += 1

# Compute the total count of all labels
total_count = sum(pred_labels.values())
# Compute the frequency of each label
frequencies = {label: count / total_count for label, count in pred_labels.items()}
print("Frequencies:", frequencies)

# Compute the entropy
entropy = -sum([p * math.log(p) for p in frequencies.values()])
print("Entropy:", entropy)


  0%|          | 0/313 [00:00<?, ?it/s]

Frequencies: {0: 0.4222244408945687, 1: 0.41902955271565495, 'no': 0.15874600638977635}
Entropy: 1.020691376559435


* Mnist

In [None]:
# from torchvision import models
# import torch.nn as nn
# import matplotlib.pyplot as plt
# import numpy as np
# from tqdm.notebook import tqdm
# import math

# checkpoint = torch.load('/datatmp/users/igeorvasilis/ddpm-continual-learning/results/mnist/eval_classifier/checkpoint_1n2.pth')

# # Load the actual and the opposite mapping
# map_labels = checkpoint['mapping']
# map_labels_inv = {v: k for k, v in map_labels.items()}

# # Define the pretrained model
# model = models.resnet18(pretrained=True)
# model.fc = nn.Linear(model.fc.in_features, len(args.labels))

# model.load_state_dict(checkpoint['state_dict'])
# model = model.to(args.device)

# # List all fake images from the corresponding directory 
# fake_images_dir = f'{args.pretrained_model_dir}/{args.folder_name}'
# fake_images_list = os.listdir(fake_images_dir)

# n_counts = {label:0 for label in map_labels.keys()}

# for b_idx in tqdm(range(0, len(fake_images_list), args.eval_batch_size)):
    
#     fake_images_names = fake_images_list[b_idx:b_idx+args.eval_batch_size]
#     fake_images = [ read_image(f'{fake_images_dir}/{fi_name}') for fi_name in fake_images_names ]
#     fake_images = torch.stack(fake_images)
#     fake_images = fake_images.float() / 255.0

#     # Get predictions
#     outputs = model(fake_images.to(args.device))
#     _, predicted = torch.max(outputs, 1)
    
#     predicted = [map_labels_inv[label.item()] for label in predicted]
#     for label in predicted: n_counts[label] += 1

# # Compute the total count of all labels
# total_count = sum(n_counts.values())
# # Compute the frequency of each label
# frequencies = {label: count / total_count for label, count in n_counts.items()}
# print("Frequencies:", frequencies)

# # Compute the entropy
# entropy = -sum([p * math.log(p) for p in frequencies.values()])
# print("Entropy:", entropy)

### Visualize

In [None]:
import matplotlib.pyplot as plt
from diffusers.utils import make_image_grid

fig, axes = plt.subplots(1, 2, figsize=(10, 5))

fakes = [transforms.ToPILImage()(image) for image in fake_images[:32]]
grid = make_image_grid(fakes, rows=4, cols=8)
axes[0].imshow(grid); axes[0].set_title(f'Fake images')
axes[0].axis('off')

reals = [transforms.ToPILImage()(image) for image in real_images[:32]]
grid = make_image_grid(reals, rows=4, cols=8)
axes[1].imshow(grid); axes[1].set_title(f'Real images')
axes[1].axis('off')