In [None]:
# Enable autoreload of module
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from vector_quantize_pytorch import VectorQuantize
import os
from data.neural_field_datasets_shapenet import ShapeNetDataset, FlattenTransform3D, TokenTransform3D, ModelTransform3DFromTokens, ModelTransform3D
from training import training_nano_gpt
from utils.visualization3d import visualize_model3d, model_to_mesh

from networks.nano_gpt import GPTConfig

torch.cuda.is_available()

In [None]:
kwargs = {
"type": "pretrained",
"fixed_label": None,
}

dir_path = os.path.dirname(os.path.abspath(os.getcwd()))
data_root = os.path.join(dir_path, "adl4cv")

# load used vector quantizer
vq_dicts = torch.load(os.path.join(data_root, "models", "vq_search_results", "vq_model_dim_1_vocab_255_batch_size_32768_threshold_ema_dead_code_0_kmean_iters_0.pth"))
vq = VectorQuantize(**vq_dicts["vq_config"])
vq.load_state_dict(vq_dicts["state_dict"])
vq.to("cuda")

In [None]:
dataset = ShapeNetDataset(os.path.join(data_root, "datasets", "plane_mlp_weights"), transform=TokenTransform3D(vq))
dataset_flatten = ShapeNetDataset(os.path.join(data_root, "datasets", "plane_mlp_weights"), transform=FlattenTransform3D())
dataset_model = ShapeNetDataset(os.path.join(data_root, "datasets", "plane_mlp_weights"), transform=ModelTransform3D())
backtransform = ModelTransform3DFromTokens(vq)

In [None]:
# Load all of the files in models/vq_search_results using glob


dir_path = os.path.dirname(os.path.abspath(os.getcwd()))
data_root = os.path.join(dir_path, "adl4cv")

test_idx = 14

for root, dirs, files in os.walk('models/vq_search_results'):
    losses = []
    for file in files:
        
        if file.endswith('.pth'):
            vq_dict = torch.load(os.path.join(data_root, "models", "vq_search_results", file))
            loss = vq_dict["loss"][-1]
            losses.append(loss)

    # find top n min loss idx
    n = 5
    min_loss_idx = sorted(range(len(losses)), key=lambda i: losses[i])[:n]
    for idx in min_loss_idx:
        print(f"Found best configuration {files[idx]} with loss {losses[idx]} from {len(losses)} configurations")

In [None]:
mesh, sdf = model_to_mesh(dataset_model[0][0].cuda())
mesh.show()

In [None]:
mesh, sdf = model_to_mesh(backtransform(dataset[0][0].cuda(), None)[0].cuda())
mesh.show()

In [None]:
# Config Training
config = training_nano_gpt.Config()
config.learning_rate=3e-3
config.max_iters = 30000
config.weight_decay=0.00
config.decay_lr=True
config.lr_decay_iters=config.max_iters
config.warmup_iters=0.05*config.max_iters
config.batch_size = 8
config.gradient_accumulation_steps = 1
config.init_from = "scratch"
config.out_dir ="models/shapenet_token_transformer"
config.detailed_folder = "training_sample_5"

config.wandb_project = "shapenet_token_transformer"

config.eval_interval = 250
config.metric_interval = 250

max_len = dataset[0][0].size(0)
model_config = GPTConfig(n_embd=64, block_size=512, n_head=8, n_layer=8, vocab_size=vq_dicts["vq_config"]["codebook_size"] + 1, dropout=0.0, max_len= max_len)


In [None]:
cb_size = vq_dicts["vq_config"]["codebook_size"]
token_dict = {
    "SOS": cb_size,
}


In [None]:
# Where to put?
# Maybe adjust dataset to be able to work with splitting data and then rewrite TokenTransform 
# to do the job combined with pytorch dataloader (get_batch == __call__ of Dataloader)

