In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="7"
print(os.environ["CUDA_VISIBLE_DEVICES"])

7


In [2]:
from PIL import Image
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

import lavis
from lavis.models import load_model_and_preprocess

import random 
random.seed(43)

class ImgDataset(Dataset):
    def __init__(self, imgRoot, promptFile, vis_processors=None, txt_processors=None):
        
        self.paths = []
        for root, dirs, files in os.walk(imgRoot):
            for file in files:
                self.paths.append(os.path.join(root, file))
        self.paths.sort()
        
        print(f'Read {len(self.paths)} paths.')
        
        self.prompts = []
        with open(promptFile ,'r') as f:
            for line in f.readlines():
                self.prompts.append(line.strip())
        
        self.vis_processors = vis_processors['eval']
        self.txt_processors = txt_processors

    def __len__(self):
        
        return len(list(self.paths))

    def __getitem__(self, index):

        image_path = self.paths[index]
        image = Image.open(image_path).convert("RGB")
        if self.vis_processors:
            image = self.vis_processors(image)
        
        # sample = {}
        # sample['image'] = image
        # sample['prompt'] = self.prompts[index]

        return image, index


# This is for query lost of images

from os import listdir
from os.path import isfile, join
from tqdm import tqdm
from sklearn.metrics import accuracy_score, confusion_matrix
import time

