## Approaches on how to represent text embedding data

* Use padding and transpose the embeddings to make sure that:
    1. No information is lost
    2. Spatial integrity is maintained for the most part
    3. Although this results in a huge sparse matrix

In [70]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt, floor, ceil

In [127]:
def isPerfectSquare(n):
    if (ceil(sqrt(n)) == floor(sqrt(n))):
        return True
    else:
        return False
    
def nextPerfectSquare(num):
    """
    Gets you the next perfect square from a given number
    """
    orig = num
    limit = num + 1000
    while num < limit:
        if isPerfectSquare(num):
            print(f"{num} is the next perfect square")
            return num
        num += 1

In [5]:
def load_data():
    """
    Loads the embeddings saved in npz format into memory and convert into torch tensors
    """
    img_emb = np.load("working/image_embeddings_large.npy", allow_pickle=True)
    text_emb = np.load("working/text_embeddings_large.npy", allow_pickle=True)
    
    img_embeddings = [torch.from_numpy(np_arr) for np_arr in img_emb]
    text_embeddings = [torch.from_numpy(np_arr) for np_arr in text_emb]
    
    return img_embeddings, text_embeddings

img_emb, text_emb = load_data()

In [10]:
text_emb[0].shape

torch.Size([7, 768])

In [12]:
text_emb[0]

tensor([[ 0.0580, -0.2076, -0.1875,  ..., -0.0379,  0.3235,  0.1966],
        [ 0.4263, -0.3844, -0.4158,  ..., -0.1270,  0.5972, -0.0761],
        [-0.1596, -0.1977, -0.0715,  ..., -0.3146,  0.0337, -0.5706],
        ...,
        [ 0.6532, -0.1045, -0.3341,  ..., -0.2913, -0.0313, -0.1788],
        [ 0.1402, -0.7515, -0.4805,  ...,  0.2352,  0.0075, -0.6129],
        [ 0.8141,  0.0583, -0.4363,  ..., -0.0272, -0.3467, -0.4803]])

In [46]:
def num_patches(image_shape, patch_shape):
    """
    Expects image shape to be -> [H, W, C] (channels last format)
    """
    res = (image_shape[0] * image_shape[1]) / (patch_shape ** 2)
    return res

In [130]:
num_patches(
    image_shape=(768, 768),
    patch_shape=32
)

576.0

In [83]:
a = text_emb[0]
a.numel()

5376

In [124]:
# Try padding
a_new = F.pad(a.T, pad=(1, a.T.shape[0] - a.T.shape[1] - 1))
print(a_new.shape, a_new.numel())

torch.Size([768, 768]) 589824
