In [1]:
import sys
sys.path.append('/data2/junhong/proj/text_guide_attack')
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import torch
import pyarrow.parquet as pq
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader
from dataset import MS_COCO, collate_fn
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import argparse
from tqdm import tqdm
from model.clip_unet import CLIP_encoder_decoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def args(argparse):
    argparse.add_argument("--train", type=str, default=True)
    argparse.add_argument("--clip_model_path", type=str, default="/data2/ModelWarehouse/clip-vit-base-patch32")
    argparse.add_argument("--image_path", type=str, default="/data2/zhiyu/data/coco/images/train2017")
    argparse.add_argument("--data_path", type=str,
                      default="/data2/junhong/proj/text_guide_attack/data/mscoco_exist.parquet")
    argparse.add_argument("--epoch", type=int, default=20)
    argparse.add_argument("--batch_size", type=int, default=256)
    argparse.add_argument("--shuffle", type=bool, default=True)
    argparse.add_argument("--mode", type=str, default="test")
    argparse.add_argument("--model_path", type=str, default="save_model/model_epoch_20.pth")
    args = argparse.parse_args()
    return args

In [2]:
# argparse= argparse.ArgumentParser()
# args= args(argparse)
processor = CLIPProcessor.from_pretrained("/data2/ModelWarehouse/clip-vit-base-patch32")
encoder = CLIPModel.from_pretrained("/data2/ModelWarehouse/clip-vit-base-patch32").to(device).eval()

In [3]:
class MS_COCO(Dataset):
    def __init__(self, data_path,image_path):
        self.args = args
        self.data_list=pq.read_table(data_path)
        self.data_list=self.data_list.to_pandas()
        self.image_path=image_path

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):

        text = self.data_list['TEXT'][idx]

        image_name = self.data_list['URL'][idx]

        img_path=os.path.join(self.image_path,image_name)
        img = Image.open(img_path).convert('RGB')

        # You can perform additional transformations on the image here if needed

        return img, text
batch_size = 256
data_path="/data2/junhong/proj/text_guide_attack/data/mscoco_exist.parquet"
image_path="/data2/zhiyu/data/coco/images/train2017"
train_dataset = MS_COCO(data_path,image_path)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn)

In [4]:
def process_text(texts):
    input_text = []
    for text in texts:
        input_text.append(text[0])
    return input_text
for (images, texts) in tqdm(train_data_loader, desc="Testing", position=0):
    inputs = processor(text=process_text(texts), images=images, return_tensors="pt", padding=True).to(device)
    break

Testing:   0%|                                          | 0/463 [00:03<?, ?it/s]


In [5]:
encoder(**inputs)