class InstructBLIP():
    def __init__(self, name="blip2_vicuna_instruct_textinv", model_type="vicuna7b", is_eval=True, device="cpu") -> None:
        print(f'Loading model...')
        #self.model, self.vis_processors, self.txt_processors = load_model_and_preprocess(name, model_type, is_eval, device)
        self.imgs = []
        self.labels = []
        
        # QA
        self.question = ""
        
        # results
        self.acc = None
        self.confusion_mat = None
        
        self.acc_3class = None
        self.confusion_mat_3class = None
        
        self.com_acc = None
        self.com_confusion_mat = None
        self.uncom_acc = None
        self.uncom_confusion_mat = None

    def LoadModels(self, model, vis_processors, txt_processors, device):
        self.model = model
        self.vis_processors = vis_processors
        self.txt_processors = txt_processors
        self.device = device
    
    def LoadImages(self, dir, num):
        onlyfiles = []
        
        for f in sorted(listdir(dir)):
            if isfile(join(dir, f)):
                onlyfiles.append(join(dir, f))
        
        onlyfiles = random.sample(onlyfiles, num)
        
        raw_img_list = []
        with tqdm(total=len(onlyfiles), desc=f'Loading imgs from {dir}') as pbar:
            for f in onlyfiles:
                raw_img = Image.open(f).convert("RGB")
                raw_img_list.append(raw_img)
                pbar.update(1)
        
        return raw_img_list

    def LoadData(self, real_dir, fake_dir, num=1000):
        #real_imgs = LoadImages(join(root_dir, "0_real"))
        #fake_imgs = LoadImages(join(root_dir, "1_fake"))
        real_imgs = self.LoadImages(real_dir, num)
        fake_imgs = self.LoadImages(fake_dir, num)
        
        self.imgs = real_imgs + fake_imgs
        self.labels = [0]*len(real_imgs) + [1]*len(fake_imgs)
        #return self.imgs, self.labels
      
    def LoadData_batch(self, csv_path):
        self.csv = csv_path
        self.dataset = TextInvDataset(csv=csv_path, vis_processors=self.vis_processors["eval"])
        self.dataloader = DataLoader(dataset=self.dataset, batch_size=8, shuffle=False, num_workers=8)    
        
    def LoadData3Class(self, real_dir, fake_common_dir, fake_uncommon_dir, num=[1000, 500, 500]):
        #real_imgs = LoadImages(join(root_dir, "0_real"))
        #fake_imgs = LoadImages(join(root_dir, "1_fake"))
        self.num = num
        real_imgs = self.LoadImages(real_dir, num[0])
        fake_common_imgs = self.LoadImages(fake_common_dir, num[1])
        fake_uncommon_imgs = self.LoadImages(fake_uncommon_dir, num[2])
        
        self.imgs = real_imgs + fake_common_imgs + fake_uncommon_imgs
        self.labels = [0]*len(real_imgs) + [1]*(len(fake_common_imgs)+len(fake_uncommon_imgs))
        self.label_3class = [0]*len(real_imgs) + [1]*len(fake_common_imgs) + [2]*len(fake_uncommon_imgs)
        #return self.imgs, self.labels, self.label_3class

    def QueryImgs(self, question, true_string="yes"):
        self.ans_list = []
        self.question = question
        
        with tqdm(total=len(self.imgs), desc=f'Answering') as pbar:
            for idx, img in enumerate(self.imgs):
                image = self.vis_processors["eval"](img).unsqueeze(0).to(self.device)

                samples = {"image": image, "text_input": question}
                
                ans = self.model.predict_answers(samples=samples, inference_method="generate")[0]
                self.ans_list.append(0 if ans == true_string else 1)
                
                pbar.update(1)
        
        self.acc = accuracy_score(self.labels, self.ans_list)
        self.confusion_mat = confusion_matrix(self.labels, self.ans_list)
        
        self.PrintResult()
        
        return self.acc, self.confusion_mat, self.ans_list
    
    def QueryImgs_batch(self, question, true_string="yes", logPath='log.txt'):
        self.labels = []
        self.label_3class = []
        self.ans_list = []
        self.question = question
        
        for image, label, is_uncommon in tqdm(self.dataloader):
            
            image = image.to(self.device)
            
            questions = [self.question] * image.shape[0]
            samples = {"image": image, "text_input": questions}
            
            ans = self.model.predict_answers(samples=samples, inference_method="generate", answer_list=["yes", "no"])
            pred_label = [0 if a == true_string else 1 for a in ans]
            self.ans_list += pred_label
            
            label = [0 if l == true_string else 1 for l in label]
            self.labels += label
            
            label_3class = label.copy()
            label_3class = [2 if is_uncommon[idx] else l for idx, l in enumerate(label)]
            
            self.label_3class += label_3class
        
        self.acc = accuracy_score(self.labels, self.ans_list)
        self.confusion_mat = confusion_matrix(self.labels, self.ans_list, labels=[0,1])
        
        self.ans_list = np.array(self.ans_list)
        self.labels = np.array(self.labels)
        self.label_3class = np.array(self.label_3class)
        
        self.PrintResult(three_class=True, logPath=logPath)
        
        return self.acc, self.confusion_mat, self.ans_list, self.labels, self.label_3class
    
    def Query(self, image, question):
        image = self.vis_processors["eval"](image).unsqueeze(0).to(self.device)
        
        samples = {"image": image, "text_input": question}
        ans = self.model.predict_answers(samples=samples, inference_method="generate")[0]
        return ans

    def PrintResult(self, three_class=False, acc=None, confusion_mat=None, ans_list=None, labels=None, label_3class=None, logPath=None):
        
        if acc:
            self.acc = acc
        if confusion_mat:
            self.confusion_mat = confusion_mat
        if ans_list:
            self.ans_list = ans_list
        if labels:
            self.labels = labels
        if label_3class:
            self.label_3class = label_3class
        
        if logPath:
            logfile = open(logPath, 'a')
        
        if three_class:
            #assert type(self.num) == list, "Type of num should be list."
            
            print(f'[TIME]      : {time.ctime()}', file=logfile)
            print(f'[Finetuned] : {self.model.finetuned}', file=logfile)
            print(f'[Data csv]  : {self.csv}', file=logfile)
            print(f'[Question]  : {self.question}\n', file=logfile)
            
            print(f'=== Overall ===', file=logfile)
            print(f'Acc: {self.acc*100:.2f}%', file=logfile)
            self.PrintConfusion(self.confusion_mat, logfile=logfile)
            print('\n', file=logfile)
            
            real_ans_list = self.ans_list[self.label_3class==0]
            real_label = [0] * len(real_ans_list)
            self.real_acc = accuracy_score(real_label, real_ans_list)
            self.real_confusion_mat = confusion_matrix(real_label, real_ans_list, labels=[0,1])
            print(f'=== Real images ===', file=logfile)
            print(f'Acc: {self.real_acc*100:.2f}%', file=logfile)
            self.PrintConfusion(self.real_confusion_mat, logfile=logfile)
            print('\n', file=logfile)
            
            com_ans_list = self.ans_list[self.label_3class==1]
            com_label = [1] * len(com_ans_list)
            self.com_acc = accuracy_score(com_label, com_ans_list)
            self.com_confusion_mat = confusion_matrix(com_label, com_ans_list, labels=[0,1])
            print(f'=== Common fake images ===', file=logfile)
            print(f'Acc: {self.com_acc*100:.2f}%', file=logfile)
            self.PrintConfusion(self.com_confusion_mat, logfile=logfile)
            print('\n', file=logfile)
            
            uncom_ans_list = self.ans_list[self.label_3class==2]
            uncom_label = [1] * len(uncom_ans_list)
            self.uncom_acc = accuracy_score(uncom_label, uncom_ans_list)
            self.uncom_confusion_mat = confusion_matrix(uncom_label, uncom_ans_list, labels=[0,1])
            print(f'=== Uncommon fake images ===', file=logfile)
            print(f'Acc: {self.uncom_acc*100:.2f}%', file=logfile)
            self.PrintConfusion(self.uncom_confusion_mat, logfile=logfile)
            print('\n', file=logfile)
        else:
            print(f'Question: {self.question}\n', file=logfile)
            print(f'Acc: {self.acc*100:.2f}%', file=logfile)
            self.PrintConfusion(self.confusion_mat, logfile=logfile)
            print('\n', file=logfile)
        
        logfile.close()
    
    def PrintConfusion(self, mat, logfile):
        padding = ' '
        print(f'        | Pred real | Pred fake |', file=logfile)
        print(f'GT real | {mat[0, 0]:{padding}<{10}}| {mat[0, 1]:{padding}<{11}}|', file=logfile)
        print(f'GT fake | {mat[1, 0]:{padding}<{10}}| {mat[1, 1]:{padding}<{11}}|', file=logfile)
        
    def MultipleAns(self, ans1, ans2):
    
        # Q1: Is this photo common in real world?
        # Q2: Is this photo generated by a model?
        
        final_ans = []
        for ans in zip(ans1, ans2):
            if ans[0] == 0 and ans[1] == 0:
                final_ans.append(0)
            else:
                final_ans.append(1)
        
        acc = accuracy_score(self.labels, final_ans)
        confusion_mat = confusion_matrix(self.labels, final_ans)
        print(f'Accuracy: {acc*100:.2f}%')
        self.PrintConfusion(confusion_mat)
        
        self.ans_list = final_ans
        self.acc = acc
        self.confusion_mat = confusion_mat
        
        return acc, confusion_mat, final_ans
    
