In [1]:
import os
from PIL import Image
import numpy as np
import pandas as pd
import torch
from torch import nn
from torchvision import models
import torchvision.transforms as T
import torch.nn.functional as F
from tqdm import tqdm
import pylab as plt
import json
import random
import math
import vit_pytorch
import einops
import copy
from munch import Munch
import h5py
import re
from functools import reduce
import cv2

# load data

In [2]:
img_root = '/root/autodl-tmp/samples'
img_paths = [os.path.join(img_root, name) for name in os.listdir(img_root)]
img_paths

['/root/autodl-tmp/samples/tt0120903_1134.png',
 '/root/autodl-tmp/samples/tt0120903_1196.png',
 '/root/autodl-tmp/samples/tt0120903_1274.png',
 '/root/autodl-tmp/samples/tt0120903_1349.png',
 '/root/autodl-tmp/samples/tt0120903_146.png',
 '/root/autodl-tmp/samples/tt0120903_161.png',
 '/root/autodl-tmp/samples/tt0120903_173.png',
 '/root/autodl-tmp/samples/tt0120903_299.png',
 '/root/autodl-tmp/samples/tt0120903_352.png',
 '/root/autodl-tmp/samples/tt0120903_353.png',
 '/root/autodl-tmp/samples/tt0120903_369.png',
 '/root/autodl-tmp/samples/tt0120903_543.png',
 '/root/autodl-tmp/samples/tt0120903_557.png',
 '/root/autodl-tmp/samples/tt0120903_582.png',
 '/root/autodl-tmp/samples/tt0120903_587.png',
 '/root/autodl-tmp/samples/tt0120903_736.png',
 '/root/autodl-tmp/samples/tt0120903_938.png',
 '/root/autodl-tmp/samples/tt0120903_982.png']

In [3]:
trans = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean = [0.485, 0.456, 0.406],
        std = [0.229, 0.224, 0.225]
    ),
])

In [4]:
imgs = [Image.open(path) for path in img_paths]
tensors = torch.stack([trans(img) for img in imgs])

# bassl

In [5]:
bassl_params = torch.load('/root/autodl-tmp/bassl40epoch/model-v1-1.ckpt', 'cpu')

In [6]:
params = {}
for k,v in bassl_params['state_dict'].items():
    if k.startswith('shot_encoder.'):
        params[k[13:]] = v

In [7]:
from resnet.resnet import resnet50
bassl = resnet50()
bassl.eval()
bassl.requires_grad_(False)
bassl.load_state_dict(params)

<All keys matched successfully>

In [8]:
features = bassl(tensors)

# shotcol

In [9]:
shotcol_params = torch.load('/root/autodl-tmp/simclr_nn/model-v1-1.ckpt', 'cpu')

In [10]:
params = {}
for k,v in shotcol_params['state_dict'].items():
    if k.startswith('shot_encoder.'):
        params[k[13:]] = v

In [11]:
from resnet.resnet import resnet50
shotcol = resnet50()
shotcol.eval()
shotcol.requires_grad_(False)
shotcol.load_state_dict(params)

<All keys matched successfully>

In [12]:
features = shotcol(tensors)

# imagenet

In [5]:
imagenet = nn.Sequential(*list(models.resnet50(True).children())[:-2])
imagenet.eval()
imagenet.requires_grad_(False);

In [14]:
features = imagenet(tensors)

# places

In [6]:
places_params = torch.load('/root/autodl-tmp/resnet50_places365.pth.tar', 'cpu')

In [7]:
params = {}
for k,v in places_params['state_dict'].items():
    if k.startswith('module.'):
        params[k[7:]] = v

In [8]:
places = models.resnet50(False)
places.fc = nn.Linear(2048,365)
places.load_state_dict(params)
places = nn.Sequential(*list(places.children())[:-2])
places.eval()
places.requires_grad_(False);

In [18]:
features = places(tensors)

# forward

# draw batch

In [19]:
for method in ['bassl','shotcol','imagenet','places']:
    features = eval(method)(tensors)
    for idx in range(len(imgs)):
        img = imgs[idx]
        feature = features[idx].mean(0)
        
        weight = cv2.resize(np.array(feature), dsize=img.size)
        weight = (weight-np.min(weight))/(np.max(weight)-np.min(weight))
        heatmap = cv2.applyColorMap(np.uint8(255 * weight), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        result = cv2.addWeighted(np.array(img), .6, heatmap, .4, 0)
        
        name = f'results/{method}/{idx}.png'
        print('saving:',name)
        img = Image.fromarray(result)
        img.save(name)

saving: results/bassl/0.png
saving: results/bassl/1.png
saving: results/bassl/2.png
saving: results/bassl/3.png
saving: results/bassl/4.png
saving: results/bassl/5.png
saving: results/bassl/6.png
saving: results/bassl/7.png
saving: results/bassl/8.png
saving: results/bassl/9.png
saving: results/bassl/10.png
saving: results/bassl/11.png
saving: results/bassl/12.png
saving: results/bassl/13.png
saving: results/bassl/14.png
saving: results/bassl/15.png
saving: results/bassl/16.png
saving: results/bassl/17.png
saving: results/shotcol/0.png
saving: results/shotcol/1.png
saving: results/shotcol/2.png
saving: results/shotcol/3.png
saving: results/shotcol/4.png
saving: results/shotcol/5.png
saving: results/shotcol/6.png
saving: results/shotcol/7.png
saving: results/shotcol/8.png
saving: results/shotcol/9.png
saving: results/shotcol/10.png
saving: results/shotcol/11.png
saving: results/shotcol/12.png
saving: results/shotcol/13.png
saving: results/shotcol/14.png
saving: results/shotcol/15.png
sa

In [9]:
features = (imagenet(tensors)+places(tensors))/2
for idx in range(len(imgs)):
    img = imgs[idx]
    feature = features[idx].mean(0)

    weight = cv2.resize(np.array(feature), dsize=img.size)
    weight = (weight-np.min(weight))/(np.max(weight)-np.min(weight))
    heatmap = cv2.applyColorMap(np.uint8(255 * weight), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    result = cv2.addWeighted(np.array(img), .6, heatmap, .4, 0)

    name = f'results/mean/{idx}.png'
    print('saving:',name)
    img = Image.fromarray(result)
    img.save(name)

saving: results/mean/0.png
saving: results/mean/1.png
saving: results/mean/2.png
saving: results/mean/3.png
saving: results/mean/4.png
saving: results/mean/5.png
saving: results/mean/6.png
saving: results/mean/7.png
saving: results/mean/8.png
saving: results/mean/9.png
saving: results/mean/10.png
saving: results/mean/11.png
saving: results/mean/12.png
saving: results/mean/13.png
saving: results/mean/14.png
saving: results/mean/15.png
saving: results/mean/16.png
saving: results/mean/17.png
