# Data

In [None]:
from dataset import GTSequenceDataset
from torch.utils.data import DataLoader

SEQ_IN_LEN = 10
SEQ_OUT_LEN = 10
SEQ_TOTAL_LEN = 20
BATCH_SIZE = 128

BASE_DIR = '../../Datasets/'
train_dataset = GTSequenceDataset.from_roots([
    f'{BASE_DIR}DanceTrack/train',
    f'{BASE_DIR}MOT17/train',
    f'{BASE_DIR}MOT20/train'
], seq_in_len=SEQ_IN_LEN, seq_out_len=SEQ_OUT_LEN, seq_total_len=SEQ_TOTAL_LEN)

val_dataset = GTSequenceDataset.from_roots([
    f'{BASE_DIR}DanceTrack/val',
    f'{BASE_DIR}MOT17/val',
    f'{BASE_DIR}MOT20/val'
], seq_in_len=SEQ_IN_LEN, seq_out_len=SEQ_OUT_LEN, seq_total_len=SEQ_TOTAL_LEN, noise_prob=0.6, noise_coeff=2)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f'Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}')

# Model

In [None]:
from transformer_encoder import MotionTransformer
from loss import LossFunction
from torch import optim

DEVICE = 'cuda'
# model = MotionTransformer(num_enc_layers=1, num_dec_layers=1, dim_ff=64, d_model=32, dropout=0, nhead=4).to(DEVICE)
model = MotionTransformer(d_model=128, dim_ff=256).to(DEVICE)
# model = LSTMPredictor(middle_dim=64, hidden_dim=256, num_layers=1).to(DEVICE)
criterion = LossFunction()

# Train

In [None]:
LR = 2e-3
NUM_EPOCHS = 40

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS + 1)

best_val_loss = float("inf")

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = model.train_one_epoch(train_loader, optimizer, criterion)
    val_loss = model.evaluate(val_loader, criterion)

    scheduler.step()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model.save_weight('pretrained/transformer-encoder-d128-ff256-3l.pth')

    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch}: Train Loss = {train_loss:.8f}, Val Loss = {val_loss:.8f}, LR = {current_lr:.8f}")

print("Training complete. Best Val Loss:", best_val_loss)