def print_combine_result(pretrained_ans, finetuned_ans, label, logPath):
    
    logfile = open(logPath, 'a')
    
    def _print_confusion(mat, logfile):
        padding = ' '
        print(f'        | Pred real | Pred fake |', file=logfile)
        print(f'GT real | {mat[0, 0]:{padding}<{10}}| {mat[0, 1]:{padding}<{11}}|', file=logfile)
        print(f'GT fake | {mat[1, 0]:{padding}<{10}}| {mat[1, 1]:{padding}<{11}}|', file=logfile)
    
    comb_ans = np.ceil((pretrained_ans + finetuned_ans)/2).astype(np.int64)
    
    comb_acc = accuracy_score(label, comb_ans)
    comb_confusion_mat = confusion_matrix(label, comb_ans, labels=[0,1])
    
    print(f'=== Overall (Comb) ===', file=logfile)
    print(f'Acc: {comb_acc*100:.2f}%', file=logfile)
    _print_confusion(comb_confusion_mat, logfile=logfile)
    print('\n', file=logfile)
    
    real_ans_list = comb_ans[label==0]
    real_label = [0] * len(real_ans_list)
    real_acc = accuracy_score(real_label, real_ans_list)
    real_confusion_mat = confusion_matrix(real_label, real_ans_list, labels=[0,1])
    print(f'=== Real images (Comb) ===', file=logfile)
    print(f'Acc: {real_acc*100:.2f}%', file=logfile)
    _print_confusion(real_confusion_mat, logfile=logfile)
    print('\n', file=logfile)
    
    
    com_ans_list = comb_ans[label==1]
    com_label = [1] * len(com_ans_list)
    com_acc = accuracy_score(com_label, com_ans_list)
    com_confusion_mat = confusion_matrix(com_label, com_ans_list, labels=[0,1])
    print(f'=== Common fake images (Comb) ===', file=logfile)
    print(f'Acc: {com_acc*100:.2f}%', file=logfile)
    _print_confusion(com_confusion_mat, logfile=logfile)
    print('\n', file=logfile)
    
    return comb_acc, comb_confusion_mat, comb_ans

  from .autonotebook import tqdm as notebook_tqdm
