In [1]:
import time, math, gc, random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from mat73 import loadmat as loadmat73
from trim_transformer.transformer_layers import TrimTransformerEncoderLayer, TrimTransformerEncoder

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Download the data from https://figshare.com/articles/dataset/Navier_Stokes_Dataset_mat/25606152?file=45665007
DATA_PATH = "./NS_data.mat"
data_dict = loadmat73(DATA_PATH)
u = torch.tensor(data_dict["u"]).to(device)  # (N, Nx, Ny, T)
a = torch.tensor(data_dict["a"]).to(device)  # (N, Nx, Ny)

class TokensDataset(Dataset):
    def __init__(self, u, a, n_timesteps=None):
        N, Nx, Ny, T = u.shape
        u = u.permute(0, 3, 1, 2)
        a = a.unsqueeze(1)
        if n_timesteps is not None and n_timesteps < T:
            idx = np.linspace(0, u.shape[1] - 1, num=n_timesteps, dtype=int)
            u = u[:, idx]
        else:
            n_timesteps = T
        self.data = torch.cat([a, u], dim=1).reshape(N, n_timesteps+1, Nx, Ny, 1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]

N_TIMESTEPS = 10
full_ds = TokensDataset(u, a, n_timesteps=N_TIMESTEPS)
train_size = int(0.8 * len(full_ds))
val_size = len(full_ds) - train_size
train_ds, val_ds = random_split(full_ds, [train_size, val_size])

BATCH_SIZE = 8
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
print(f"Train/Val samples: {len(train_ds)} / {len(val_ds)}")

N, Nx, Ny, T = u.shape
Q = 1
X_COMPRESSION = 2
Y_COMPRESSION = 2
Nx_ = Nx // X_COMPRESSION
Ny_ = Ny // Y_COMPRESSION
n_tokens = N_TIMESTEPS * Nx_ * Ny_  
block_size = X_COMPRESSION * Y_COMPRESSION

Train/Val samples: 4000 / 1000


In [3]:
def make_block_mask_after(n_tokens, block_size):
    idx = torch.arange(n_tokens, dtype=torch.long)
    mask_after = block_size * ((idx // block_size) + 1)-1
    return mask_after

def mask_after_to_dense_mask(mask_after):
    n_tokens = mask_after.shape[0]
    col_indices = torch.arange(n_tokens)
    return (col_indices > mask_after.unsqueeze(1))

mask_after = make_block_mask_after(n_tokens, block_size)
dense_mask = mask_after_to_dense_mask(mask_after)
mask_after = mask_after.to(device)
dense_mask = dense_mask.to(device)

In [4]:
a = make_block_mask_after(10, 2)
b = mask_after_to_dense_mask(a)
print(a)
print(b)

tensor([1, 1, 3, 3, 5, 5, 7, 7, 9, 9])
tensor([[False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False]])


In [5]:
EMBED_DIM = 64
NUM_HEADS = 4
NUM_LAYERS = 4
DROPOUT = 0.1
DIM_FEEDFORWARD = 256

In [6]:
# As in https://arxiv.org/abs/2209.15190, we use an encoder that operates on patches of the spatial
# grid. We add an additional non-linearity and layer normalizations. Also, in their implementation,
# the decoder is an MLP that takes as input the entire spatial grid after the transformer. Instead,
# we structure the decoder similarly to the encoder, operating on the compressed representation of
# a patch and outputting an uncompressed representation of the same patch.
class PatchwiseMLP(nn.Module):
    def __init__(self, dim, hidden_dim=32, out_dim=32,hidden_ff=64,K=[4,4],S=[4,4]):
        super(PatchwiseMLP, self).__init__()
        self.conv_layer1 = nn.Conv2d(dim, hidden_dim,
                                     kernel_size=K,
                                     stride=S)

        self.fc1 = nn.Linear(hidden_dim, hidden_ff)
        self.fc2 = nn.Linear(hidden_ff, out_dim)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_ff)

    def forward(self, x):
        B, T, H, W, Q = x.shape
        
        out = x.permute(0, 1, 4, 2, 3).reshape(B * T, Q, H, W)  # (B*T, Q, H, W)
        out = self.conv_layer1(out)  # (B*T, hidden_dim, H', W')
        out = out.permute(0, 2, 3, 1)  # (B*T, H', W', hidden_dim)
        out = self.norm1(out)
        out = self.relu1(out)
        out = self.fc1(out)  # (B*T, H', W', hidden_ff)
        out = self.norm2(out)
        out = self.relu2(out)
        out = self.fc2(out)  # (B*T, H', W', out_dim)

        _BT, H_prime, W_prime, C_out = out.shape
        out = out.contiguous().view(B, T, H_prime, W_prime, C_out)  # (B, T, H', W', out_dim)
        return out

softmax_encoder = PatchwiseMLP(
    1,
    EMBED_DIM,
    EMBED_DIM,
    hidden_ff=EMBED_DIM,
    K=[X_COMPRESSION, Y_COMPRESSION],
    S=[X_COMPRESSION, Y_COMPRESSION]
)

trim_encoder = PatchwiseMLP(
    1,
    EMBED_DIM,
    EMBED_DIM,
    hidden_ff=EMBED_DIM,
    K=[X_COMPRESSION, Y_COMPRESSION],
    S=[X_COMPRESSION, Y_COMPRESSION]
)

softmax_decoder = PatchwiseMLP(
    EMBED_DIM,
    EMBED_DIM,
    block_size,
    hidden_ff=EMBED_DIM,
    K=[1, 1],
    S=[1, 1]
)

trim_decoder = PatchwiseMLP(
    EMBED_DIM,
    EMBED_DIM,
    block_size,
    hidden_ff=EMBED_DIM,
    K=[1, 1],
    S=[1, 1]
)

In [7]:
# In this example, we use the positional encoding from https://arxiv.org/abs/1706.03762. This is
# an unnatural choice for this dataset, but it is the classic choice for language models.
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)
    def forward(self, x):
        return x + self.pe[:, : x.size(1)]

