In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch

from generators import TopNGenerator
from modules import MLPEncoder, MLPDecoder, TransformerEncoder, TransformerDecoder, GeneratorMLP
from orbit_dataset import get_datasets

from model import AutoEncoderModel
from utils import ChamferLoss, HungarianLoss, HungarianLossCustom, ChamferLossWeighted, HungarianLossDimensionMatching

from torch.optim import Adam, AdamW

from torch.nn import CrossEntropyLoss, MSELoss

from torch.optim.lr_scheduler import LinearLR, ReduceLROnPlateau, SequentialLR

In [2]:
dataset_train, dataset_test, dataloader_train, dataloader_test, n_max = get_datasets(1000, 64, 3500, 1500)

Rips(maxdim=1, thresh=inf, coeff=2, do_cocycles=False, n_perm = None, verbose=True)


  0%|          | 0/5000 [00:00<?, ?it/s]

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n_epochs = 300
lr = 0.001
set_channels = 3

criterion = HungarianLoss()

#encoder = TransformerEncoder(n_in=set_channels, embed_dim=64, fc_dim=128, num_heads=8, num_layers=5, n_out_enc=16)
#generator = TopNGenerator(set_channels=8, cosine_channels=8, max_n=n_max + 10, latent_dim=16)
#decoder = GeneratorMLP(n_in=8, n_hidden=128, n_out=3, num_layers=4, generator=generator)

encoder = MLPEncoder(n_in=3, n_hidden=64, n_out=16, num_layers=4)
decoder = MLPDecoder(n_in=16, n_hidden=64, n_out=n_max * set_channels, num_layers=4, set_channels=set_channels)

model = AutoEncoderModel(encoder, decoder).to(device)

optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler1 = LinearLR(optimizer, start_factor=0.01, total_iters=15)
scheduler2 = ReduceLROnPlateau(optimizer, patience=10, min_lr=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[15])

In [10]:
for epoch_idx in range(15, n_epochs):
        
    # train
    model.train()
        
    loss = 0
    for batch in tqdm(dataloader_train):
        src = batch[0].to(device)
        mask = batch[1].to(device)
        tgt = model(src, mask)
        loss_batch = criterion(src.to(torch.float), tgt)
        loss_batch.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss += loss_batch.detach().cpu()
        
    loss_train = loss / len(dataloader_train.dataset)
        
    # test
    model.eval()
    loss = 0
    for batch in tqdm(dataloader_test):
        src = batch[0].to(device)
        mask = batch[1].to(device)
        
        with torch.no_grad():
            tgt = model(src, mask)
            loss_batch = criterion(src.to(torch.float), tgt)
            loss += loss_batch
            
            
    loss_test = loss / len(dataloader_test.dataset)
    if epoch_idx < 15:
        scheduler.step()
    else:
        scheduler.step(loss_test)
        
    print("Epoch: {:3} {:.10f} {:.10f}".format(epoch_idx, loss_train, loss_test))

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

TypeError: step() takes 1 positional argument but 2 were given