2023-10-05 22:12:08.807608: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model, vis_processors, txt_processors = load_model_and_preprocess(name="blip2_vicuna_instruct", model_type="vicuna7b", is_eval=True, device=device)
# model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_vqa", model_type="vqav2", is_eval=True, device=device)
model, vis_processors, txt_processors = load_model_and_preprocess(name="albef_vqa", model_type="vqav2", is_eval=True, device=device)

print(f'Load model OK!')

100%|██████████| 4.33G/4.33G [04:22<00:00, 17.7MB/s] 


Load model OK!


In [20]:
device = 'cuda'
imgRoot = "/eva_data0/iammingggg/textual_inversion/60k_6k_6k/test/1_fake/SDXLInpaint/foreground_rectangle"
promptFile = "/eva_data0/denny/textual_inversion/60k_6k_6k/test/real_prompts.txt"

question_base = "Which word best describe the image?"

print(type(vis_processors))
dataset = ImgDataset(imgRoot=imgRoot, promptFile=promptFile, vis_processors=vis_processors)
dataloader = DataLoader(dataset=dataset, batch_size=1, num_workers=8, pin_memory=True)
print('DL OK')

label_list = []
for image, index in tqdm(dataloader):

    image = image.to(device)
    
    questions = [question_base] * image.shape[0]
    
    answer_list = []
    for i in index:
        answer_list.append(dataset.prompts[i].split(' '))
    # print(answer_list)
    
    # for i in range(len(questions)):
    #     # print(answer_list[i])
    #     questions[i] = questions[i] + ' ' +  ', '.join(answer_list[i])
    #     print(questions[i])
    
    samples = {"image": image, "text_input":questions}
    
    #print(answer_list[0])
    ans = model.predict_answers(samples=samples, answer_list=answer_list[0], inference_method="rank")
    print(ans)
    label_list.append(ans[0])


<class 'dict'>
Read 100 paths.
DL OK


  2%|▏         | 2/100 [00:01<01:01,  1.59it/s]

['room']
['open']


  4%|▍         | 4/100 [00:01<00:29,  3.24it/s]

['elephant']
['cake']


  6%|▌         | 6/100 [00:02<00:20,  4.63it/s]

['phone']
['motorcycle']


  8%|▊         | 8/100 [00:02<00:17,  5.40it/s]

['building']
['here']


 10%|█         | 10/100 [00:02<00:14,  6.32it/s]

['building']
['donut']


 12%|█▏        | 12/100 [00:02<00:12,  7.25it/s]

['horses']
['tennis']


 14%|█▍        | 14/100 [00:03<00:11,  7.40it/s]

['man']
['umbrella']


 16%|█▌        | 16/100 [00:03<00:10,  8.22it/s]

['frisbee']
['statues']


 18%|█▊        | 18/100 [00:03<00:09,  8.41it/s]

['cat']
['white']


 20%|██        | 20/100 [00:03<00:09,  8.54it/s]

['cake']
['rescue']


 22%|██▏       | 22/100 [00:04<00:08,  8.86it/s]

['food']
['bedroom']


 24%|██▍       | 24/100 [00:04<00:08,  8.92it/s]

['jet']
['skateboard']


 26%|██▌       | 26/100 [00:04<00:08,  8.96it/s]

['horse']
['cooking']


 28%|██▊       | 28/100 [00:04<00:08,  8.98it/s]

['man']
['backpack']


 30%|███       | 30/100 [00:04<00:07,  9.03it/s]

