In [1]:
import time, math, gc, random
import pandas as pd

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from trim_transformer.transformer_layers import TrimTransformerEncoderLayer, TrimTransformerEncoder
from doublependulum_loader import load_double_pendulum_data
import matplotlib.pyplot as plt
import os

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
DATA_PATH = "/home/ubuntu/HSUT/traj_2s_dt_pow_15_const_params"
time_trimmed = False
init_conds, trajs, t_coords = load_double_pendulum_data(DATA_PATH, time_trimmed=time_trimmed, const_pendulum_parameters=True)
init_conds = init_conds.to(device) # (N, 1, Nx=1, Ny=1, Q)
trajs = trajs.to(device) # (N, T, Nx=1, Ny=1, Q)
t_coords = t_coords.to(device) # (T+1)

class TokensDataset(Dataset):
    def __init__(self, u, a, n_timesteps=None):
        N, T, Nx, Ny, Q = u.shape
        if n_timesteps is not None and n_timesteps < T:
            idx = np.linspace(0, u.shape[1] - 1, num=n_timesteps, dtype=int)
            u = u[:, idx]
        else:
            n_timesteps = T
        self.data = torch.cat([a, u], dim=1).reshape(N, n_timesteps+1, Nx, Ny, Q)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]

N_TIMESTEPS = len(t_coords) - 1
full_ds = TokensDataset(trajs, init_conds, n_timesteps=N_TIMESTEPS)
train_size = 1 #int(0.8 * len(full_ds)) #######
randomly_split = False
if randomly_split:
    val_size = len(full_ds) - train_size
    train_ds, val_ds = random_split(full_ds, [train_size, val_size])
else:
    train_ds = full_ds[:train_size]
    val_ds = full_ds[train_size:]

BATCH_SIZE = 1
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
# val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False) ######
# print(f"Train/Val samples: {len(train_ds)} / {len(val_ds)}")


ICs Loaded
File not saved, data already exists at: /home/ubuntu/HSUT/traj_2s_dt_pow_15_const_params/max_time_2_fps_32_trimmed/ic_0_dt_pow_15.npy


In [3]:
N, T, Nx, Ny, Q = trajs.shape
X_COMPRESSION = 1
Y_COMPRESSION = 1
Nx_ = Nx // X_COMPRESSION
Ny_ = Ny // Y_COMPRESSION
n_tokens = N_TIMESTEPS * Nx_ * Ny_  
block_size = Q * X_COMPRESSION * Y_COMPRESSION # Included multiplication by Q, should be more general

num_vars_predicting = 2