In [8]:
pos_enc = PositionalEncoding(EMBED_DIM, max_len=T*Nx*Ny)

softmax_layer = nn.TransformerEncoderLayer(d_model=EMBED_DIM,
                                           nhead=NUM_HEADS,
                                           dim_feedforward=DIM_FEEDFORWARD,
                                           batch_first=True)
softmax_model = nn.TransformerEncoder(softmax_layer, num_layers=NUM_LAYERS)

trim_layer = TrimTransformerEncoderLayer(d_model=EMBED_DIM,
                                         nhead=NUM_HEADS,
                                         dim_feedforward=DIM_FEEDFORWARD,
                                         batch_first=True)
trim_model = TrimTransformerEncoder(trim_layer, num_layers=NUM_LAYERS)

# Compose the encoder, positional encoding, transformer, and decoder into a single module.
class Pipeline(nn.Module):
    def __init__(self, encoder, decoder, pos_enc, transformer, mask):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pos_enc = pos_enc
        self.transformer = transformer
        self.mask = mask

    def forward(self, x):
        B, T, _, _, _ = x.shape
        z = self.encoder(x)
        y = z.flatten(1, 3)
        y = self.pos_enc(y)
        y = self.transformer(y, mask=self.mask)
        y = y.reshape_as(z)
        y = self.decoder(y)
        return y.reshape_as(x)

softmax_pipeline = Pipeline(softmax_encoder, softmax_decoder, pos_enc, softmax_model, dense_mask).to(device)
trim_pipeline = Pipeline(trim_encoder, trim_decoder, pos_enc, trim_model, mask_after).to(device)

print("softmax_encoder parameter count:", sum(p.numel() for p in softmax_encoder.parameters()))
print("softmax_decoder parameter count:", sum(p.numel() for p in softmax_decoder.parameters()))
print("softmax_model parameter count:", sum(p.numel() for p in softmax_model.parameters()))
print("trim_encoder parameter count:", sum(p.numel() for p in trim_encoder.parameters()))
print("trim_decoder parameter count:", sum(p.numel() for p in trim_decoder.parameters()))
print("trim_model parameter count:", sum(p.numel() for p in trim_model.parameters()))

softmax_encoder parameter count: 8896
softmax_decoder parameter count: 8836
softmax_model parameter count: 199936
trim_encoder parameter count: 8896
trim_decoder parameter count: 8836
trim_model parameter count: 199936