['circus']
['frisbee']


 32%|███▏      | 32/100 [00:05<00:07,  9.14it/s]

['flowers']
['skateboard']


 34%|███▍      | 34/100 [00:05<00:07,  8.93it/s]

['tennis']
['fruit']


 36%|███▌      | 36/100 [00:05<00:07,  9.08it/s]

['garbage']
['man']


 38%|███▊      | 38/100 [00:05<00:06,  9.16it/s]

['boy']
['shower']


 40%|████      | 40/100 [00:06<00:06,  9.22it/s]

['group']
['broth']


 42%|████▏     | 42/100 [00:06<00:06,  9.00it/s]

['cars']
['bear']


 44%|████▍     | 44/100 [00:06<00:06,  8.92it/s]

['cat']
['man']


 46%|████▌     | 46/100 [00:06<00:05,  9.02it/s]

['teddy']
['man']


 48%|████▊     | 48/100 [00:06<00:05,  9.17it/s]

['black']
['boy']


 50%|█████     | 50/100 [00:07<00:05,  9.06it/s]

['frisbee']
['clock']


 52%|█████▏    | 52/100 [00:07<00:05,  9.12it/s]

['cow']
['cat']


 54%|█████▍    | 54/100 [00:07<00:05,  9.11it/s]

['man']
['cake']


 56%|█████▌    | 56/100 [00:07<00:04,  9.14it/s]

['restaurant']
['room']


 58%|█████▊    | 58/100 [00:08<00:04,  8.88it/s]

['wall']
['pizza']


 60%|██████    | 60/100 [00:08<00:04,  8.66it/s]

['food']
['tower']


 62%|██████▏   | 62/100 [00:08<00:04,  8.89it/s]

['woman']
['baseball']


 64%|██████▍   | 64/100 [00:08<00:03,  9.07it/s]

['soccer']
['street']


 66%|██████▌   | 66/100 [00:08<00:03,  8.90it/s]

['sign']
['bears']


 68%|██████▊   | 68/100 [00:09<00:03,  8.93it/s]

['dogs']
['plane']


 70%|███████   | 70/100 [00:09<00:03,  8.89it/s]

['apple']
['bear']


 72%|███████▏  | 72/100 [00:09<00:03,  8.94it/s]

['room']
['cabinet']


 74%|███████▍  | 74/100 [00:09<00:02,  8.78it/s]

['water']
['pasta']


 76%|███████▌  | 76/100 [00:10<00:02,  8.89it/s]

['cat']
['bathroom']


 78%|███████▊  | 78/100 [00:10<00:02,  8.71it/s]

['car']
['zebras']


 80%|████████  | 80/100 [00:10<00:02,  8.96it/s]

['baseball']
['people']


 82%|████████▏ | 82/100 [00:10<00:02,  8.90it/s]

['skateboard']
['sad']


 84%|████████▍ | 84/100 [00:10<00:01,  8.87it/s]

['man']
['snow']


 86%|████████▌ | 86/100 [00:11<00:01,  9.11it/s]

['elephant']
['man']


 88%|████████▊ | 88/100 [00:11<00:01,  9.36it/s]

['motorcycle']
['food']


 90%|█████████ | 90/100 [00:11<00:01,  9.47it/s]

['church']
['hair']


 92%|█████████▏| 92/100 [00:11<00:00,  9.58it/s]

['pizza']
['hat']


 94%|█████████▍| 94/100 [00:12<00:00,  9.55it/s]

['man']


 96%|█████████▌| 96/100 [00:12<00:00,  9.32it/s]

['mirror']
['airplane']


 98%|█████████▊| 98/100 [00:12<00:00,  9.17it/s]

['giraffe']
['donut']


100%|██████████| 100/100 [00:12<00:00,  8.92it/s]

['rain']
['kite']


100%|██████████| 100/100 [00:12<00:00,  7.75it/s]


In [23]:
import random

path = 'log/new_prompt.txt'
new_prompts = []
with open(path, 'w') as f:
    for i, prompt in enumerate(dataset.prompts):
        new_prompt = prompt.replace(label_list[i], random.sample(label_list,1)[0])
        new_prompts.append(new_prompt)
        print(new_prompt, file=f)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
