In [1]:
# CALL FROM TERMINAL

# cd codes/utils/
# python img2feat_decoded.py --gpu 0 --subject subj01 --method cvpr

# IMPORTS

In [2]:
import argparse, os, sys, glob
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import numpy as np
import torchvision
from torchvision import transforms
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# SETUP

In [18]:
gpu=0
subject = 'subj01'
method = 'cvpr'

# Parameters
torch.cuda.set_device(gpu)
device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

# SETUP

In [26]:
image_model =  'netdecoded'     # Options : 'paper', 'lasso', 'net', 'netdecoded' 

if image_model == 'paper':
    imglist = sorted(glob.glob(f'../../decoded/all_generated_images/image-{method}/{subject}/samples/*'))
    outdir = f'../../identification/paper/{method}/{subject}/'
    
elif image_model == 'lasso':
    imglist = sorted(glob.glob(f'../../decoded/lasso_images/image-{method}/{subject}/samples/*'))
    outdir = f'../../identification/lasso/{method}/{subject}/'
    
elif image_model == 'net':
    imglist = sorted(glob.glob(f'../../decoded/net_images/image-{method}/{subject}/samples/*'))
    outdir = f'../../identification/net/{method}/{subject}/'
    
elif image_model == 'netdecoded':
    imglist = sorted(glob.glob(f'../../decoded/netdecoded_images/image-{method}/{subject}/samples/*'))
    outdir = f'../../identification/netdecoded/{method}/{subject}/'

# TEST ON 500 6/image
imglist = imglist[:3000]
    
os.makedirs(outdir, exist_ok=True)

# LOAD MODELS

In [20]:
# Load Models 
# Inception V3
preprocess = transforms.Compose([
    transforms.Resize(299),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
model_inception = torchvision.models.inception_v3(pretrained=True)
model_inception.eval()
model_inception.to(device)
model_inception = torchvision.models.feature_extraction.create_feature_extractor(model_inception, {'flatten':'flatten'})

# AlexNet
model_alexnet = torchvision.models.alexnet(pretrained=True)
model_alexnet.eval()
model_alexnet.to(device)
model_alexnet = torchvision.models.feature_extraction.create_feature_extractor(model_alexnet,{'features.5':'features.5',
                                                                                              'features.12':'features.12',
                                                                                              'classifier.5':'classifier.5'})

# CLIP
model_clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
model_clip.to(device)
processor_clip = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [27]:
len(imglist)

3000

In [28]:
print(f"Now processing start for : {method}")
for img in tqdm(imglist):
    imgname = img.split('/')[-1].split('.')[0]
    #print(img)
    image = Image.open(img)

    # Inception
    input_tensor = preprocess(image)
    input_batch = input_tensor.unsqueeze(0)
    input_batch = input_batch.to(device)
    with torch.no_grad():
        feat = model_inception(input_batch)
    feat_inception = feat['flatten'].cpu().detach().numpy().copy()    

    # AlexNet
    with torch.no_grad():
        feat = model_alexnet(input_batch)
    feat_alexnet5 = feat['features.5'].flatten().cpu().detach().numpy().copy()    
    feat_alexnet12 = feat['features.12'].flatten().cpu().detach().numpy().copy()    
    feat_alexnet18 = feat['classifier.5'].flatten().cpu().detach().numpy().copy()    

    # CLIP
    inputs = processor_clip(text="",images=image, return_tensors="pt").to(device)
    outputs = model_clip(**inputs,output_hidden_states=True)
    feat_clip = outputs.image_embeds.cpu().detach().numpy().copy()
    feat_clip_h6 = outputs.vision_model_output.hidden_states[6].flatten().cpu().detach().numpy().copy()
    feat_clip_h12 = outputs.vision_model_output.hidden_states[12].flatten().cpu().detach().numpy().copy()

    # SAVE
    fname = f'{outdir}/{imgname}'
    np.save(f'{fname}_inception.npy',feat_inception)
    np.save(f'{fname}_alexnet5.npy',feat_alexnet5)
    np.save(f'{fname}_alexnet12.npy',feat_alexnet12)
    np.save(f'{fname}_alexnet18.npy',feat_alexnet18)
    #np.save(f'{fname}_clip.npy',feat_clip)
    #np.save(f'{fname}_clip_h6.npy',feat_clip_h6)
    #np.save(f'{fname}_clip_h12.npy',feat_clip_h12)
    
    

Now processing start for : cvpr


100%|███████████████████████████████████████| 3000/3000 [03:03<00:00, 16.32it/s]