In [4]:
def make_block_mask_after(n_tokens, block_size):
    idx = torch.arange(n_tokens, dtype=torch.long)
    mask_after = block_size * ((idx // block_size) + 1)-1
    return mask_after

def mask_after_to_dense_mask(mask_after):
    n_tokens = mask_after.shape[0]
    col_indices = torch.arange(n_tokens)
    return (col_indices > mask_after.unsqueeze(1))

mask_after = make_block_mask_after(n_tokens, block_size)
dense_mask = mask_after_to_dense_mask(mask_after)
mask_after = mask_after.to(device)
dense_mask = dense_mask.to(device)

In [5]:
a = make_block_mask_after(10, 2)
b = mask_after_to_dense_mask(a)
print(a)
print(b)

tensor([1, 1, 3, 3, 5, 5, 7, 7, 9, 9])
tensor([[False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False]])


In [6]:
EMBED_DIM = 64
NUM_HEADS = 4
NUM_LAYERS = 4
DROPOUT = 0.1
DIM_FEEDFORWARD = 256

In [7]:
class PatchwiseMLP(nn.Module):
    def __init__(self, dim, hidden_dim=32, out_dim=32,hidden_ff=64,K=[1,1],S=[1,1]):
        super(PatchwiseMLP, self).__init__()
        self.conv_layer1 = nn.Conv2d(dim, hidden_dim,
                                     kernel_size=K,
                                     stride=S)

        self.fc1 = nn.Linear(hidden_dim, hidden_ff)
        self.fc2 = nn.Linear(hidden_ff, out_dim)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_ff)

    def forward(self, x):
        B, T, H, W, Q = x.shape
        
        out = x.permute(0, 1, 4, 2, 3).reshape(B * T, Q, H, W)  # (B*T, Q, H, W)
        out = self.conv_layer1(out)  # (B*T, hidden_dim, H', W')
        out = out.permute(0, 2, 3, 1)  # (B*T, H', W', hidden_dim)
        out = self.norm1(out)
        out = self.relu1(out)
        out = self.fc1(out)  # (B*T, H', W', hidden_ff)
        out = self.norm2(out)
        out = self.relu2(out)
        out = self.fc2(out)  # (B*T, H', W', out_dim)

        _BT, H_prime, W_prime, C_out = out.shape
        out = out.contiguous().view(B, T, H_prime, W_prime, C_out)  # (B, T, H', W', out_dim)
        return out

softmax_encoder = PatchwiseMLP(
    Q, # Changed from 1, Q should be more general
    EMBED_DIM,
    EMBED_DIM,
    hidden_ff=EMBED_DIM,
    K=[X_COMPRESSION, Y_COMPRESSION],
    S=[X_COMPRESSION, Y_COMPRESSION]
)

trim_encoder = PatchwiseMLP(
    Q, # Changed from 1, Q should be more general
    EMBED_DIM,
    EMBED_DIM,
    hidden_ff=EMBED_DIM,
    K=[X_COMPRESSION, Y_COMPRESSION],
    S=[X_COMPRESSION, Y_COMPRESSION]
)

softmax_decoder = PatchwiseMLP(
    EMBED_DIM,
    EMBED_DIM,
    block_size,
    hidden_ff=EMBED_DIM,
    K=[1, 1],
    S=[1, 1]
)

trim_decoder = PatchwiseMLP(
    EMBED_DIM,
    EMBED_DIM,
    block_size,
    hidden_ff=EMBED_DIM,
    K=[1, 1],
    S=[1, 1]
)

In [8]:
# In this example, we use the positional encoding from https://arxiv.org/abs/1706.03762. This is
# an unnatural choice for this dataset, but it is the classic choice for language models.
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)
    def forward(self, x):
        return x + self.pe[:, : x.size(1)]

In [9]:
pos_enc = PositionalEncoding(EMBED_DIM, max_len=T*Nx*Ny)

softmax_layer = nn.TransformerEncoderLayer(d_model=EMBED_DIM,
                                            nhead=NUM_HEADS,
                                            dim_feedforward=DIM_FEEDFORWARD,
                                            batch_first=True)
softmax_model = nn.TransformerEncoder(softmax_layer, num_layers=NUM_LAYERS)


