# Dependencies

In [8]:
from hipt_4k import HIPT_4K
from hipt_model_utils import get_vit256, get_vit4k, eval_transforms
from hipt_heatmap_utils import *
light_jet = cmap_map(lambda x: x/2 + 0.5, matplotlib.cm.jet)

pretrained_weights256 = '../Checkpoints/vit256_small_dino.pth'
pretrained_weights4k = '../Checkpoints/vit4k_xs_dino.pth'
device256 = torch.device('cpu')
device4k = torch.device('cpu')

### ViT_256 + ViT_4K loaded independently (used for Attention Heatmaps)
model256 = get_vit256(pretrained_weights=pretrained_weights256, device=device256)
model4k = get_vit4k(pretrained_weights=pretrained_weights4k, device=device4k)

### ViT_256 + ViT_4K loaded into HIPT_4K API
model = HIPT_4K(pretrained_weights256, pretrained_weights4k, device256, device4k)
model.eval()

# of Patches: 196
# of Patches: 196


# Standalone HIPT_4K Model Inference

In [4]:
region = Image.open('./image_demo/image_4k.png')
x = eval_transforms()(region).unsqueeze(dim=0)
print('Input Shape:', x.shape)
print('Output Shape:', model.forward(x).shape)

# of Patches: 196
Input Shape: torch.Size([1, 3, 4096, 4096])
Output Shape: torch.Size([1, 192])


# HIPT_4K Attention Heatmaps
Code for producing attention results (for [256 x 256], [4096 x 4096], and hierarchical [4096 x 4096]) can be run (as-is) below. There are several ways these results can be run:
1. **hipt_4k.py** Class (Preferred): This class blends inference and heatmap creation in a seamless and more object-oriented manner, and is where I am focusing my future code development around.
2. Helper Functions in **hipt_heatmap_utils.py** (Soon-to-be-deprecated): Heatmap creation was originally written as helper functions. May be more useful and easier from research perspective.

Please use whatever is most helpful for your use case :) 

### 256 x 256 Demo (Saving Attention Maps Individually)

In [3]:
patch = Image.open('./image_demo/image_256.png')
output_dir = './attention_demo/256_output_indiv/'
os.makedirs(output_dir, exist_ok=True)
create_patch_heatmaps_indiv(patch=patch, model256=model256, 
                            output_dir=output_dir, fname='patch',
                            cmap=light_jet, device256=device256)

### 256 x 256 Demo (Concatenating + Saving Attention Maps)

In [9]:
patch = Image.open('./image_demo/image_256.png')
output_dir = './attention_demo/256_output_concat/'
os.makedirs(output_dir, exist_ok=True)
create_patch_heatmaps_concat(patch=patch, model256=model256, 
                            output_dir=output_dir, fname='patch',
                            cmap=light_jet, device256=device256)

### 4096 x 4096 Demo (Saving Attention Maps Individually)

In [3]:
region = Image.open('./image_demo/image_4k.png')
output_dir = './attention_demo/4k_output_indiv/'
os.makedirs(output_dir, exist_ok=True)
create_hierarchical_heatmaps_indiv(region, model256, model4k, 
                                   output_dir, fname='region', 
                                   scale=2, threshold=0.5, cmap=light_jet, alpha=0.5,
                                   device256=device256, device4k=device4k)

### 4096 x 4096 Demo (Concatenating + Saving Attention Maps)

In [19]:
region = Image.open('./image_demo/image_4k.png')
output_dir = './attention_demo/4k_output_concat/'
os.makedirs(output_dir, exist_ok=True)
create_hierarchical_heatmaps_concat(region, model256, model4k, 
                                   output_dir, fname='region', 
                                   scale=2, threshold=0.5, cmap=light_jet, alpha=0.5,
                                   device256=device256, device4k=device4k)