CLIPOutput(loss=None, logits_per_image=tensor([[28.3239, 14.1571, 18.9245,  ..., 11.2521,  8.8391, 13.8871],
        [18.4652, 30.3185,  9.9260,  ..., 12.1234, 18.7343, 12.8444],
        [17.2572, 12.5573, 28.3465,  ..., 11.9973, 10.2950, 11.0302],
        ...,
        [18.9985, 12.8682,  8.7138,  ..., 33.1035, 11.9111, 13.3936],
        [17.2938, 14.5991, 14.4264,  ..., 10.5532, 37.3439,  9.2228],
        [16.4232, 15.4139, 13.6517,  ..., 14.3283, 10.8179, 31.8879]],
       device='cuda:0', grad_fn=<TBackward0>), logits_per_text=tensor([[28.3239, 18.4652, 17.2572,  ..., 18.9985, 17.2938, 16.4232],
        [14.1571, 30.3185, 12.5573,  ..., 12.8682, 14.5991, 15.4139],
        [18.9245,  9.9260, 28.3465,  ...,  8.7138, 14.4264, 13.6517],
        ...,
        [11.2521, 12.1234, 11.9973,  ..., 33.1035, 10.5532, 14.3283],
        [ 8.8391, 18.7343, 10.2950,  ..., 11.9111, 37.3439, 10.8179],
        [13.8871, 12.8444, 11.0302,  ..., 13.3936,  9.2228, 31.8879]],
       device='cuda:0', grad_f

In [6]:
def process_text(texts):
    input_text = []
    for text in texts:
        input_text.append(text[0])
    return input_text

def calculate_cos(image_emb, output_encode):
    cos_sim = torch.nn.functional.cosine_similarity(image_emb, output_encode, dim=1)
    return cos_sim.sum()
sim=0.0
eval_batch=30
count=0
with torch.no_grad():
    for (images, texts) in tqdm(train_data_loader, desc="Testing", position=0):
        # images=images.to(device)
        # texts=texts.to(device)
        inputs = processor(text=process_text(texts), images=images, return_tensors="pt", padding=True).to(device)
        # print(type(inputs))
        output=encoder(**inputs)
        img_embeds=output.image_embeds
        text_embeds=output.text_embeds
        sim+=calculate_cos(img_embeds,text_embeds)
        count+=1
        print(f"count_{count}:",sim/(count*batch_size))
        if count==eval_batch:
            break
    # print("sim:",sim/(eval_batch*batch_size))




Testing:   0%|                                  | 1/463 [00:03<28:37,  3.72s/it]

count_1: tensor(0.3039, device='cuda:0')


Testing:   0%|▏                                 | 2/463 [00:07<29:26,  3.83s/it]

count_2: tensor(0.3043, device='cuda:0')


Testing:   1%|▏                                 | 3/463 [00:11<29:09,  3.80s/it]

count_3: tensor(0.3052, device='cuda:0')


Testing:   1%|▎                                 | 4/463 [00:15<28:47,  3.76s/it]

count_4: tensor(0.3050, device='cuda:0')


Testing:   1%|▎                                 | 5/463 [00:18<28:36,  3.75s/it]

count_5: tensor(0.3035, device='cuda:0')


Testing:   1%|▍                                 | 6/463 [00:22<28:56,  3.80s/it]

count_6: tensor(0.3034, device='cuda:0')


Testing:   2%|▌                                 | 7/463 [00:26<28:05,  3.70s/it]

count_7: tensor(0.3036, device='cuda:0')


Testing:   2%|▌                                 | 8/463 [00:29<27:25,  3.62s/it]

count_8: tensor(0.3036, device='cuda:0')


Testing:   2%|▋                                 | 9/463 [00:33<27:27,  3.63s/it]

count_9: tensor(0.3037, device='cuda:0')


Testing:   2%|▋                                | 10/463 [00:36<27:24,  3.63s/it]

count_10: tensor(0.3037, device='cuda:0')


Testing:   2%|▊                                | 11/463 [00:40<27:42,  3.68s/it]

count_11: tensor(0.3037, device='cuda:0')


Testing:   3%|▊                                | 12/463 [00:44<28:05,  3.74s/it]

count_12: tensor(0.3041, device='cuda:0')


Testing:   3%|▉                                | 13/463 [00:48<27:55,  3.72s/it]

count_13: tensor(0.3041, device='cuda:0')
count_14: 

Testing:   3%|▉                                | 14/463 [00:51<27:35,  3.69s/it]

tensor(0.3039, device='cuda:0')


Testing:   3%|█                                | 15/463 [00:55<27:21,  3.67s/it]

count_15: tensor(0.3038, device='cuda:0')


Testing:   3%|█▏                               | 16/463 [00:59<27:55,  3.75s/it]

count_16: tensor(0.3040, device='cuda:0')


Testing:   4%|█▏                               | 17/463 [01:03<27:36,  3.71s/it]

count_17: tensor(0.3040, device='cuda:0')


Testing:   4%|█▎                               | 18/463 [01:06<27:10,  3.66s/it]

count_18: tensor(0.3041, device='cuda:0')


Testing:   4%|█▎                               | 19/463 [01:10<26:57,  3.64s/it]

count_19: tensor(0.3040, device='cuda:0')


Testing:   4%|█▍                               | 20/463 [01:13<26:25,  3.58s/it]

count_20: tensor(0.3041, device='cuda:0')
count_21: 

Testing:   5%|█▍                               | 21/463 [01:17<26:17,  3.57s/it]

tensor(0.3039, device='cuda:0')


Testing:   5%|█▌                               | 22/463 [01:20<26:11,  3.56s/it]

count_22: tensor(0.3040, device='cuda:0')


Testing:   5%|█▋                               | 23/463 [01:24<26:02,  3.55s/it]

count_23: tensor(0.3039, device='cuda:0')


Testing:   5%|█▋                               | 24/463 [01:27<26:19,  3.60s/it]

count_24: tensor(0.3040, device='cuda:0')


Testing:   5%|█▊                               | 25/463 [01:31<26:20,  3.61s/it]

count_25: tensor(0.3040, device='cuda:0')


Testing:   6%|█▊                               | 26/463 [01:35<26:31,  3.64s/it]

count_26: tensor(0.3040, device='cuda:0')
count_27: 

Testing:   6%|█▉                               | 27/463 [01:39<26:32,  3.65s/it]

tensor(0.3039, device='cuda:0')


Testing:   6%|█▉                               | 28/463 [01:42<26:23,  3.64s/it]

count_28: tensor(0.3038, device='cuda:0')


Testing:   6%|██                               | 29/463 [01:46<26:41,  3.69s/it]

count_29: tensor(0.3038, device='cuda:0')


Testing:   6%|██                               | 29/463 [01:49<27:25,  3.79s/it]

count_30: tensor(0.3038, device='cuda:0')



