In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import transformers
import torch
from transformers import AutoImageProcessor, AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import json
import cv2

In [None]:
class FunnyNotFunnyDataset(Dataset):
    def __init__(self, data=[], root_dir=None, transform=None, image_processor=None):
        if root_dir[-1] != '/':
            root_dir += '/'
        self.root_dir = root_dir
        classes = []
        for file in os.scandir(root_dir):
            if file.is_dir():
                classes.append(file.name)
        data = []
        for i, class_name in enumerate(classes):
            for file in os.listdir(root_dir+class_name):
                data.append((root_dir + class_name + '/'+ file, 1-i))
        self.data = data
        self.num_classes = len(classes)
        self.transform = transform
        self.image_processor = image_processor
    def __len__(self):
        return len(self.data)
    def __getitem__(self, index):
        image = Image.open(self.data[index][0])
        if self.image_processor:
            image = self.image_processor(images=image, return_tensors='pt')
            image['pixel_values'] = image['pixel_values'].squeeze(0)
            label = self.data[index][1]
            label_tensor = torch.zeros(1)
            if label == 1:
                label_tensor[0] = 1
            image['label'] = label_tensor
            image['filename'] = self.data[index][0]
            return image
        if not self.image_processor:
            if self.transform:
                try:
                    image = self.transform()(image)
                except:
                    image = self.transform(image)
        label = self.data[index][1]
        label_tensor = torch.zeros(1)
        if label == 1:
            label_tensor[0] = 1
        return {'image_data':image, 'label':label_tensor, 'filename':self.data[index][0]}

In [None]:
GPU_MAP = {0: "8GiB", 1: "8GiB", 2: "8GiB"}

model = transformers.AutoModelForImageClassification.from_pretrained("google/vit-huge-patch14-224-in21k", 
                                                                     device_map='auto', max_memory=GPU_MAP, num_labels=2, ignore_mismatched_sizes=True)
processor = AutoImageProcessor.from_pretrained("google/vit-huge-patch14-224-in21k")
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
state_dict = torch.load('/home/vedaant/send/benchmarks/binary_classification/transformers/vit_huge_Funny/model/0.pth')
model.load_state_dict(state_dict)

In [None]:
test_dataset_transformers = FunnyNotFunnyDataset(root_dir='/home/vedaant/send/NewDataset2/Test' ,image_processor=processor)
test_dataloader_transformers = DataLoader(test_dataset_transformers, batch_size = 16, shuffle=True)

In [None]:
import gc
def TestTransformersWithFileNames():
    model.eval()
    total_size = 0
    total_loss = []
    correct = 0
    all_preds = []
    all_targets = []
    file_results = {}
    file_detailed_results = {}
    with torch.no_grad():
        for data in test_dataloader_transformers:
            gc.collect()
            torch.cuda.empty_cache()
            inputs = data['pixel_values'].to(device)
            targets = data["label"].to(device).flatten()
            file_names = data['filename']
            inputs = inputs.type(torch.cuda.FloatTensor)
            targets = targets.type(torch.cuda.FloatTensor)
            #print(ids.shape, "ids")
            batch_size = inputs.size(0)
            #assert batch_size == 1, 'To Test with File Names Batch Size in Test DataLoader should be 1'
            output = model(inputs).logits
            gc.collect()
            del inputs
            torch.cuda.empty_cache()
            output = torch.softmax(output, dim=1)
            output_list = output.tolist()
            predictions = torch.argmax(output, dim=1)
            if (predictions == targets).float().sum().item() > batch_size:
                print('error?')
            all_preds += predictions.flatten().cpu().detach().tolist()
            all_targets += targets.flatten().cpu().detach().tolist()
            correct += (predictions == targets).float().sum().item()
            results = (predictions == targets)
            for i, file_name in enumerate(file_names):
                file_results[file_name] = int(results[i].cpu().detach())
                file_detailed_results[file_name] = {}
                file_detailed_results[file_name]['pred'] = predictions[i]
                file_detailed_results[file_name]['out'] = output_list[i][predictions[i]]
                file_detailed_results[file_name]['label'] = targets[i]
            gc.collect()
            del predictions
            del targets
            del output
            torch.cuda.empty_cache()
            total_size += batch_size
          #gpu_usage()
        accuracy = correct/(total_size)
    print("Total Test Loss: {:.4f}; Test Accuracy: {:.2f}%".format(np.sum(total_loss), accuracy*100))
    return file_results, file_detailed_results

In [None]:
vit_file_results, vit_file_detailed_results = TestTransformersWithFileNames()

In [None]:
mod2org = json.load(open('/home/vedaant/send/mod2org.json'))

