In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import imageio, warnings, os

from interpretation_methods import *
from utils.imagenet_seg_loader import ImagenetSegLoader
from utils.model_loaders import vit_base_patch16_224_dino, vit_base_patch16_224
from utils.input_arguments import get_arg_parser
from utils.saver import Saver
from utils.sideplot import side_plot
from utils.image_denorm import image_vizformat

warnings.filterwarnings("ignore")
plt.switch_backend("agg")

In [None]:
data_path = "lib/dataset/gtsegs_ijcv.mat"
data_length = 3
batch_size = 1
num_workers = 7

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
label_transform = transforms.Compose([transforms.Resize((224, 224), Image.NEAREST), ])

In [None]:
dataset = ImagenetSegLoader(data_path, data_length, transform=image_transform, target_transform=label_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)
dataloader = tqdm(dataloader)  # Would help tracking loop iteration along with setting some verbose text.

# model = vit_base_patch16_224_dino(pretrained=True).to(device)

In [None]:
def image_vizformat(img):
    inr = transforms.Normalize(mean=[-0.5/.5, -0.5/.5, -0.5/.5], std=[1/0.5, 1/0.5, 1/0.5])
    img = inr(img[0])
    img = torch.permute(img, (1, 2, 0))
    return img.detach().cpu().numpy()

In [None]:
imgs = []
masks = []

for ix, d in tqdm(enumerate(dataloader)):
    imgs.append(d[0])
    masks.append(d[1])

In [None]:
mdpath = "C:/Users/muimr/Research/Vit Interpret/Codes/beyond_intuition/lib/benchmark__trained_on_noisy_data/ff.pth"

In [None]:
from functools import partial
from utils.load_pretrained import load_pretrained
from utils.model_loaders import _conv_filter
from vision_transformer.vit import VisionTransformer
from utils.config import default_config
import torch.nn as nn


def vit_base_patch16_224(pretrained=False, url_given=None, **kwargs):
    
    model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    
    if url_given is not None:
        default_config['vit_base_patch16_224']['url'] = url_given
    
    model.default_cfg = default_config['vit_base_patch16_224']
    if pretrained:
        load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
    return model

In [None]:
default_config['vit_base_patch16_224']

In [None]:
model = vit_base_patch16_224(pretrained=True, 
                             url_given="C:/Users/muimr/Research/Vit Interpret/Codes/beyond_intuition/lib/pretrained_model/jx_vit_base_p16_224-80ecf9dd.pth")

In [None]:
torch.save(model.state_dict(), mdpath)

In [None]:
model2 = vit_base_patch16_224(pretrained=True, url_given=mdpath)

In [None]:
m1keys = list((torch.load("C:/Users/muimr/Research/Vit Interpret/Codes/beyond_intuition/lib/pretrained_model/jx_vit_base_p16_224-80ecf9dd.pth")).keys())

In [None]:
m2keys = list(torch.load(mdpath).keys())

In [None]:
out1 = model(imgs[0])

In [None]:
out2 = model2(imgs[0])