In [None]:
import torch
import numpy as np
from models.vqvae import VQVAE
from masked_transformer import *
from torch.utils.data import DataLoader
from pipeline_utils import MRITokenDataset

In [None]:
model_path = '/home/mingjie/mri230/vqvae_checkpoints/newloss_reshiddens32_n_embeddings64_embed_dim16.pth'

checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)

model = VQVAE(
    h_dim=128, res_h_dim=32, n_res_layers=2,
    n_embeddings=64, embedding_dim=16, beta=0.25
)
model.load_state_dict(checkpoint["model"])
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
data = np.load("/home/mingjie/mri230/train_data/train_data.npy")    

In [None]:
n = 1
data = torch.from_numpy(data[:n]).float().to(device)   # shape (n, 1, 256, 256)

with torch.no_grad():
    _, _, _, tokens = model(data)

tokens = tokens.cpu().numpy()
print(tokens.shape)

In [None]:
checkpoint_name = 'newloss_reshiddens32_n_embeddings64_embed_dim16.pth'
context_slices = 3
mask_prob = 0.25

train_token_seq, val_token_seq, test_token_seq = tokenize(checkpoint_name, res_h_dim=32, embedding_dim=16)

train_dataset = MRITokenDataset(tokens=train_token_seq, context_slices=context_slices, mask_prob=mask_prob)
val_dataset = MRITokenDataset(tokens=val_token_seq, context_slices=context_slices, mask_prob=mask_prob)
test_dataset = MRITokenDataset(tokens=test_token_seq, context_slices=context_slices, mask_prob=mask_prob)

In [None]:
batch_size = 5
num_workers = 4 
train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
)

x, labels = next(iter(train_dataloader))
print(labels.shape)

In [None]:
print(f"train token range: {min(train_token_seq)} to {max(train_token_seq)}")
print(f"val token range: {min(val_token_seq)} to {max(val_token_seq)}")
print(f"test token range: {min(test_token_seq)} to {max(test_token_seq)}")

print(f"train unique tokens used: {len(np.unique(train_token_seq))}")
print(f"val unique tokens used: {len(np.unique(val_token_seq))}")
print(f"test unique tokens used: {len(np.unique(test_token_seq))}")