In [1]:
import argparse
import yaml
from datasets import get_dataset
from tqdm import tqdm
from models import get_model
import torch
from torch.autograd import Variable
import os
from torchvision.utils import save_image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

def get_config(config):
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
config = get_config('./configs/ucla_gan256.yaml')
config['batch_size'] = 1
# all we are really getting here is the path for the test data ^^^

lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex', normalize=True).cuda()
fid = FrechetInceptionDistance(feature=2048, normalize=True).cuda()
inception = InceptionScore(normalize=True).cuda()

#load model.pt
#Create model
discriminator, generator = get_model(config)

generator = torch.nn.DataParallel(generator, device_ids=[0])

checkpoint = torch.load(os.path.join('./ckpts','UCLA WGANGP (full conditional)','model.pt'))
generator.load_state_dict(checkpoint['g_model_state_dict'])

train_loader, test_loader = get_dataset(config)

save_folder = os.path.join('./ckpts','UCLA WGANGP (full conditional)','test_set')
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

for sample in tqdm(test_loader):
    fname = sample['fname'][0]
    labels = Variable(sample['label']).cuda().type(torch.FloatTensor)[0]
    image = Variable(sample['image']).type(torch.FloatTensor).cuda()
    print(labels)
    prompt = ""
    negative_prompt=""
    if labels[0] == 1:
        prompt+= "This is an image of a protest. "
    else:
        prompt+= "This is not an image of a protest. "
    if labels[1] == 1:
        prompt+= "The protest is violent. "
    if labels[2] == 1:
        prompt+= "A protester is holding a visual sign. "
    if labels[3] == 1:
        prompt+= "The sign contains a photo of a person. "
    if labels[4] == 1:
        prompt+= "There is fire or smoke in the scene. "
    if labels[5] == 1:
        prompt+= "Police or troops are present in the scene. "
    if labels[6] == 1:
        prompt+= "There are children in the scene. "
    if labels[7] == 1:
        prompt+= "There are roughly more than 20 people in the scene. "
    if labels[8] == 1:
        prompt+= "There are roughly more than 100 people in the scene. "
    if labels[9] == 1:
        prompt+= "There are flags in the scene. "
    if labels[10] == 1:
        prompt+= "The scene is at night. "
    if labels[11] == 1:
        prompt+= "There are one or more people shouting. "
    print(prompt)
    break

  0%|          | 0/8153 [00:00<?, ?it/s]

tensor([1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.])
This is an image of a protest. A protester is holding a visual sign. There are roughly more than 20 people in the scene. There are roughly more than 100 people in the scene. The scene is at night. 



