In [2]:
import torch
import h5py
import time
import pickle
import numpy as np
from tqdm.notebook import trange, tqdm


DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Model Definition

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle


class WolfPQEncoder(nn.Module):

    def __init__(self, dim, M, K, min_dist_factor, semantic_sampling) -> None:
        super(WolfPQEncoder, self).__init__()
        self.dim = dim
        self.M = M
        self.K = K
        self.min_dist_factor = min_dist_factor
        self.norm = nn.LayerNorm(dim)
        self.layer1 = nn.Linear(dim, M * K // 2)
        self.layer2 = nn.Linear(M * K // 2, M * K)
        self.interp_param = nn.Parameter(torch.randn(1))


        self.semantic_sampling = semantic_sampling

        if not self.semantic_sampling:
          self.layer1.requires_grad_(False)
          self.layer2.requires_grad_(False)
          self.interp_param.requires_grad_(False)


    def _compute_distance(self, x, codebook):
      """
      Computes the expected quantizer indices based on the minimum distance
      """

      x_reshaped = x.reshape(x.shape[0], self.M, 1, self.dim // self.M)
      codebook_reshaped = codebook.reshape(1, self.M, self.K, self.dim // self.M)

      dist = torch.sum((codebook_reshaped - x_reshaped) ** 2, dim=-1)

      return dist


    def forward(self, x, codebook):

        if x.shape[-1] != self.dim:
            raise Exception(f'Expected embedding if size {self.dim}')


        dist = self._compute_distance(x, codebook)
        min_dist = F.one_hot(torch.argmin(dist, dim=-1), self.K).to(dtype=torch.float)

        if self.semantic_sampling:

          # propoage vector throught he feed forward neural network
          h = F.tanh(self.layer1(self.norm(x)))
          a = F.relu(self.layer2(h))
          a_reshaped = a.reshape(-1, self.M, self.K)

          # compute the interpolation patameter
          interp = F.sigmoid(self.interp_param)

          # sample from gumbels softmax distribution based on the miniumum distance corrected by the semantic information derived from the vector
          # s = F.gumbel_softmax((1.0 - interp ) * a_reshaped + interp * self.min_dist_factor * min_dist, hard=True, dim=-1)

          s = F.one_hot(torch.argmax((1.0 - interp ) * a_reshaped + interp * self.min_dist_factor * min_dist, dim=-1), num_classes=self.K)


        else:

          # sample base don minimum distance only
          # s = F.gumbel_softmax(self.min_dist_factor * min_dist, hard=True, dim=-1)

          s = F.one_hot(torch.argmax(self.min_dist_factor * min_dist, dim=-1), num_classes=self.K)


        return s



class WolfPQ(nn.Module):

    def __init__(self, dim, M, K, min_dist_factor, semantic_sampling: bool, pq_index_path: str) -> None:
        super(WolfPQ, self).__init__()

        self.dim = dim
        self.M = M
        self.K = K
        self.min_dist_factor = min_dist_factor
        self.rotation_matrix = nn.Parameter(torch.stack([torch.eye(dim // M) for _ in range(M)]))

        self.encoder = WolfPQEncoder(dim, M, K, min_dist_factor, semantic_sampling)

        if pq_index_path == None:
            self.codebook = nn.Parameter(torch.randn((M, K, dim // M)))

        else:
            pq_index = pickle.load(open(pq_index_path, 'rb'))
            initial_codebook = pq_index['codebook']
            self.codebook = nn.Parameter(torch.from_numpy(initial_codebook))




    def forward(self, x):

        if x.shape[-1] != self.dim:
            raise Exception(f'Expected embedding if size {self.dim}')

        # apply the rotational matrix to the input
        codebook_rot = self.codebook @ self.rotation_matrix

        # get the indices into the codebook based on the vector
        s = self.encoder(x, codebook_rot)

        # construct the quantized vector to be used during training
        res = codebook_rot.reshape(-1, *codebook_rot.shape) * s.reshape(*s.shape, -1)
        res1 = res.sum(dim=2)
        res2 = res1.reshape(res1.shape[0], -1)

        return res2, s

# Setup the quantizer path and other parameters

In [4]:
# Model parameters
# the path of the quantization model
MODEL_PATH = './drive/MyDrive/MasterThesis/final_trained_models/tct_24_1024_listwise/tct_24_1024_sem_mdf30_listwise_epoch_1.pt'
M=24
K=1024
MDF=30.0
SEMANTIC_SAMPLING=True
DIM=768

In [5]:
# Dataset parameters
DATASET_PATH = './drive/MyDrive/MasterThesis/datasets/tct_colbert.h5'
OUTPUT_PATH = './drive/MyDrive/MasterThesis/wolfpq_indices/final_indices/wolfpq_final_tct_24_1024_sem_mdf30_listwise_epoch_1_backup.pickle'
BATCH_SIZE=500

# Initialize the model

In [6]:
model = WolfPQ(DIM, M, K, MDF, semantic_sampling=SEMANTIC_SAMPLING, pq_index_path=None)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

WolfPQ(
  (encoder): WolfPQEncoder(
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layer1): Linear(in_features=768, out_features=12288, bias=True)
    (layer2): Linear(in_features=12288, out_features=24576, bias=True)
  )
)

In [7]:
model = model.to(DEVICE)

# Run the quantization process

In [8]:
with h5py.File(DATASET_PATH, 'r') as dataset:

    # load the dataset
    vectors = dataset['vectors']
    dim = vectors.shape[1]


    # COMPUTE THE Q_CODES
    with torch.no_grad():

        vector_count = int(vectors.shape[0])
        num_batches = vector_count // BATCH_SIZE + (1 if vector_count % BATCH_SIZE > 0 else 0)

        code_batches = []

        for bi in trange(0, num_batches):


            index_start = bi * BATCH_SIZE
            index_end = min((bi + 1) * BATCH_SIZE, vector_count)

            vector_batch = np.array(vectors[index_start:index_end, :])

            vector_batch_torch = torch.from_numpy(vector_batch).to(DEVICE)
            _, code_batch_torch = model(vector_batch_torch)
            code_batch = torch.argmax(code_batch_torch, dim=-1).cpu().numpy()


            code_batches.append(code_batch)



        q_codes = np.concatenate(code_batches, axis=0)



        # GET THE CODEBOOK
        codebook = (model.codebook @ model.rotation_matrix).cpu().detach().numpy()

        # BUILD THE INDEX OBJECT

        index_obj = {
            'codebook': codebook,
            'quantized_index': q_codes,
            'M': M,
            'K': K,
            'vector_size': dim,
            'doc_ids': np.array(dataset['docids'][:])
        }


        # SAVE THE INDEX OBJECT
        pickle.dump(index_obj, open(OUTPUT_PATH, 'wb'))


  0%|          | 0/17684 [00:00<?, ?it/s]

In [25]:
pickle.dump(index_obj, open(OUTPUT_PATH, 'wb'))