In [None]:
import torch
from torchvision import transforms
from einops import rearrange, repeat
from hipt_model_utils import get_vit256, get_vit4k
from hipt_heatmap_utils import Image

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f'Device is {device}')

In [None]:
class HIPT_4K(torch.nn.Module):
    """
    HIPT Model (ViT_4K-256) for encoding non-square images
    (with [256 x 256] patch tokens), with [256 x 256] patch
    tokens encoded via ViT_256-16 using [16 x 16] patch
    tokens
    """
    def __init__(self,
                 model256_path = '../HIPT_4K/Checkpoints/vit256_small_dino.pth',
                 model4k_path = '../HIPT_4K/Checkpoints/vit4k_xs_dino.pth',
                 device256 = torch.device(device),
                 device4k = torch.device(device)):

        super().__init__()
        self.model256 = get_vit256(pretrained_weights=model256_path).to(device)
        self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device)
        self.device256 = device256
        self.device4k = device4k
        #self.patch_filter_params = patch_filter_params

    def forward(self, x):
        """
        Forward pass of HIPT (given an image tensor x), outputing the [CLS] token from ViT_4K.
        1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K.
        2. x then gets unfolded into a "batch" of [256 x 256] images.
        3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch.
        4. These batch-of-features are then reshaped into a 2D features grid (of width "w_256" and height "h_256".)
        5. This feature grid is then used as the input to ViT_4K-256, outputing [CLS]_4K.

        Args:
            - x (torch.Tensor): [1 x C x W' x H'] image tensor.
        Return:
            - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
        """
        # 1. [1 x 3 x W x H].
        batch_256, w_256, h_256 = self.prepare_img_tensor(x)
        print(f'1. [1 x 3 x W x H] {batch_256.shape}')
        # 2. [1 x 3 x w_256 x h_256 x 256 x 256]
        batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256)
        print(f'2. [1 x 3 x w_256 x h_256 x 256 x 256] {batch_256.shape}')
        # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256)
        batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h')
        print(f'2. [B x 3 x 256 x 256], where B = (1*w_256*h_256) {batch_256.shape}')

        features_cls256 = []
        # 3. B may be too large for ViT_256. We further take minibatch
        for mini_bs in range(0, batch_256.shape[0], 256):
            print(f'Minibatch number {mini_bs}')
            minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True)
            # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches.
            features_cls256.append(self.model256(minibatch_256).detach().cpu())
        
        print(f'Lenght of the list of minibatches is {len(features_cls256)}')
        print(f'Shape of element 0 inside the list {features_cls256[0].shape}')

        # 3. [B x 384], where 384 == dim of ViT_256 [CLS] token
        features_cls256 = torch.vstack(features_cls256)
        print(f'3. [B x 384], where 384 == dim of ViT_256 [CLS] token {features_cls256.shape}')
        features_cls256 = features_cls256.reshape(w_256, h_256, 384)
        print(features_cls256.shape)
        features_cls256 = features_cls256.transpose(0,1).transpose(0,2).unsqueeze(dim=0)
        print(features_cls256.shape)
        # 4. [1 x 384 x w_256 x h_256]
        features_cls256 = features_cls256.to(self.device4k, non_blocking=True)
        print(f'4. [1 x 384 x w_256 x h_256] {features_cls256.shape}')
        # 5. [1 x 192], where 192 == dim of ViT_4K [CLS] token
        features_cls4k = self.model4k.forward(features_cls256)
        print(f'5. [1 x 192], where 192 == dim of ViT_4K [CLS] token {features_cls4k.shape}')
        
        return features_cls4k

    def prepare_img_tensor(self, img: torch.Tensor, patch_size=256):
        """
        Helper function that takes a non-square image tensor, and takes a center crop s.t.
        the width / height are divisible by 256.

        (Note: "_256" for w / h should technicaly be renamed as "_ps",
        but may not be easier to read.
        Until I need to make HIPT with patch_sizes != 256,
        keeping the naming convention as-is.)

        Args:
            - img (torch.Tensor): [1 x C x W' x H'] image tensor.
            - patch_size (int): Desired patch size to evenly subdivide the image.
        Return:
            - img_new (torch.Tensor): [1 x C x W x H] image tensor, where W and H
            are divisable by patch_size.
            - w_256 (int): # of [256 x 256] patches of img_new's width (e.g. - W/256)
            - h_256 (int): # of [256 x 256] patches of img_new's height (e.g. - H/256)
        """
        make_divisible = lambda l, patch_size: (l - (l % patch_size))
        b, c, w, h = img.shape
        #print(b, c, w, h)
        load_size = make_divisible(w, patch_size), make_divisible(h, patch_size)
        #print(load_size)
        w_256, h_256 = w // patch_size, h // patch_size
        #print(w_256, h_256)
        img_new = transforms.CenterCrop(load_size)(img)
        #print(img_new.shape)

        return img_new, w_256, h_256
        

In [None]:
def eval_transforms():
    """
    """
    mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
    eval_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    return eval_t

In [None]:
model = HIPT_4K()
model.eval()

In [None]:
region = Image.open('../HIPT_4K/image_demo/image_4k.png')
#region.show()

In [None]:
x = eval_transforms()(region).unsqueeze(dim=0)
print(x.shape)

In [None]:
out = model.forward(x)