def create_split_indices(n, train_ratio=0.9):
    # Generate a random permutation of indices from 0 to n-1
    shuffled_indices = torch.randperm(n)
    # Determine the cut-off for training data
    train_size = int(train_ratio * n)
    # Split indices into training and validation sets
    train_indices = shuffled_indices[:train_size]
    val_indices = shuffled_indices[train_size:]
    return train_indices, val_indices

train_indices, val_indices = create_split_indices(len(dataset))

def get_batch_lambda(config, dataset, model_config, split):
    batch_size = config.batch_size
    

    # Select indices based on the split
    if split == 'train':
        # Randomly select batch_size indices from the train_indices
        indices = train_indices[torch.randint(0, len(train_indices), (batch_size,))]
    elif split == 'val':
        # Randomly select batch_size indices from the val_indices
        indices = val_indices[torch.randint(0, len(val_indices), (batch_size,))]
    
    
    # Initialize lists to hold the sequences and labels
    samples = []
    labels = []

    # Collect samples and labels
    for idx in indices:
        sample, label = dataset[idx]
        start_tokens = torch.Tensor([token_dict["SOS"]]).long().to(sample)  # Start of sequence token
        sample = torch.cat((start_tokens, sample), dim=0)
        #start_tokens = torch.Tensor([0]).long()  # Start of sequence token
        #sample = torch.cat((start_tokens, sample + 1), dim=0)
        samples.append(sample)
        labels.append(label)

    # Prepare the sequences for model input
    max_len = samples[0].size(0)
    x = torch.zeros((batch_size, max_len - 1), dtype=torch.long)
    y = torch.zeros((batch_size, max_len - 1), dtype=torch.long)
    
    for i, sample in enumerate(samples):
        end_index = sample.size(0) - 1
        x[i, :end_index] = sample[:-1]  # Exclude the last token for x
        y[i, :end_index] = sample[1:]   # Exclude the first token for y

    idx = torch.randint(0, max_len - 1 - model_config.block_size, (batch_size,))
    x_cutted = torch.zeros((batch_size, model_config.block_size), dtype=torch.long)
    y_cutted = torch.zeros((batch_size, model_config.block_size), dtype=torch.long)

    for i, offset in enumerate(idx):
        x_cutted[i, :] = x[i, offset:offset + model_config.block_size]
        y_cutted[i, :] = y[i, offset:offset + model_config.block_size]

    # x and y have to be
    x_cutted = x_cutted.to(config.device)
    y_cutted = y_cutted.to(config.device)
    idx = idx.to(config.device)

    return x_cutted, y_cutted, idx

create_get_batch = lambda config, dataset, model_config: lambda split: get_batch_lambda(config, dataset, model_config, split)
get_batch = create_get_batch(config, dataset, model_config)

In [None]:
# Prepeare model parameters and train
import wandb
trained_model = training_nano_gpt.train(get_batch, config, model_config, vq, vq_dicts["vq_config"], token_dict=token_dict)


In [None]:
import matplotlib.pyplot as plt
import torch
from networks.nano_gpt import GPT
from utils import get_default_device

model_dict = torch.load("./models/shapenet_token_transformer/ckpt.pt")
# Configuration
print(model_dict.keys())
idx = 3

device = get_default_device()
model = GPT(model_dict["c"])#model_dict
model.to(device=device)
model.load_state_dict(model_dict["model"])
model.eval()

vq = VectorQuantize(**model_dict["vq_config"])
vq.load_state_dict(model_dict["vq_state_dict"])
vq.to(device=device)
vq.eval()

In [None]:
model.eval()

# time the generation
import time
start = time.time()
novel_tokens = model.generate(torch.Tensor([[token_dict["SOS"]]]).long().to(device="cuda"), max_len, temperature=20, top_k=1)
print(f"Time for generation: {time.time() - start}")


In [None]:
tokens = novel_tokens.squeeze(0)[1:]
mlp_model, label = ModelTransform3DFromTokens(vq)(tokens.detach(), None)

In [None]:
mesh, sdf = model_to_mesh(mlp_model.cuda())
mesh.show()