In [None]:
temp = '/home/vedaant/send/NewDataset2/Test/Funny/'
vit_file_results_key_shorten = {}
for key, value in vit_file_results.items():
    if '/Funny' in key:
        vit_file_results_key_shorten[key[len(temp):]] = value
    else:
        vit_file_results_key_shorten[key[len(temp)+4:]] = value
len(vit_file_results_key_shorten)

In [None]:
accurate_pairs = []
inaccurate_pairs = []
special = []
for key, value in mod2org.items():
    if key in vit_file_results_key_shorten and value in vit_file_results_key_shorten:
        if vit_file_results_key_shorten[key] == 1 and vit_file_results_key_shorten[value] == 1:
            accurate_pairs.append((key, value))
            continue
        if vit_file_results_key_shorten[key] == 0 and vit_file_results_key_shorten[value] == 0:
            special.append((key, value))
        inaccurate_pairs.append((key, value))
len(accurate_pairs), len(inaccurate_pairs), len(special)

In [None]:
correct = 0
total = 0
for key, value in vit_file_results_key_shorten.items():
    if 'M' not in key:
        if value == 1:
            correct += 1
        total += 1
correct, total, correct/total

In [None]:
def get_attention_map(image_path, model, processor):
    image = Image.open(image_path)
    data = processor(images=image, return_tensors='pt')
    inputs = data.pixel_values.to(device)
    inputs.to(device)
    output= model(pixel_values = inputs, output_attentions=True)
    att_mat = output.attentions
    att_mat = torch.stack(att_mat).squeeze(1)
    att_mat = att_mat.mean(dim=1)
    att_mat = att_mat.cpu().detach()
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    result = cv2.resize(mask / mask.max(), image.size)
    
    return result

In [None]:
def plot_attention_map(original_img, att_map):
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    attention_plot = ax2.imshow(att_map)
    
    # Create a colorbar associated with the attention map
    plt.colorbar(attention_plot, ax=ax2, orientation='vertical', fraction=0.036, pad=0.03)
    plt.show()

In [None]:
#write function to display an image

def display_image(image):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.axis('off')

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import math

def display_images_in_grid(image_paths, labels, grid_size=None):
    """
    Displays a list of images in a grid with labels under each image.

    :param image_paths: List of paths to image files.
    :param labels: List of labels for each image.
    :param grid_size: Tuple indicating the grid size (rows, columns). If None, a square grid will be used.
    """
    if len(image_paths) != len(labels):
        raise ValueError("The number of images must match the number of labels")

    if grid_size is None:
        n = math.ceil(math.sqrt(len(image_paths)))  # Calculate the size of the square grid
        grid_size = (n, n)

    fig, axes = plt.subplots(grid_size[0], grid_size[1], figsize=(grid_size[1] * 5, grid_size[0] * 5))
    axes = axes.flatten()

    for ax, img_path, label in zip(axes, image_paths, labels):
        img = mpimg.imread(img_path)
        ax.imshow(img)
        ax.axis('off')  # Hide the axis
        ax.set_title(label, fontsize=12)  # Set the title (label) below the image

    # Hide any remaining empty subplots
    for ax in axes[len(image_paths):]:
        ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
all_image_paths = [os.path.join('/home/vedaant/send/NewDataset2/Test/Funny/', pair[1]) if os.path.exists(os.path.join('/home/vedaant/send/NewDataset2/Test/Funny/', pair[1])) else os.path.join('/home/vedaant/send/NewDataset2/Test/Not_Funny/', pair[1]) for pair in accurate_pairs if os.path.exists(os.path.join('/home/vedaant/send/NewDataset2/Test/Not_Funny/', pair[1])) for pair in accurate_pairs]
all_labels = [f"{i}" for i in range(len(all_image_paths))]

In [None]:
start = 0
cutoff = 16
image_paths = all_image_paths[start:cutoff]
labels = all_labels[start:cutoff]
display_images_in_grid(image_paths, labels)

In [None]:
image_path = os.path.join('/home/vedaant/send/NewDataset2/Test/Not_Funny/', accurate_pairs[9][0])
attn_map = get_attention_map(image_path, model, processor)
image = Image.open(image_path)
attn_map_img = attn_map[..., np.newaxis]
attn_map_img = (attn_map_img * image).astype("uint8")
plot_attention_map(image, attn_map)
# display_image(attn_map_img)

In [None]:
image_path = os.path.join('/home/vedaant/send/NewDataset2/Test/Funny/', accurate_pairs[9][1])
attn_map = get_attention_map(image_path, model, processor)
image = Image.open(image_path)
attn_map_img = attn_map[..., np.newaxis]
attn_map_img = (attn_map_img * image).astype("uint8")
plot_attention_map(image, attn_map)
# display_image(attn_map_img)

In [None]:
plot_attention_map(image, attn_map)