In [1]:
import functools
import matplotlib
import os
import pdb
import time
from typing import Any, Callable, Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

True
cuda


In [2]:
# Load data
DATA_PATH = "data.npz"

with open(DATA_PATH, "rb") as f:
    data = np.load(f)
    RAW_DATA = data["arr_0"]

print(RAW_DATA.shape)

(200, 1001, 2)


In [3]:
MAX_LEN = 1001
TEST_SIZE = 32
TRAIN_SIZE = 200 - TEST_SIZE

assert TRAIN_SIZE + TEST_SIZE <= RAW_DATA.shape[0]
assert MAX_LEN <= RAW_DATA.shape[1]
# PAD_VALUE = -1e10
# is_pad = lambda x: np.isclose(x, PAD_VALUE)


class DictDataset(torch.utils.data.Dataset):
    def __init__(self, dictionary):
        self.dictionary = dictionary

    def __getitem__(self, index):
        return {key: values[index] for key, values in self.dictionary.items()}

    def __len__(self):
        return len(next(iter(self.dictionary.values())))
        

def get_datasets(
    batch_size,
    data=RAW_DATA,
    max_len=MAX_LEN,
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
):
    data = data[:, :max_len, :]
    train_x = torch.FloatTensor(data[:train_size, :max_len, 0][..., None])
    train_y = torch.FloatTensor(data[:train_size, :max_len, 1][..., None])
    # train_pad = torch.BoolTensor(is_pad(data[:train_size, :max_len, 0]))
    test_x = torch.FloatTensor(data[-test_size:, :max_len, 0][..., None])
    test_y = torch.FloatTensor(data[-test_size:, :max_len, 1][..., None])
    # test_pad = torch.BoolTensor(is_pad(data[-test_size:, :max_len, 0]))

    train_ds = DictDataset({"x": train_x, "y": train_y})
    test_ds = DictDataset({"x": test_x, "y": test_y})

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [4]:
# Utils.


def create_learning_rate_scheduler(warmup_steps=1000, total_steps=10000):

    def lr_lambda(step):
        lr = 1.0
        lr *= min(1.0, step / warmup_steps)
        lr *= min(1.0, (total_steps - step) / (total_steps - warmup_steps))
        return lr

    return lr_lambda


def compute_l2(predictions, targets, padding=None):
    if predictions.ndim != targets.ndim:
        raise ValueError(
            f"Incorrect shapes. Got shape {predictions.shape} predictions and {targets.shape} targets"
        )
    padding = padding or torch.zeros(
        predictions.shape[:-1], dtype=torch.bool, device=predictions.device
    )

    predictions = predictions * ~padding.unsqueeze(-1)
    targets = targets * ~padding.unsqueeze(-1)
    loss = ((predictions - targets) ** 2).sum(dim=-1)
    return loss.mean()


def compute_hinge(values):
    assert values.dim() == 2, f"{values.dim()} != 2"
    loss = torch.clamp(values, min=0)
    return loss.mean()


def compute_losses(
    py,
    pdy,
    physics_aux,
    y,
    dy,
    padding=None,
    deltas_loss_weight=0.0,
    physics_loss_weight=0.0,
):
    l = compute_l2(py, y)
    ld = compute_l2(pdy, dy)
    l2_loss = (1 - deltas_loss_weight) * l + deltas_loss_weight * ld
    if physics_aux is not None:
        physics_loss = compute_hinge(physics_aux)
    else:
        physics_loss = torch.zeros([])
        assert physics_loss_weight == 0.0

    loss = l2_loss + physics_loss_weight * physics_loss
    return {
        "loss": loss,
        "l2_loss": l2_loss,
        "physics_loss": physics_loss,
    }

In [5]:
def build_deltas(x):
    dx = x[:, 1:, :] - x[:, :-1, :]
    first_dx = torch.zeros((x.shape[0], 1, x.shape[2]), dtype=x.dtype, device=x.device)
    dx = torch.cat([first_dx, dx], dim=1)
    return dx


class TransformerConfig:
    def __init__(
        self,
        output_size: int = 1,
        max_len: int = MAX_LEN,
        num_layers: int = 2,
        hidden_dim: int = 16,
        mlp_dim: int = 64,
        num_heads: int = 4,
        dropout_rate: float = 0.0,
        attention_dropout_rate: float = 0.0,
        deterministic: bool = False,
        decode: bool = False,
        causal_x: bool = True,
        physics_decoder: bool = False,
    ):
        self.output_size = output_size
        self.max_len = max_len
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.attention_dropout_rate = attention_dropout_rate
        self.deterministic = deterministic
        self.decode = decode
        self.causal_x = causal_x
        self.physics_decoder = physics_decoder


