In [2]:
import einops
import os

import matplotlib.pyplot as plt
import numpy as np
from tianshou.policy import PPOPolicy
import uuid
from tianshou.utils import WandbLogger, LazyLogger
from torch.utils.tensorboard import SummaryWriter
from gymnasium.wrappers import TimeLimit
from customs import CustomDQNPolicy, CustomOffpolicyTrainer
from dataset import *
from env import *
from networks.qnet import *
from networks.vit import ViTTrailEncoder
from networks.SimpleAC import Actor as SimpleActor, Critic as SimpleCritic
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
from tianshou.trainer import OnpolicyTrainer
import torch
import tianshou as ts
from networks.mul_AC import Actor as MulActor, Critic as MulCritic
from networks.path_vit import BaseNetwork as MulBaseNetwork
import torch.nn.functional as F
import random
from networks.simmim import BaseNetwork as SimmimBase
from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling, BeitImageProcessor
from PIL import Image
import requests
import seaborn as sns
import IPython
import time

In [3]:
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k', 
                                                   proxies={'http': '127.0.0.1:10809', 'https': '127.0.0.1:10809'},)

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
image_tensor = inputs.pixel_values[0]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
for idx in range(logits.shape[1]):
    row_logits = logits[0, idx, :]
    probs = F.softmax(row_logits, dim=0).detach()
    axes[0].cla()
    axes[0].set(ylim=(0, .5))
    sns.lineplot(x=list(range(len(probs))), y=probs, color="salmon", ax=axes[0])
    row = idx // (224//16)
    row *= 16
    col = idx % (224//16)
    col *= 16
    highlighted_image = image_tensor.clone()
    highlighted_image[:, row:row+16, col:col+16] = highlighted_image[:, row:row+16, col:col+16] * 0.5
    axes[1].imshow(highlighted_image.permute(1, 2, 0).cpu().numpy())
    IPython.display.clear_output(wait=True)
    IPython.display.display(fig)
    time.sleep(.2)

In [None]:
def get_row_col(idx):
    row = idx // (224//16)
    row *= 16
    col = idx % (224//16)
    col *= 16
    return row, col

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
axes[0].imshow(image_tensor.permute(1, 2, 0).cpu().numpy())
for idx in range(logits.shape[1]):
    row_logits = logits[0, idx, :].detach()
    probs = F.softmax(row_logits, dim=0)
    row, col = get_row_col(idx)
    highlighted_image = image_tensor.clone()
    highlighted_image[:, row:row+16, col:col+16] = (highlighted_image[:, row:row+16, col:col+16] + 1)/2
    
    score_list = []
    for idxc in range(logits.shape[1]):
        if idxc == idx:
            continue
        rowc_logits = logits[0, idxc, :].detach()
        score = row_logits@rowc_logits
        score_list.append(score)
    max, min = np.max(score_list), np.min(score_list)
    for idxc in range(logits.shape[1]):
        if idxc == idx:
            continue
        rowc_logits = logits[0, idxc, :].detach()
        rowc, colc = get_row_col(idxc)
        score = row_logits@rowc_logits
        highlighted_image[:, rowc:rowc+16, colc:colc+16] = highlighted_image[:, rowc:rowc+16, colc:colc+16] * ((score-min)/(max-min))

    axes[1].imshow(highlighted_image.permute(1, 2, 0).cpu().numpy())
    
    IPython.display.clear_output(wait=True)
    IPython.display.display(fig)
    time.sleep(.5)