In [1]:
import numpy as np
import seaborn as sns
import os
import pandas as pd
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import json
import cv2
import torchvision

  from .autonotebook import tqdm as notebook_tqdm


# Preparation

In [7]:
MODEL_FILEDIR = '/scratch/data/TrojAI/image-classification-sep2022-train/models/'
METADATA_FILEPATH = '/scratch/data/TrojAI/image-classification-sep2022-train/METADATA.csv'
MODEL_ARCH = ['classification:' + arch for arch in ['resnet50', 'vit_base_patch32_224', 'mobilenet_v2']]
NUM_MODEL = 288
OUTPUT_FILEDIR = '/scratch/jialin/image-classification-sep2022/projects/weight_analysis/extracted_source/'
EXAMPLE_SRC_DIR = '/scratch/data/TrojAI/image-classification-sep2022-train/image-classification-sep2022-example-source-dataset'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def num_to_model_id(num):
    return 'id-' + str(100000000+num)[1:]

# Load Metadata

In [3]:
METADATA = pd.read_csv(METADATA_FILEPATH)
METADATA.head()

Unnamed: 0,model_name,converged,nonconverged_reason,master_seed,task_type_level,task_type,source_dataset_level,source_dataset,model_architecture,model_architecture_level,...,trigger_2.trigger_size_restriction_option,trigger_2.polygon_texture_augmentation_level,trigger_2.polygon_texture_augmentation,trigger_2.size_percentage_of_foreground_min,trigger_2.size_percentage_of_foreground_max,trigger_2.min_area,trigger_2.spatial_quadrant_level,trigger_2.spatial_quadrant,trigger_2.options_level,trigger_2.options
0,id-00000000,True,,354103127,0,classification,0,cityscapes,classification:resnet50,3,...,,,,,,,,,,
1,id-00000001,True,,2049821827,0,classification,0,cityscapes,classification:vit_base_patch32_224,6,...,,,,,,,,,,
2,id-00000002,True,,74361305,0,classification,0,cityscapes,classification:mobilenet_v2,5,...,,,,,,,,,,
3,id-00000003,True,,197593124,0,classification,0,cityscapes,classification:resnet50,3,...,,,,,,,,,,
4,id-00000004,True,,69550395,0,classification,0,cityscapes,classification:resnet50,3,...,,,,,,,,,,


In [4]:
model_dir = os.path.join(MODEL_FILEDIR, num_to_model_id(0))
model_filepath = os.path.join(model_dir, 'model.pt')
clean_images_dir = os.path.join(model_dir, 'clean-example-data')

In [53]:
def get_output_from_example_images(model_filepath, image_filedir, device, loss_fn):
    image_filepaths = [os.path.join(image_filedir, img) for img in os.listdir(image_filedir) if img.endswith('.jpg')]
    image_filepaths.sort()

    images, targets, ids = [], [], []
    for image_filepath in image_filepaths:
        image_id = os.path.basename(image_filepath)
        image_id = int(image_id.replace('.jpg',''))
        img = cv2.imread(image_filepath, cv2.IMREAD_UNCHANGED)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        with open(image_filepath.replace('.jpg', '.json')) as json_file:
            target = json.load(json_file)

        # with torch.no_grad():
        img = torch.as_tensor(img).permute((2, 0, 1))
        img = torchvision.transforms.functional.convert_image_dtype(img, torch.float)

        images.append(img)
        targets.append(target)
        ids.append(image_id)
    
    # with torch.no_grad():
    images = [image.to(device) for image in images]
    images = [image.requires_grad_() for image in images]
    images = torch.stack(images, 0)
    images.retain_grad()
    
    model = torch.load(model_filepath).to(device)
    model.eval()

    logits = model(images)
    
    loss = loss_fn(logits, torch.as_tensor(targets, dtype=torch.long, device=device))
    loss.backward()

    del model

    return {'image_id': ids, 'logits':logits.cpu(), 'targets': targets, 'grad_images': images.grad.cpu()}


In [57]:
output_dict = get_output_from_example_images(model_filepath, clean_images_dir, DEVICE, torch.nn.CrossEntropyLoss())

In [60]:
output_dict['logits'].shape

torch.Size([20, 119])