class AddPositionEmbs(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.pos_embedding = nn.Parameter(
            torch.randn(1, config.max_len, config.hidden_dim) * 0.02
        )

    def forward(self, inputs):
        assert (
            inputs.ndim == 3
        ), f"Number of dimensions should be 3, but it is: {inputs.ndim}"
        return inputs + self.pos_embedding[:, : inputs.shape[1], :]


class MlpBlock(nn.Module):
    def __init__(self, config: TransformerConfig, out_dim: Optional[int] = None):
        super().__init__()
        self.config = config
        self.out_dim = out_dim if out_dim is not None else config.hidden_dim
        self.dense1 = nn.Linear(config.hidden_dim, config.mlp_dim)
        self.dense2 = nn.Linear(config.mlp_dim, self.out_dim)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, inputs):
        x = self.dense1(inputs)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        return self.dropout(x)


class EncoderDecoder1DBlock(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.ln1 = nn.LayerNorm(config.hidden_dim)
        self.attention = nn.MultiheadAttention(
            config.hidden_dim,
            config.num_heads,
            dropout=config.attention_dropout_rate,
            batch_first=True,
        )
        self.dropout = nn.Dropout(config.dropout_rate)
        self.ln2 = nn.LayerNorm(config.hidden_dim)
        self.mlp = MlpBlock(config)

    def forward(self, inputs, decoder_mask=None):
        x = self.ln1(inputs)
        x, _ = self.attention(x, x, x, attn_mask=decoder_mask)
        x = self.dropout(x)
        x = x + inputs
        z = self.ln2(x)
        z = self.mlp(z)
        return x + z


class Decoder(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.embed_x = nn.Linear(1, config.hidden_dim // 2)
        self.embed_dx = nn.Linear(1, config.hidden_dim // 2)
        self.pos_embed = AddPositionEmbs(config)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.layers = nn.ModuleList(
            [EncoderDecoder1DBlock(config) for _ in range(config.num_layers)]
        )
        self.ln = nn.LayerNorm(config.hidden_dim)
        self.logits_x = nn.Linear(config.hidden_dim, config.output_size)
        self.logits_dx = nn.Linear(config.hidden_dim, config.output_size)

    def forward(self, inputs, decoder_mask=None):
        x = inputs["x"]
        dx = build_deltas(x)

        x = self.embed_x(x)
        dx = self.embed_dx(dx)
        x = torch.cat([x, dx], dim=-1)

        x = self.dropout(x)
        x = self.pos_embed(x)
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, decoder_mask=decoder_mask)

        x = self.ln(x)
        logits_x = self.logits_x(x)
        logits_dx = self.logits_dx(x)

        return logits_x, logits_dx, None


# TODO add physics transformer.
class Transformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.decoder = Decoder(config)

    def forward(self, inputs):
        cfg = self.config
        decoder_mask = None
        if cfg.causal_x:
            decoder_mask = nn.Transformer.generate_square_subsequent_mask(
                inputs["x"].shape[1]
            ).to(inputs["x"].device)

        logits_x, logits_dx, aux = self.decoder(inputs, decoder_mask=decoder_mask)
        return logits_x, logits_dx, aux

In [8]:
hparams = {
    "random_seed": 0,
    "model_dir": "/tmp/test",
    "physics_decoder": False,
    "max_len": MAX_LEN,
    "num_layers": 4,
    "hidden_dim": 16,
    "mlp_dim": 64,
    "num_heads": 2,
    "dropout_rate": 0.0,
    "attention_dropout_rate": 0.0,
    "deltas_loss_weight": 0.0,
    "physics_loss_weight": 0.0,
    "causal_x": True,  # TODO add
    "batch_size": 16,
    "learning_rate": 1e-2,
    "weight_decay": 0.0,
    "warmup_steps": 500,
    "total_steps": 5000,
    "eval_freq": 500,
}

torch.cuda.empty_cache()
torch.manual_seed(hparams["random_seed"])

train_loader, eval_loader = get_datasets(batch_size=hparams["batch_size"])

config = TransformerConfig(
    max_len=hparams["max_len"],
    num_layers=hparams["num_layers"],
    hidden_dim=hparams["hidden_dim"],
    mlp_dim=hparams["mlp_dim"],
    num_heads=hparams["num_heads"],
    dropout_rate=hparams["dropout_rate"],
    attention_dropout_rate=hparams["attention_dropout_rate"],
    causal_x=hparams["causal_x"],
    physics_decoder=hparams["physics_decoder"],
)
model = Transformer(config).to(device)
model.train()
optimizer = optim.AdamW(
    model.parameters(),
    lr=hparams["learning_rate"],
    weight_decay=hparams["weight_decay"],
)
scheduler = optim.lr_scheduler.LambdaLR(
    optimizer,
    create_learning_rate_scheduler(
        warmup_steps=hparams["warmup_steps"], total_steps=hparams["total_steps"]
    ),
)


def forward(model, inputs, deltas_loss_weight, physics_loss_weight):
    optimizer.zero_grad()
    py, pdy, physics_aux = model(inputs)
    y = inputs["y"]
    dy = build_deltas(y)
    return compute_losses(
        py=py,
        pdy=pdy,
        physics_aux=physics_aux,
        y=y,
        dy=dy,
        deltas_loss_weight=deltas_loss_weight,
        physics_loss_weight=physics_loss_weight,
    )


metrics_all = []
total_steps = 0
tick = time.time()

while total_steps < hparams["total_steps"]:
    for batch in train_loader:

        if total_steps == 1 or (
            total_steps % hparams["eval_freq"] == 0 and total_steps > 0
        ):
            summary = {k: np.mean([m[k] for m in metrics_all]) for k in metrics_all[0]}
            summary["learning_rate"] = scheduler.get_last_lr()[0]
            metrics_all = []

            tock = time.time()
            steps_per_sec = hparams["eval_freq"] / (tock - tick)
            tick = tock

            model.eval()

            eval_metrics = []
            with torch.no_grad():
                for eval_batch in eval_loader:
                    eval_batch = {k: v.to(device) for k, v in eval_batch.items()}
                    metrics = forward(
                        model,
                        eval_batch,
                        hparams["deltas_loss_weight"],
                        hparams["physics_loss_weight"],
                    )
                    eval_metrics.append(
                        {
                            k: v.detach().item() if v is not None else None
                            for k, v in metrics.items()
                        }
                    )

            eval_summary = {
                k: np.mean([m[k] for m in eval_metrics]) for k in eval_metrics[0]
            }

            print(
                f"Step: {total_steps:04d},\ttrain loss {summary['loss']:.3f},\t"
                f"train l2 {summary['l2_loss']:.3f},\ttrain aux {summary['physics_loss']:.3f},\t"
                f"eval loss {eval_summary['loss']:.3f},\teval l2 {eval_summary['l2_loss']:.3f},\t"
                f"eval aux {eval_summary['physics_loss']:.3f},\tsteps/s {steps_per_sec:.1f},\t"
                f"lr {summary['learning_rate']:.5f}"
            )

            model.train()

        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        metrics = forward(
            model,
            batch,
            hparams["deltas_loss_weight"],
            hparams["physics_loss_weight"],
        )
        metrics["loss"].backward()
        optimizer.step()
        metrics_all.append(
            {
                k: v.detach().item() if v is not None else None
                for k, v in metrics.items()
            }
        )
        scheduler.step()
        total_steps += 1

        if total_steps >= hparams["total_steps"]:
            break

print(f"Training completed after {total_steps} steps.")

Step: 0001,	train loss 0.480,	train l2 0.480,	train aux 0.000,	eval loss 0.489,	eval l2 0.489,	eval aux 0.000,	steps/s 1576.7,	lr 0.00002
Step: 0500,	train loss 0.093,	train l2 0.093,	train aux 0.000,	eval loss 0.012,	eval l2 0.012,	eval aux 0.000,	steps/s 42.5,	lr 0.01000
Step: 1000,	train loss 0.008,	train l2 0.008,	train aux 0.000,	eval loss 0.006,	eval l2 0.006,	eval aux 0.000,	steps/s 42.4,	lr 0.00889
Step: 1500,	train loss 0.004,	train l2 0.004,	train aux 0.000,	eval loss 0.004,	eval l2 0.004,	eval aux 0.000,	steps/s 42.7,	lr 0.00778
Step: 2000,	train loss 0.002,	train l2 0.002,	train aux 0.000,	eval loss 0.005,	eval l2 0.005,	eval aux 0.000,	steps/s 42.4,	lr 0.00667
Step: 2500,	train loss 0.002,	train l2 0.002,	train aux 0.000,	eval loss 0.003,	eval l2 0.003,	eval aux 0.000,	steps/s 42.7,	lr 0.00556
Step: 3000,	train loss 0.002,	train l2 0.002,	train aux 0.000,	eval loss 0.003,	eval l2 0.003,	eval aux 0.000,	steps/s 42.5,	lr 0.00444
Step: 3500,	train loss 0.001,	train l2 0.001,	