In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch

import generators
import modules
from orbit_dataset import get_datasets

from torch.nn import MSELoss

from model import AutoEncoderModel, PDNetOrbit5k
from utils import ChamferLoss, HungarianLoss, SlicedWasserstein

from torch.optim import Adam, AdamW

from torch.nn import MSELoss

from model_train import train_epoch_ae, train_epoch_full

from calculate_metrics import pd_to_pd_ae_metrics, orbit5k_metrics

from torch.optim.lr_scheduler import LinearLR, ReduceLROnPlateau, SequentialLR

In [2]:
dataset_train, dataset_test, dataloader_train, dataloader_test, n_max = get_datasets(200, 32, 700,300)

Rips(maxdim=1, thresh=inf, coeff=2, do_cocycles=False, n_perm = None, verbose=True)


  0%|          | 0/1000 [00:00<?, ?it/s]

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

lr = 0.001
n_points = 300
set_channels = 3
warmup_iters = 2
n_epochs = 1 + warmup_iters

criterion = SlicedWasserstein(n_projections=2)
crit_chamfer = ChamferLoss()
criterion_hungarian = HungarianLoss()
mse = MSELoss(reduction='mean')

# model init
encoder_pd = modules.MLPEncoder(3, 64, 64, 3)
generator = generators.TopNGenerator(set_channels=3, cosine_channels=32, max_n=n_max, latent_dim=64)
decoder_pd = modules.TransformerDecoder(n_in=3, latent_dim=64, fc_dim=32, num_heads=2, num_layers=2, n_out=3, 
                                        generator=generator, n_out_lin=32, n_hidden=32, num_layers_lin = 2, 
                                        dropout = 0.1, use_conv=True)

encoder_data = modules.MLPEncoder(2, 64, 64, 3)
decoder_data = modules.MLPDecoder(n_in=64, n_hidden=64, n_out=n_points * 2, num_layers=1, set_channels=2)

model = PDNetOrbit5k(encoder_data, decoder_data, encoder_pd, decoder_pd).to(device)

optimizer = Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, min_lr=1e-6)

optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler1 = LinearLR(optimizer, start_factor=0.0000001, total_iters=warmup_iters)

In [4]:
model_classificator = modules.CustomPersformer(n_in=3, embed_dim=64, fc_dim=128, num_heads=4, num_layers=5, n_out_enc=5, dropout=0.0, 
                           reduction="attention", use_skip=False).to(device)
checkpoint = torch.load("pretrained_models/persformer_orbit5k_77_test_acc_only_one_dim.pt", map_location=device)
model_classificator.load_state_dict(checkpoint)

<All keys matched successfully>

In [5]:
loss_train, loss_test = [], []
for epoch_idx in range(n_epochs):
    train_loss, test_loss = train_epoch_full(model, dataloader_train, dataloader_test, optimizer, 
                                             criterion, crit_chamfer, progress=False)
    loss_train.append(train_loss)
    loss_test.append(test_loss)
    
    test_acc_approx, w2, chamfer = orbit5k_metrics(model, model_classificator, dataloader_train, dataloader_test)
    print(test_acc_approx, w2, chamfer)
    if epoch_idx < warmup_iters:
        scheduler1.step()
    else:
        if epoch_idx == warmup_iters:
            scheduler2 = ReduceLROnPlateau(optimizer, patience=20, min_lr=1e-6)
        scheduler2.step(test_loss)

100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [00:04<00:00,  4.81it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 11.56it/s]


0.24666666666666667 0.03777184337377548 0.007398912683129311


100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [00:04<00:00,  5.26it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 12.00it/s]


0.24666666666666667 0.0016885397490113974 0.0025124719832092524


100%|██████████████████████████████████████████████████████████████████████████████████| 22/22 [00:03<00:00,  5.66it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 12.12it/s]


0.19666666666666666 0.0008151328074745834 0.0012513446854427457
