In [2]:
import json
import os

generated_dir = '../results/2023-01-25_14.40.38.728069'
coco_dir = '../coco'

with open(os.path.join(generated_dir, 'generated.json')) as f:
    generated_pointers = json.load(f) 

In [2]:
generated_pointers[0:20]

[{'image_id': 203564,
  'caption': 'A bicycle replica with a clock as the front wheel.',
  'generated_image': '203564_a-bicycle-replica-with-a-clock-as-the-front-wheel.jpg'},
 {'image_id': 179765,
  'caption': 'A black Honda motorcycle parked in front of a garage.',
  'generated_image': '179765_a-black-honda-motorcycle-parked-in-front-of-a-garage.jpg'},
 {'image_id': 322141,
  'caption': 'A room with blue walls and a white sink and door.',
  'generated_image': '322141_a-room-with-blue-walls-and-a-white-sink-and-door.jpg'},
 {'image_id': 16977,
  'caption': 'A car that seems to be parked illegally behind a legally parked car',
  'generated_image': '16977_a-car-that-seems-to-be-parked-illegally-behind-a-legally-parked-car.jpg'},
 {'image_id': 106140,
  'caption': 'A large passenger airplane flying through the air.',
  'generated_image': '106140_a-large-passenger-airplane-flying-through-the-air.jpg'},
 {'image_id': 106140,
  'caption': 'There is a GOL plane taking off in a partly cloudy s

In [3]:
"""
CLIP-score doesn't actually need the real image; I am just testing that we are able to locate it.
We will need this for FID-score.
"""

def get_real_image_path(coco_dir, image_id):
    image_id_str = str(image_id).zfill(12) # pad to 12 characters
    image_base_fname = f"COCO_val2014_{image_id_str}.jpg"
    image_path = os.path.join(coco_dir, 'val2014', image_base_fname)
    return image_path

real_image_path = get_real_image_path(coco_dir, generated_pointers[0]['image_id'])

real_image_path

'../coco/val2014/COCO_val2014_000000203564.jpg'

In [10]:
# CLIP score demo

import torch
from torchvision.io import read_image
import torchvision.transforms as T
from torchmetrics.multimodal import CLIPScore
device = torch.device('cuda') 

metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14")

for i in range(len(generated_pointers)):
    generated_image_path = os.path.join(generated_dir, generated_pointers[i]['generated_image'])
    generated_image = read_image(generated_image_path)

    # img = T.ToPILImage()(generated_image)
    # img.show()

    metric.update(generated_image, generated_pointers[i]['caption'])

    if i%100 == 0:
        print(f"starting {i}")


starting 0
starting 100
starting 200
starting 300
starting 400
starting 500
starting 600
starting 700
starting 800
starting 900
starting 1000
starting 1100
starting 1200
starting 1300
starting 1400
starting 1500
starting 1600
starting 1700
starting 1800
starting 1900
starting 2000
starting 2100
starting 2200
starting 2300
starting 2400
starting 2500
starting 2600
starting 2700
starting 2800
starting 2900
starting 3000
starting 3100
starting 3200
starting 3300
starting 3400
starting 3500
starting 3600
starting 3700
starting 3800
starting 3900
starting 4000
starting 4100
starting 4200
starting 4300
starting 4400
starting 4500
starting 4600
starting 4700
starting 4800
starting 4900
starting 5000
starting 5100
starting 5200
starting 5300
starting 5400
starting 5500
starting 5600
starting 5700
starting 5800
starting 5900
starting 6000
starting 6100
starting 6200
starting 6300
starting 6400
starting 6500
starting 6600
starting 6700
starting 6800
starting 6900
starting 7000
starting 7100
star

In [11]:
metric.compute()

tensor(26.1764)

In [19]:
def convert_to_3channel(img):
    if img.shape[0] == 1:
        return img.expand(3,*img.shape[1:])

    else:
        return img

# prep for FID score demo

generated_images = []
real_images = []

for i in range(len(generated_pointers)):
    image_path = get_real_image_path(coco_dir, generated_pointers[i]['image_id'])
    # print(image_path)
    real_images.append(convert_to_3channel(read_image(image_path)))


for i in range(len(generated_pointers)):
    image_path = os.path.join(generated_dir, generated_pointers[i]['generated_image'])
    generated_images.append(convert_to_3channel(read_image(image_path)))


    

In [22]:
# FID score demo

from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048)

"""
We are running the following with batch size 1. It would likely be faster to use a larger batch size.
But, the real images have various resolutions and aspect ratios.
It's convenient to let torchmetrics take care of resizing the images.
"""

i = 0
for generated, real in zip(generated_images, real_images):
    if i%100 == 0:
        print(f"generated.shape: {generated.shape}")
        # print(f"real.shape: {real.shape}")
        # print(f"fid score iteration {i}")
    
    fid.update(generated.unsqueeze(0), real=False)
    fid.update(real.unsqueeze(0), real=True)
    
    i += 1
    
fid_score = fid.compute()
print(f"fid score: {fid_score}")

generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 480, 640])
fid score iteration 100
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 640, 480])
fid score iteration 200
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 375, 500])
fid score iteration 300
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 442, 640])
fid score iteration 400
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 425, 640])
fid score iteration 500
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 443, 640])
fid score iteration 600
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 427, 640])
fid score iteration 700
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 361, 640])
fid score iteration 800
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 424, 640])
fid score iteration 900
generated.shape: torch.Size([3, 512, 512])
real.shape: 

generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 426, 640])
fid score iteration 7900
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 426, 640])
fid score iteration 8000
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 427, 640])
fid score iteration 8100
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 425, 640])
fid score iteration 8200
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 480, 640])
fid score iteration 8300
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 612, 612])
fid score iteration 8400
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 426, 640])
fid score iteration 8500
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 480, 640])
fid score iteration 8600
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 640, 480])
fid score iteration 8700
generated.shape: torch.Size([3, 512, 512])
rea

generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 484, 640])
fid score iteration 15600
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 375, 500])
fid score iteration 15700
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 329, 640])
fid score iteration 15800
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 426, 640])
fid score iteration 15900
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 640, 360])
fid score iteration 16000
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 427, 640])
fid score iteration 16100
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 427, 640])
fid score iteration 16200
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 428, 640])
fid score iteration 16300
generated.shape: torch.Size([3, 512, 512])
real.shape: torch.Size([3, 480, 640])
fid score iteration 16400
generated.shape: torch.Size([3, 512, 

In [9]:
import torch
img = torch.zeros((1, 200, 200))
img = img.expand(3,*img.shape[1:])

In [11]:
img.shape[0]

3