norm_q = nn.LayerNorm(EMBED_DIM//NUM_HEADS)
norm_k = nn.LayerNorm(EMBED_DIM//NUM_HEADS)
norm_v = nn.LayerNorm(EMBED_DIM//NUM_HEADS)
scale = 1 / (N_TIMESTEPS * Nx_ * Ny_)

trim_layer = TrimTransformerEncoderLayer(d_model=EMBED_DIM,
                                                       nhead=NUM_HEADS,
                                                       dim_feedforward=DIM_FEEDFORWARD,
                                                       norm_q=norm_q,
                                                       norm_k=norm_k,
                                                       norm_v=norm_v,
                                                       scale=scale,
                                                       batch_first=True)
trim_model = TrimTransformerEncoder(trim_layer, num_layers=NUM_LAYERS)

# Compose the encoder, positional encoding, transformer, and decoder into a single module.
class Pipeline(nn.Module):
    def __init__(self, encoder, decoder, pos_enc, transformer, mask):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enc = pos_enc
        self.transformer = transformer
        self.mask = mask

    def forward(self, x):
        B, T, _, _, _ = x.shape
        z = self.encoder(x)
        y = z.flatten(1, 3)
        y = self.pos_enc(y)
        y = self.transformer(y, mask=self.mask)
        y = y.reshape_as(z)
        y = self.decoder(y)
        return y.reshape_as(x)
        # return y.reshape(B, T, Nx, Ny, 2)

softmax_pipeline = Pipeline(softmax_encoder, softmax_decoder, pos_enc, softmax_model, None).to(device)
trim_pipeline = Pipeline(trim_encoder, trim_decoder, pos_enc, trim_model, None).to(device)

print("softmax_encoder parameter count:", sum(p.numel() for p in softmax_encoder.parameters()))
print("softmax_decoder parameter count:", sum(p.numel() for p in softmax_decoder.parameters()))
print("trim_encoder parameter count:", sum(p.numel() for p in trim_encoder.parameters()))
print("trim_decoder parameter count:", sum(p.numel() for p in trim_decoder.parameters()))
print("softmax_model parameter count:", sum(p.numel() for p in softmax_model.parameters()))
print("trim_model parameter count:", sum(p.numel() for p in trim_model.parameters()))

softmax_encoder parameter count: 9152
softmax_decoder parameter count: 9096
trim_encoder parameter count: 9152
trim_decoder parameter count: 9096
softmax_model parameter count: 199936
trim_model parameter count: 200320


In [10]:
# A significant amount of memory (~4GB) is consumed by the dataset. So that the displayed memory
# usage demonstrates the memory usage of training the model, we set the memory usage before model
# training as a softmax. Note that the models have the same parameter count, so the memory usage
# of their weights is the same.
if device.type == "cuda":
    torch.cuda.synchronize()
    baseline_memory = torch.cuda.memory_allocated() / 1024**2
    torch.cuda.reset_peak_memory_stats()
    print(f"Baseline memory usage: {baseline_memory:.1f}MB")
else:
    baseline_memory = 0.0
    print("Using CPU - memory tracking not available")

Baseline memory usage: 1.7MB


In [None]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    start = time.time()
    running = 0.0
    running_per_var = torch.zeros(num_vars_predicting)
    for traj in loader:
        optimizer.zero_grad()
        pred = model(traj[:, :1].repeat(1, T, 1, 1, 1))
        values = traj[:, 1:, :, :, :]
        with torch.no_grad():
            for i in range(num_vars_predicting):
                running_per_var[i] += criterion(pred[..., i], values[..., i]).item()
        loss = criterion(pred[..., :num_vars_predicting], values[..., :num_vars_predicting])
        loss.backward()
        optimizer.step()
        running += loss.item() * traj.size(0)
    elapsed = time.time() - start
    return running / len(loader.dataset), running_per_var / len(loader), elapsed

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running = 0.0
    running_per_var = torch.zeros(num_vars_predicting)
    for traj in loader:
        pred = model(traj[:, :1].repeat(1, T, 1, 1, 1))
        values = traj[:, 1:, :, :, :]
        for i in range(num_vars_predicting): # for only angles
            running_per_var[i] += criterion(pred[..., i], values[..., i]).item()
        loss = criterion(pred[..., :num_vars_predicting], values[..., :num_vars_predicting])
        running += loss.item() * traj.size(0)
    return running / len(loader.dataset), running_per_var / len(loader)

def peak_mem():
    if device.type == "cuda":
        torch.cuda.synchronize()
        m = torch.cuda.max_memory_allocated() / 1024**2
        torch.cuda.reset_peak_memory_stats()
        return m - baseline_memory  # Subtract baseline memory usage
    return 0.0

EPOCHS = 10000
lr = 1e-3
weight_decay = 1e-5
criterion = nn.MSELoss()

def get_prediction(model, trajectory_batch):
    model.eval()
    with torch.no_grad():
        initial_state = trajectory_batch[:, :1]
        model_input = initial_state.repeat(1, T, 1, 1, 1)
        pred = model(model_input)
    return pred

predictee_full_trajectory = next(iter(train_loader))[0:1] # Batch of 1: shape (1, 64, 1, 1, 8)
prediction_dir = "predictions_over_time"
os.makedirs(prediction_dir, exist_ok=True)
np.save(os.path.join(prediction_dir, "predictee_trajectory.npy"), predictee_full_trajectory.cpu().numpy())

weights_dir = "trim_weights"
os.makedirs(weights_dir, exist_ok=True)

results = {}
for name, model in [("trim", trim_pipeline)]:#, ("softmax", softmax_pipeline)]:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    hist = {"train": [], "train_per_var": [], "val": [], "val_per_var": [], "time": [], "mem": []}
    for ep in range(1, EPOCHS+1):
        train_loss, train_loss_per_var, t = train_epoch(model, train_loader, optimizer, criterion)
        # val_loss, val_loss_per_var = evaluate(model, val_loader, criterion)
        mem = peak_mem()
        hist["train"].append(train_loss)
        hist["train_per_var"].append(train_loss_per_var)
        # hist["val"].append(val_loss)
        # hist["val_per_var"].append(val_loss_per_var)
        hist["time"].append(t)
        hist["mem"].append(mem)
        train_per_var_str = ", ".join([f"{v:.3e}" for v in train_loss_per_var])
        # val_per_var_str = ", ".join([f"{v:.3e}" for v in val_loss_per_var])
        print(f"{name:10s} | epoch {ep}/{EPOCHS} | train {train_loss:.3e} | train_per_var [{train_per_var_str}]")#| val {val_loss:.3e} | val_per_var [{val_per_var_str}] | {t:.2f}s | mem {mem:.1f}MB")
        if ep % 20 == 0:
            current_pred = get_prediction(model, predictee_full_trajectory)
            save_path = os.path.join(prediction_dir, f"prediction_{name}_epoch_{ep:04d}.npy")
            np.save(save_path, current_pred.cpu().numpy())
            weights_save_path = os.path.join(weights_dir, f"weights_epoch_{ep}.pt")
            torch.save({"state_dict": model.state_dict()}, weights_save_path)
            print(f"Saved model at epoch {ep}")
    results[name] = hist
    #plt.plot(hist.get("train"), label="Training Loss")
    #plt.plot(hist.get("val"), label="Validation Loss")
    #plt.title(f"{name} Double Pendulum (Easier ICs) Training Curve")
    #plt.legend()
    #plt.show()
    
    if device.type == "cuda":
        torch.cuda.empty_cache(); gc.collect()

trim       | epoch 1/100 | train 1.634e+01 | train_per_var [3.097e+00, 2.957e+01]
trim       | epoch 2/100 | train 1.318e+01 | train_per_var [2.781e+00, 2.357e+01]
trim       | epoch 3/100 | train 1.174e+01 | train_per_var [2.682e+00, 2.080e+01]
trim       | epoch 4/100 | train 1.077e+01 | train_per_var [2.611e+00, 1.892e+01]
trim       | epoch 5/100 | train 1.014e+01 | train_per_var [2.602e+00, 1.768e+01]
trim       | epoch 6/100 | train 9.366e+00 | train_per_var [2.551e+00, 1.618e+01]
trim       | epoch 7/100 | train 8.802e+00 | train_per_var [2.591e+00, 1.501e+01]
trim       | epoch 8/100 | train 8.256e+00 | train_per_var [2.605e+00, 1.391e+01]
trim       | epoch 9/100 | train 7.832e+00 | train_per_var [2.490e+00, 1.318e+01]
trim       | epoch 10/100 | train 7.503e+00 | train_per_var [2.483e+00, 1.252e+01]
trim       | epoch 11/100 | train 7.208e+00 | train_per_var [2.436e+00, 1.198e+01]
trim       | epoch 12/100 | train 6.961e+00 | train_per_var [2.419e+00, 1.150e+01]
trim       | 

In [None]:
# summary = pd.DataFrame.from_dict({
#     k: {
#         "train_loss": min(v["train"]),
#         "val_loss": min(v["val"]),
#         "time/epoch (s)": np.mean(v["time"]),
#         "peak_mem (MB)": max(v["mem"]),
#     } for k, v in results.items()
# }, orient="index")
# print(summary)

ValueError: min() arg is an empty sequence