In [1]:
import torch
import torch.nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from boardGPT.models import GameAutoEncoder
from transformers import AutoTokenizer

In [2]:
model, model_config = GameAutoEncoder.from_pretrained(repo_id="theartificialis/Othello-Synthetic-AutoEncoder-20m-S")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("theartificialis/Othello-Synthetic-AutoEncoder-20m-S", subfolder="tokenizer")

In [4]:
import faiss
d = 64
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, 8192, 16, 12)

In [5]:
from boardGPT.datasets import GameDataset, collate_fn

n_samples = 10000

dataset = GameDataset(
    data_dir="../../data/othello/othello-synthetic",
    split="train",
    num_samples=n_samples,
)

# Create a dataloader
dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=1024,
    num_workers=8,
    pin_memory=True,
    shuffle=False,
    drop_last=False,
    collate_fn=lambda b: collate_fn(b, tokenizer)
)

In [6]:
len(dataloader)

10

In [7]:
batch = next(iter(dataloader))
batch[0].shape

torch.Size([1024, 60])

In [8]:
enc = model.encode_indices(batch[0])

In [9]:
enc.shape

torch.Size([1024, 64])

In [10]:
train_vectors = list()
for batch in dataloader:
    x = batch[0]
    enc_x = model.encode_indices(x)
    train_vectors.append(enc_x)
# end for
train_vectors = torch.cat(train_vectors, dim=0)

In [12]:
index.train(train_vectors)  # Training cookbook