In [9]:
# A significant amount of memory (~4GB) is consumed by the dataset. So that the displayed memory
# usage demonstrates the memory usage of training the model, we set the memory usage before model
# training as a baseline. Note that the models have equal parameter count, so the memory usage
# of their weights is equal.
if device.type == "cuda":
    torch.cuda.synchronize()
    baseline_memory = torch.cuda.memory_allocated() / 1024**2
    torch.cuda.reset_peak_memory_stats()
    print(f"Baseline memory usage: {baseline_memory:.1f}MB")
else:
    baseline_memory = 0.0
    print("Using CPU - memory tracking not available")

Baseline memory usage: 4918.0MB


In [10]:
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    start = time.time()
    running = 0.0
    for traj in loader:
        optimizer.zero_grad()
        pred = model(traj[:, :-1])
        loss = criterion(pred, traj[:, 1:])
        loss.backward()
        optimizer.step()
        running += loss.item() * traj.size(0)
    elapsed = time.time() - start
    return running / len(loader.dataset), elapsed

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    running = 0.0
    for traj in loader:
        pred = model(traj[:, :-1])
        loss = criterion(pred, traj[:, 1:])
        running += loss.item() * traj.size(0)
    return running / len(loader.dataset)

def peak_mem():
    if device.type == "cuda":
        torch.cuda.synchronize()
        m = torch.cuda.max_memory_allocated() / 1024**2
        torch.cuda.reset_peak_memory_stats()
        return m - baseline_memory  # Subtract baseline memory usage
    return 0.0

EPOCHS = 100
lr = 1e-3
weight_decay = 1e-5
criterion = nn.MSELoss()

results = {}
for name, model in [("trim", trim_pipeline), ("softmax", softmax_pipeline)]:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    hist = {"train": [], "val": [], "time": [], "mem": []}
    for ep in range(1, EPOCHS+1):
        train_loss, t = train_epoch(model, train_loader, optimizer, criterion)
        val_loss = evaluate(model, val_loader, criterion)
        mem = peak_mem()
        hist["train"].append(train_loss)
        hist["val"].append(val_loss)
        hist["time"].append(t)
        hist["mem"].append(mem)
        print(f"{name:10s} | epoch {ep}/{EPOCHS} | train {train_loss:.3e} | val {val_loss:.3e} | {t:.2f}s | mem {mem:.1f}MB")
    results[name] = hist
    if device.type == "cuda":
        torch.cuda.empty_cache(); gc.collect()

trim       | epoch 1/100 | train 4.599e-01 | val 1.405e-01 | 35.74s | mem 3786.9MB
trim       | epoch 2/100 | train 1.137e-01 | val 6.084e-02 | 35.30s | mem 3787.6MB
trim       | epoch 3/100 | train 6.510e-02 | val 3.985e-02 | 35.30s | mem 3787.6MB
trim       | epoch 4/100 | train 4.771e-02 | val 2.967e-02 | 35.30s | mem 3787.6MB
trim       | epoch 5/100 | train 3.794e-02 | val 2.108e-02 | 35.29s | mem 3787.6MB
trim       | epoch 6/100 | train 3.152e-02 | val 1.898e-02 | 35.29s | mem 3787.6MB
trim       | epoch 7/100 | train 2.695e-02 | val 1.581e-02 | 35.30s | mem 3787.6MB
trim       | epoch 8/100 | train 2.384e-02 | val 1.473e-02 | 35.29s | mem 3787.6MB
trim       | epoch 9/100 | train 2.141e-02 | val 1.292e-02 | 35.29s | mem 3787.6MB
trim       | epoch 10/100 | train 1.967e-02 | val 1.166e-02 | 35.29s | mem 3787.6MB
trim       | epoch 11/100 | train 1.796e-02 | val 1.041e-02 | 35.29s | mem 3787.6MB
trim       | epoch 12/100 | train 1.665e-02 | val 1.064e-02 | 35.30s | mem 3787.6MB
t

In [12]:
summary = pd.DataFrame.from_dict({
    k: {
        "train_loss": min(v["train"]),
        "val_loss": min(v["val"]),
        "time/epoch (s)": np.mean(v["time"]),
        "peak_mem (MB)": max(v["mem"]),
    } for k, v in results.items()
}, orient="index")
print(summary)

         train_loss  val_loss  time/epoch (s)  peak_mem (MB)
trim       0.004626  0.002578       35.301417    3787.612305
softmax    0.003244  0.001938      210.313659   54992.738281
