# Kolmogorov Arnold Networks

### Kolmogorov-Arnold Networks (KANs): Overview  
- **Theoretical Foundation**: Based on the Kolmogorov-Arnold representation theorem, which states that any multivariate continuous function can be expressed as a sum of univariate functions, allowing for a systematic decomposition of complex functions.  
- **Architecture**: Decomposes high-dimensional mappings into a series of simpler one-dimensional functions, enabling efficient and accurate approximation of intricate dependencies in high-dimensional problems.  
- **Applications**: Particularly useful in fields requiring complex function approximations, such as machine learning, physics, and computational mathematics.

### KANs for Noise Removal in Partial Differential Equations (PDEs)  
- **Denoising Mechanism**: KANs excel at separating structured signals from stochastic noise by learning the underlying deterministic relationships within the data governed by PDEs.  
- **Training Process**: Trains to approximate noiseless PDE solutions by leveraging the network's bias towards representing smooth, well-structured functions.  
- **Advantage**: Provides robust and efficient recovery of clean PDE solutions from noisy datasets, combining theoretical rigor with practical performance.  
- **Practical Benefit**: Enhances the accuracy and reliability of numerical solutions for PDEs in scenarios where data is corrupted by noise.

## Implementation

This implementation of KAN has been taken from https://github.com/Blealtan/efficient-kan. 

This code defines a **Kolmogorov-Arnold Network (KAN) Linear Layer** in PyTorch. It is designed to efficiently model complex relationships by combining linear transformations and B-spline interpolation, as informed by the Kolmogorov-Arnold representation theorem.

###  Components:
1. **Initialization and Grid Setup**:
   - The grid represents the domain over which splines are interpolated. It is extended by the spline order to ensure smooth boundary conditions.
   - Parameters for the layer include base weights (`base_weight`) for linear transformations and spline weights (`spline_weight`) for B-spline coefficients.

2. **B-Spline Basis Calculation**:
   - The `b_splines` function computes B-spline bases for given inputs, which are then used to interpolate values in the input space.

3. **Spline Coefficient Computation**:
   - The `curve2coeff` function fits spline coefficients to the data by solving a least-squares problem, ensuring that the spline interpolation matches the input-output relationship.

4. **Forward Pass**:
   - Combines the linear transformation of the input using `base_weight` and the spline interpolation using `spline_weight` to produce the output.

5. **Grid Adaptation**:
   - The `update_grid` method adjusts the spline grid based on input distribution, allowing the model to adapt dynamically to new data.

6. **Regularization**:
   - A custom loss (`regularization_loss`) penalizes the spline weights to control overfitting and ensure smooth approximations. This includes terms for weight sparsity and entropy regularization.

This implementation is designed to be flexible and efficient, enabling it to approximate and learn high-dimensional functions, especially useful in tasks like function approximation, noise filtering, or PDE solutions.


## Layers

In [1]:
import torch
import torch.nn.functional as F
import math
import torch.nn as nn


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(
            self.base_weight, a=math.sqrt(5) * self.scale_base
        )
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(
                    self.spline_scaler, a=math.sqrt(5) * self.scale_spline
                )

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output

        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

## Complete Model

- Defines a complete Kolmogorov-Arnold Network (KAN) as a sequence of KANLinear layers for multivariate function approximation.  
- Accepts hyperparameters like grid size, spline order, scaling factors, activation functions, and grid range to configure the architecture.  
- Uses a `ModuleList` to stack multiple KANLinear layers based on the input hidden layer configuration (`layers_hidden`).  
- Implements a forward pass that processes input through each KANLinear layer, optionally updating the grid dynamically for adaptability.  
- Provides a method to compute regularization loss by aggregating the regularization terms from all layers to promote sparsity and smoothness.

In [2]:
class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

### Regularisation

- Extends the KAN architecture by integrating dropout layers to improve regularization and prevent overfitting.  
- Uses a `ModuleList` to stack layers, where each layer includes a KAN module, batch normalization, SiLU activation, and dropout.  
- Configurable dropout probability allows control over the level of regularization applied during training.  
- Implements a forward pass that processes input through each stacked layer sequentially, applying all operations in the layer pipeline.

In [None]:
class KANWithDropout(torch.nn.Module):
    def __init__(self, layers_hidden, dropout_prob=0.3):
        super(KANWithDropout, self).__init__()
        self.layers = nn.ModuleList()
        for in_features, out_features in zip(layers_hidden[:-1], layers_hidden[1:]):
            self.layers.append(
                nn.Sequential(
                    KAN(
                        layers_hidden=[in_features, out_features],
                        grid_size=5,
                        spline_order=3,
                        scale_noise=0.1,
                        scale_base=1.0,
                        scale_spline=1.0,
                        base_activation=torch.nn.SiLU,
                        grid_eps=0.02,
                        grid_range=[-1, 1],
                    ),
                    nn.BatchNorm1d(out_features),
                    nn.SiLU(),
                    nn.Dropout(dropout_prob),
                )
            )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

## Custom Loss Function

- A custom loss function combining Mean Squared Error (MSE) and L1 loss for balanced optimization between precision and robustness.  
- The weighting factor `alpha` controls the contribution of each loss term, allowing flexibility based on the problem's requirements.  
- Uses the `forward` method to compute the weighted sum of MSE and L1 losses between the model's output and the target values.  
- Provides a smooth and robust loss function useful for tasks where both small errors and outlier handling are important.

In [4]:
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()

    def forward(self, output, target):
        return self.alpha * self.mse(output, target) + (1 - self.alpha) * self.l1(
            output, target
        )

## Data Loading

The clean and noisy samples of the solution are loaded for training

In [5]:
import h5py
import numpy as np
with h5py.File("/kaggle/input/burgers-noisy/simulation_data.h5", "r") as f:
    a = list(f.keys())
    clean = []
    noisy = []
    for i in a[:-1]:
        clean.append(f[i]["clean"][:])
        noisy.append(f[i]["noisy"][:])

clean = np.array(clean)
noisy = np.array(noisy)

## Data Preparation

This script processes noisy and clean data for training a model, preparing it for PyTorch's DataLoader.

- **Data Reshaping**:
  - The noisy and clean datasets are flattened from `(num_samples, height, width)` into `(num_samples, height × width)` for compatibility with PyTorch models.

- **Tensor Conversion**:
  - The flattened arrays are converted into PyTorch tensors (`X_tensor` for noisy data and `Y_tensor` for clean data) of type `float32`.

- **Dataset Creation**:
  - A `TensorDataset` pairs the noisy data (`X_tensor`) with the clean target data (`Y_tensor`) for supervised learning.

- **Train/Test Split**:
  - The dataset is split into training and testing subsets, with 80% for training and 20% for testing, using PyTorch's `random_split`.

- **DataLoaders**:
  - Training and testing datasets are wrapped into DataLoaders with a batch size of 2, enabling shuffled training and sequential testing.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import h5py
import numpy as np


print("Noisy data shape:", noisy.shape)

num_samples, height, width = noisy.shape
noisy_flattened = noisy.reshape(num_samples, -1)
clean_flattened = clean.reshape(num_samples, -1)

X_tensor = torch.tensor(noisy_flattened, dtype=torch.float32)
Y_tensor = torch.tensor(clean_flattened, dtype=torch.float32)

dataset = TensorDataset(X_tensor, Y_tensor)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, test_size]
)

trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Trainloader size:", len(trainloader.dataset))
print("Testloader size:", len(testloader.dataset))

Noisy data shape: (100, 201, 1024)
Trainloader size: 80
Testloader size: 20


## Model Definition

  - A `KANWithDropout` model is initialized with the input dimension (`input_dim`), two hidden layers of sizes 256 and 64, and a final output layer of size `input_dim` to match the input/output shape.
  - The model is moved to the appropriate device (`GPU` if available, otherwise `CPU`).

- **Optimizer**:
  - Uses the `AdamW` optimizer, which combines the benefits of Adam with weight decay regularization to reduce overfitting.  
  - Learning rate is set to `1e-3`, and weight decay to `1e-4`.

- **Learning Rate Scheduler**:
  - A `ReduceLROnPlateau` scheduler reduces the learning rate by a factor of 0.5 when the monitored metric (e.g., loss) stops improving for 4 epochs (`patience=4`).
  - Minimum learning rate (`min_lr`) is set to `1e-6`, and verbose mode is enabled for logging changes.

- **Loss Function**:
  - Combines MSE and L1 losses using the `CombinedLoss` class, balancing precision and robustness during optimization.

In [7]:
input_dim = noisy_flattened.shape[1]  
model = KANWithDropout([input_dim, 512, 256, 128, 64, 128, 256, 512, input_dim])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=4,
    min_lr=1e-6,
    verbose=True,
)

criterion = CombinedLoss()



## Training Loop

This code trains the `KANWithDropout` model for 200 epochs, evaluates its validation loss, adjusts the learning rate using a scheduler, and saves the trained model.

- **Training Loop**:
  - Iterates over 200 epochs.
  - For each batch:
    - Moves data (`inputs` and `targets`) to the appropriate device.
    - Clears previous gradients using `optimizer.zero_grad()`.
    - Computes model outputs, loss, and gradients, and updates weights using `optimizer.step()`.
    - Performs validation


- **Learning Rate Adjustment**:
  - The learning rate scheduler (`ReduceLROnPlateau`) adjusts the learning rate based on the validation loss, reducing it when the loss plateaus.

- **Model Saving**:
  - Saves the trained model's state dictionary (`model.state_dict()`) to a file named `"kan_model.pth"`, allowing for later use or deployment.

In [8]:
for epoch in range(200):
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            output = model(inputs)
            val_loss += criterion(output, targets).item()
    val_loss /= len(testloader)

    # Update learning rate
    scheduler.step(val_loss)

    print(f"Epoch {epoch + 1}, Val Loss: {val_loss}")

# Save the model
torch.save(model.state_dict(), "kan_model.pth")

100%|██████████| 25/25 [00:03<00:00,  7.48it/s, loss=0.364, lr=0.001]


Epoch 1, Val Loss: 0.40956142118998934


100%|██████████| 25/25 [00:02<00:00,  8.95it/s, loss=0.439, lr=0.001]


Epoch 2, Val Loss: 0.401749781199864


100%|██████████| 25/25 [00:02<00:00,  9.09it/s, loss=0.377, lr=0.001]


Epoch 3, Val Loss: 0.39422949297087534


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.381, lr=0.001]


Epoch 4, Val Loss: 0.37822872400283813


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.369, lr=0.001]


Epoch 5, Val Loss: 0.35787443816661835


100%|██████████| 25/25 [00:02<00:00,  9.37it/s, loss=0.383, lr=0.001]


Epoch 6, Val Loss: 0.35934974040303913


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.364, lr=0.001]


Epoch 7, Val Loss: 0.3615292693887438


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.312, lr=0.001]


Epoch 8, Val Loss: 0.34612626688820974


100%|██████████| 25/25 [00:02<00:00,  8.80it/s, loss=0.297, lr=0.001]


Epoch 9, Val Loss: 0.34239550147737774


100%|██████████| 25/25 [00:02<00:00,  8.93it/s, loss=0.313, lr=0.001]


Epoch 10, Val Loss: 0.3340603304760797


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.254, lr=0.001]


Epoch 11, Val Loss: 0.33591480766023907


100%|██████████| 25/25 [00:02<00:00,  8.84it/s, loss=0.312, lr=0.001]


Epoch 12, Val Loss: 0.3226600629942758


100%|██████████| 25/25 [00:02<00:00,  8.70it/s, loss=0.31, lr=0.001]


Epoch 13, Val Loss: 0.3271221007619585


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.294, lr=0.001]


Epoch 14, Val Loss: 0.32420663961342405


100%|██████████| 25/25 [00:02<00:00,  8.78it/s, loss=0.281, lr=0.001]


Epoch 15, Val Loss: 0.3291414422648294


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.23, lr=0.001]


Epoch 16, Val Loss: 0.3053822730268751


100%|██████████| 25/25 [00:02<00:00,  8.94it/s, loss=0.403, lr=0.001]


Epoch 17, Val Loss: 0.30949369285787853


100%|██████████| 25/25 [00:02<00:00,  9.37it/s, loss=0.256, lr=0.001]


Epoch 18, Val Loss: 0.30623456835746765


100%|██████████| 25/25 [00:02<00:00,  8.80it/s, loss=0.234, lr=0.001]


Epoch 19, Val Loss: 0.30448477395943235


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.225, lr=0.001]


Epoch 20, Val Loss: 0.30398223229816984


100%|██████████| 25/25 [00:02<00:00,  8.78it/s, loss=0.246, lr=0.001]


Epoch 21, Val Loss: 0.29428372212818693


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.347, lr=0.001]


Epoch 22, Val Loss: 0.29082936474255155


100%|██████████| 25/25 [00:02<00:00,  9.14it/s, loss=0.229, lr=0.001]


Epoch 23, Val Loss: 0.297384432383946


100%|██████████| 25/25 [00:02<00:00,  9.08it/s, loss=0.272, lr=0.001]


Epoch 24, Val Loss: 0.2941744902304241


100%|██████████| 25/25 [00:02<00:00,  8.64it/s, loss=0.196, lr=0.001]


Epoch 25, Val Loss: 0.2935093215533665


100%|██████████| 25/25 [00:02<00:00,  8.95it/s, loss=0.204, lr=0.001]


Epoch 26, Val Loss: 0.29773129309926716


100%|██████████| 25/25 [00:02<00:00,  9.22it/s, loss=0.235, lr=0.001]


Epoch 27, Val Loss: 0.2890680858067104


100%|██████████| 25/25 [00:02<00:00,  9.06it/s, loss=0.235, lr=0.001]


Epoch 28, Val Loss: 0.2919142799718039


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.23, lr=0.001]


Epoch 29, Val Loss: 0.2984813856227057


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.256, lr=0.001]


Epoch 30, Val Loss: 0.2914662148271288


100%|██████████| 25/25 [00:02<00:00,  8.86it/s, loss=0.345, lr=0.001]


Epoch 31, Val Loss: 0.29214002404894146


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.23, lr=0.001]


Epoch 32, Val Loss: 0.286211510854108


100%|██████████| 25/25 [00:02<00:00,  9.00it/s, loss=0.288, lr=0.001]


Epoch 33, Val Loss: 0.287079886666366


100%|██████████| 25/25 [00:02<00:00,  9.18it/s, loss=0.252, lr=0.001]


Epoch 34, Val Loss: 0.29013362739767345


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.247, lr=0.001]


Epoch 35, Val Loss: 0.2820343183619635


100%|██████████| 25/25 [00:02<00:00,  9.02it/s, loss=0.249, lr=0.001]


Epoch 36, Val Loss: 0.2866498976945877


100%|██████████| 25/25 [00:02<00:00,  9.09it/s, loss=0.276, lr=0.001]


Epoch 37, Val Loss: 0.2944352147834642


100%|██████████| 25/25 [00:02<00:00,  9.08it/s, loss=0.211, lr=0.001]


Epoch 38, Val Loss: 0.2854394295385906


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.223, lr=0.001]


Epoch 39, Val Loss: 0.2882209047675133


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.222, lr=0.001]


Epoch 40, Val Loss: 0.2746558295828955


100%|██████████| 25/25 [00:02<00:00,  8.82it/s, loss=0.287, lr=0.001]


Epoch 41, Val Loss: 0.2842698097229004


100%|██████████| 25/25 [00:02<00:00,  9.38it/s, loss=0.243, lr=0.001]


Epoch 42, Val Loss: 0.27912837479795727


100%|██████████| 25/25 [00:02<00:00,  9.11it/s, loss=0.198, lr=0.001]


Epoch 43, Val Loss: 0.2884527061666761


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.262, lr=0.001]


Epoch 44, Val Loss: 0.28494105275188175


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.284, lr=0.001]


Epoch 45, Val Loss: 0.28651670047215055


100%|██████████| 25/25 [00:02<00:00,  9.45it/s, loss=0.214, lr=0.0005]


Epoch 46, Val Loss: 0.28400589738573345


100%|██████████| 25/25 [00:02<00:00,  9.00it/s, loss=0.208, lr=0.0005]


Epoch 47, Val Loss: 0.29055782726832796


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.227, lr=0.0005]


Epoch 48, Val Loss: 0.28805231409413473


100%|██████████| 25/25 [00:02<00:00,  9.09it/s, loss=0.217, lr=0.0005]


Epoch 49, Val Loss: 0.28348738167967114


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.234, lr=0.0005]


Epoch 50, Val Loss: 0.28367969180856434


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.191, lr=0.00025]


Epoch 51, Val Loss: 0.2800870026860918


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.237, lr=0.00025]


Epoch 52, Val Loss: 0.2773030400276184


100%|██████████| 25/25 [00:02<00:00,  8.94it/s, loss=0.228, lr=0.00025]


Epoch 53, Val Loss: 0.2773850177015577


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.255, lr=0.00025]


Epoch 54, Val Loss: 0.2864741597856794


100%|██████████| 25/25 [00:02<00:00,  9.06it/s, loss=0.286, lr=0.00025]


Epoch 55, Val Loss: 0.281657521213804


100%|██████████| 25/25 [00:02<00:00,  8.85it/s, loss=0.207, lr=0.000125]


Epoch 56, Val Loss: 0.27371829535279957


100%|██████████| 25/25 [00:02<00:00,  9.30it/s, loss=0.284, lr=0.000125]


Epoch 57, Val Loss: 0.276943781546184


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.229, lr=0.000125]


Epoch 58, Val Loss: 0.2818236542599542


100%|██████████| 25/25 [00:02<00:00,  8.95it/s, loss=0.207, lr=0.000125]


Epoch 59, Val Loss: 0.2836337025676455


100%|██████████| 25/25 [00:02<00:00,  8.84it/s, loss=0.415, lr=0.000125]


Epoch 60, Val Loss: 0.28896704209702356


100%|██████████| 25/25 [00:02<00:00,  9.07it/s, loss=0.2, lr=0.000125]


Epoch 61, Val Loss: 0.2813756231750761


100%|██████████| 25/25 [00:02<00:00,  9.36it/s, loss=0.211, lr=6.25e-5]


Epoch 62, Val Loss: 0.2761945405176708


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.24, lr=6.25e-5]


Epoch 63, Val Loss: 0.28251836768218447


100%|██████████| 25/25 [00:02<00:00,  9.20it/s, loss=0.251, lr=6.25e-5]


Epoch 64, Val Loss: 0.2797381005116871


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.224, lr=6.25e-5]


Epoch 65, Val Loss: 0.2832726985216141


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.231, lr=6.25e-5]


Epoch 66, Val Loss: 0.28555460274219513


100%|██████████| 25/25 [00:02<00:00,  8.94it/s, loss=0.336, lr=3.13e-5]


Epoch 67, Val Loss: 0.2876193342464311


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.183, lr=3.13e-5]


Epoch 68, Val Loss: 0.28653123123305185


100%|██████████| 25/25 [00:02<00:00,  9.07it/s, loss=0.218, lr=3.13e-5]


Epoch 69, Val Loss: 0.27439977122204645


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.216, lr=3.13e-5]


Epoch 70, Val Loss: 0.285542120890958


100%|██████████| 25/25 [00:02<00:00,  9.15it/s, loss=0.22, lr=3.13e-5]


Epoch 71, Val Loss: 0.2783579336745398


100%|██████████| 25/25 [00:02<00:00,  9.18it/s, loss=0.231, lr=1.56e-5]


Epoch 72, Val Loss: 0.2860529635633741


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.214, lr=1.56e-5]


Epoch 73, Val Loss: 0.2778893709182739


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.175, lr=1.56e-5]


Epoch 74, Val Loss: 0.2856562797512327


100%|██████████| 25/25 [00:02<00:00,  9.13it/s, loss=0.213, lr=1.56e-5]


Epoch 75, Val Loss: 0.2869007395846503


100%|██████████| 25/25 [00:02<00:00,  9.18it/s, loss=0.253, lr=1.56e-5]


Epoch 76, Val Loss: 0.2865152114203998


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.186, lr=7.81e-6]


Epoch 77, Val Loss: 0.28318923179592403


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.341, lr=7.81e-6]


Epoch 78, Val Loss: 0.28000139338629587


100%|██████████| 25/25 [00:02<00:00,  8.94it/s, loss=0.201, lr=7.81e-6]


Epoch 79, Val Loss: 0.2823889223592622


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.248, lr=7.81e-6]


Epoch 80, Val Loss: 0.2881431558302471


100%|██████████| 25/25 [00:02<00:00,  9.11it/s, loss=0.231, lr=7.81e-6]


Epoch 81, Val Loss: 0.28005271617855343


100%|██████████| 25/25 [00:02<00:00,  9.39it/s, loss=0.204, lr=3.91e-6]


Epoch 82, Val Loss: 0.2730787417718342


100%|██████████| 25/25 [00:02<00:00,  8.92it/s, loss=0.228, lr=3.91e-6]


Epoch 83, Val Loss: 0.2778947151132992


100%|██████████| 25/25 [00:02<00:00,  8.81it/s, loss=0.297, lr=3.91e-6]


Epoch 84, Val Loss: 0.28296755254268646


100%|██████████| 25/25 [00:02<00:00,  9.25it/s, loss=0.298, lr=3.91e-6]


Epoch 85, Val Loss: 0.2868628523179463


100%|██████████| 25/25 [00:02<00:00,  9.19it/s, loss=0.208, lr=3.91e-6]


Epoch 86, Val Loss: 0.2837820053100586


100%|██████████| 25/25 [00:02<00:00,  8.95it/s, loss=0.251, lr=3.91e-6]


Epoch 87, Val Loss: 0.28070418004478725


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.227, lr=1.95e-6]


Epoch 88, Val Loss: 0.28403542190790176


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.216, lr=1.95e-6]


Epoch 89, Val Loss: 0.2796029459152903


100%|██████████| 25/25 [00:02<00:00,  9.42it/s, loss=0.203, lr=1.95e-6]


Epoch 90, Val Loss: 0.2780973528112684


100%|██████████| 25/25 [00:02<00:00,  8.86it/s, loss=0.237, lr=1.95e-6]


Epoch 91, Val Loss: 0.2693714935864721


100%|██████████| 25/25 [00:02<00:00,  9.03it/s, loss=0.22, lr=1.95e-6]


Epoch 92, Val Loss: 0.28098506161144804


100%|██████████| 25/25 [00:02<00:00,  9.03it/s, loss=0.255, lr=1.95e-6]


Epoch 93, Val Loss: 0.287144090448107


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.192, lr=1.95e-6]


Epoch 94, Val Loss: 0.2776591511709349


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.198, lr=1.95e-6]


Epoch 95, Val Loss: 0.28528406471014023


100%|██████████| 25/25 [00:02<00:00,  8.95it/s, loss=0.246, lr=1.95e-6]


Epoch 96, Val Loss: 0.2798644570367677


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.203, lr=1e-6]


Epoch 97, Val Loss: 0.2805001458951405


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.214, lr=1e-6]


Epoch 98, Val Loss: 0.2756702293242727


100%|██████████| 25/25 [00:02<00:00,  9.14it/s, loss=0.215, lr=1e-6]


Epoch 99, Val Loss: 0.28256663999387194


100%|██████████| 25/25 [00:02<00:00,  9.13it/s, loss=0.193, lr=1e-6]


Epoch 100, Val Loss: 0.2821988207953317


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.204, lr=1e-6]


Epoch 101, Val Loss: 0.28228733049971716


100%|██████████| 25/25 [00:02<00:00,  9.08it/s, loss=0.203, lr=1e-6]


Epoch 102, Val Loss: 0.2799392704452787


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.254, lr=1e-6]


Epoch 103, Val Loss: 0.2772515064903668


100%|██████████| 25/25 [00:02<00:00,  8.93it/s, loss=0.184, lr=1e-6]


Epoch 104, Val Loss: 0.2781886447753225


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.217, lr=1e-6]


Epoch 105, Val Loss: 0.28061978944710325


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.18, lr=1e-6]


Epoch 106, Val Loss: 0.2794663927384785


100%|██████████| 25/25 [00:02<00:00,  9.03it/s, loss=0.226, lr=1e-6]


Epoch 107, Val Loss: 0.28233749845198225


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.192, lr=1e-6]


Epoch 108, Val Loss: 0.2857017857687814


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.255, lr=1e-6]


Epoch 109, Val Loss: 0.27522146701812744


100%|██████████| 25/25 [00:02<00:00,  8.74it/s, loss=0.197, lr=1e-6]


Epoch 110, Val Loss: 0.27538417492594036


100%|██████████| 25/25 [00:02<00:00,  9.17it/s, loss=0.228, lr=1e-6]


Epoch 111, Val Loss: 0.2857965880206653


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.262, lr=1e-6]


Epoch 112, Val Loss: 0.28392783978155683


100%|██████████| 25/25 [00:02<00:00,  9.17it/s, loss=0.209, lr=1e-6]


Epoch 113, Val Loss: 0.2733562876071249


100%|██████████| 25/25 [00:02<00:00,  9.25it/s, loss=0.252, lr=1e-6]


Epoch 114, Val Loss: 0.27337586347545895


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.217, lr=1e-6]


Epoch 115, Val Loss: 0.28005022555589676


100%|██████████| 25/25 [00:02<00:00,  9.02it/s, loss=0.301, lr=1e-6]


Epoch 116, Val Loss: 0.2764697585787092


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.223, lr=1e-6]


Epoch 117, Val Loss: 0.27203425019979477


100%|██████████| 25/25 [00:02<00:00,  9.43it/s, loss=0.27, lr=1e-6]


Epoch 118, Val Loss: 0.2755668024931635


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.237, lr=1e-6]


Epoch 119, Val Loss: 0.27425956619637354


100%|██████████| 25/25 [00:02<00:00,  8.89it/s, loss=0.195, lr=1e-6]


Epoch 120, Val Loss: 0.28437526736940655


100%|██████████| 25/25 [00:02<00:00,  9.19it/s, loss=0.225, lr=1e-6]


Epoch 121, Val Loss: 0.2849443991269384


100%|██████████| 25/25 [00:02<00:00,  9.04it/s, loss=0.208, lr=1e-6]


Epoch 122, Val Loss: 0.27697913135801044


100%|██████████| 25/25 [00:02<00:00,  9.14it/s, loss=0.207, lr=1e-6]


Epoch 123, Val Loss: 0.2826561267886843


100%|██████████| 25/25 [00:02<00:00,  9.13it/s, loss=0.233, lr=1e-6]


Epoch 124, Val Loss: 0.2753948430929865


100%|██████████| 25/25 [00:02<00:00,  8.97it/s, loss=0.247, lr=1e-6]


Epoch 125, Val Loss: 0.2837658907685961


100%|██████████| 25/25 [00:02<00:00,  9.44it/s, loss=0.195, lr=1e-6]


Epoch 126, Val Loss: 0.2830904041017805


100%|██████████| 25/25 [00:02<00:00,  9.07it/s, loss=0.199, lr=1e-6]


Epoch 127, Val Loss: 0.27592690714768003


100%|██████████| 25/25 [00:02<00:00,  9.11it/s, loss=0.219, lr=1e-6]


Epoch 128, Val Loss: 0.2803060348544802


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.199, lr=1e-6]


Epoch 129, Val Loss: 0.28241694825036184


100%|██████████| 25/25 [00:02<00:00,  9.16it/s, loss=0.209, lr=1e-6]


Epoch 130, Val Loss: 0.27445264905691147


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.252, lr=1e-6]


Epoch 131, Val Loss: 0.282078212925366


100%|██████████| 25/25 [00:02<00:00,  8.80it/s, loss=0.242, lr=1e-6]


Epoch 132, Val Loss: 0.27904557436704636


100%|██████████| 25/25 [00:02<00:00,  9.08it/s, loss=0.212, lr=1e-6]


Epoch 133, Val Loss: 0.27878786410604206


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.223, lr=1e-6]


Epoch 134, Val Loss: 0.27955677679606844


100%|██████████| 25/25 [00:02<00:00,  8.87it/s, loss=0.247, lr=1e-6]


Epoch 135, Val Loss: 0.2804058641195297


100%|██████████| 25/25 [00:02<00:00,  8.87it/s, loss=0.271, lr=1e-6]


Epoch 136, Val Loss: 0.26920067944696974


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.243, lr=1e-6]


Epoch 137, Val Loss: 0.27126651683024


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.257, lr=1e-6]


Epoch 138, Val Loss: 0.2833107198987688


100%|██████████| 25/25 [00:02<00:00,  9.16it/s, loss=0.32, lr=1e-6]


Epoch 139, Val Loss: 0.2840440848043987


100%|██████████| 25/25 [00:02<00:00,  9.09it/s, loss=0.253, lr=1e-6]


Epoch 140, Val Loss: 0.2825264888150351


100%|██████████| 25/25 [00:02<00:00,  9.14it/s, loss=0.247, lr=1e-6]


Epoch 141, Val Loss: 0.27322551714522497


100%|██████████| 25/25 [00:02<00:00,  9.45it/s, loss=0.186, lr=1e-6]


Epoch 142, Val Loss: 0.28264378756284714


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.228, lr=1e-6]


Epoch 143, Val Loss: 0.27224226615258623


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.223, lr=1e-6]


Epoch 144, Val Loss: 0.2701794909579413


100%|██████████| 25/25 [00:02<00:00,  8.86it/s, loss=0.236, lr=1e-6]


Epoch 145, Val Loss: 0.28441103867122103


100%|██████████| 25/25 [00:02<00:00,  9.26it/s, loss=0.211, lr=1e-6]


Epoch 146, Val Loss: 0.2807835776891027


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.215, lr=1e-6]


Epoch 147, Val Loss: 0.2838139363697597


100%|██████████| 25/25 [00:02<00:00,  8.82it/s, loss=0.224, lr=1e-6]


Epoch 148, Val Loss: 0.2816969965185438


100%|██████████| 25/25 [00:02<00:00,  9.04it/s, loss=0.193, lr=1e-6]


Epoch 149, Val Loss: 0.2794541386621339


100%|██████████| 25/25 [00:02<00:00,  9.19it/s, loss=0.228, lr=1e-6]


Epoch 150, Val Loss: 0.2824270895549229


100%|██████████| 25/25 [00:02<00:00,  8.90it/s, loss=0.262, lr=1e-6]


Epoch 151, Val Loss: 0.27934794553688597


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.274, lr=1e-6]


Epoch 152, Val Loss: 0.2744577271597726


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.21, lr=1e-6]


Epoch 153, Val Loss: 0.2845582813024521


100%|██████████| 25/25 [00:02<00:00,  9.04it/s, loss=0.223, lr=1e-6]


Epoch 154, Val Loss: 0.28049849612372263


100%|██████████| 25/25 [00:02<00:00,  9.16it/s, loss=0.232, lr=1e-6]


Epoch 155, Val Loss: 0.28502114755766733


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.243, lr=1e-6]


Epoch 156, Val Loss: 0.2786956599780491


100%|██████████| 25/25 [00:02<00:00,  8.91it/s, loss=0.237, lr=1e-6]


Epoch 157, Val Loss: 0.279442551944937


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.255, lr=1e-6]


Epoch 158, Val Loss: 0.28125459275075365


100%|██████████| 25/25 [00:02<00:00,  8.79it/s, loss=0.242, lr=1e-6]


Epoch 159, Val Loss: 0.27502445770161493


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.17, lr=1e-6]


Epoch 160, Val Loss: 0.2801616596324103


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.236, lr=1e-6]


Epoch 161, Val Loss: 0.28310875381742207


100%|██████████| 25/25 [00:02<00:00,  9.22it/s, loss=0.245, lr=1e-6]


Epoch 162, Val Loss: 0.28247302557740894


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.236, lr=1e-6]


Epoch 163, Val Loss: 0.27796555523361477


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.187, lr=1e-6]


Epoch 164, Val Loss: 0.27291418612003326


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.221, lr=1e-6]


Epoch 165, Val Loss: 0.2806436877165522


100%|██████████| 25/25 [00:02<00:00,  8.89it/s, loss=0.228, lr=1e-6]


Epoch 166, Val Loss: 0.28532709713493076


100%|██████████| 25/25 [00:02<00:00,  9.07it/s, loss=0.235, lr=1e-6]


Epoch 167, Val Loss: 0.27866476135594503


100%|██████████| 25/25 [00:02<00:00,  9.14it/s, loss=0.22, lr=1e-6]


Epoch 168, Val Loss: 0.2710730582475662


100%|██████████| 25/25 [00:02<00:00,  8.95it/s, loss=0.276, lr=1e-6]


Epoch 169, Val Loss: 0.2851007751056126


100%|██████████| 25/25 [00:02<00:00,  8.83it/s, loss=0.235, lr=1e-6]


Epoch 170, Val Loss: 0.28042394880737576


100%|██████████| 25/25 [00:02<00:00,  9.24it/s, loss=0.298, lr=1e-6]


Epoch 171, Val Loss: 0.278777539730072


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.382, lr=1e-6]


Epoch 172, Val Loss: 0.284740503345217


100%|██████████| 25/25 [00:02<00:00,  9.05it/s, loss=0.26, lr=1e-6]


Epoch 173, Val Loss: 0.27483914366790224


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.219, lr=1e-6]


Epoch 174, Val Loss: 0.2816484868526459


100%|██████████| 25/25 [00:02<00:00,  9.02it/s, loss=0.306, lr=1e-6]


Epoch 175, Val Loss: 0.2837367206811905


100%|██████████| 25/25 [00:02<00:00,  9.17it/s, loss=0.267, lr=1e-6]


Epoch 176, Val Loss: 0.27215792877333506


100%|██████████| 25/25 [00:02<00:00,  9.07it/s, loss=0.229, lr=1e-6]


Epoch 177, Val Loss: 0.26830970283065525


100%|██████████| 25/25 [00:02<00:00,  9.40it/s, loss=0.217, lr=1e-6]


Epoch 178, Val Loss: 0.28448394260236193


100%|██████████| 25/25 [00:02<00:00,  9.26it/s, loss=0.212, lr=1e-6]


Epoch 179, Val Loss: 0.2823911053793771


100%|██████████| 25/25 [00:02<00:00,  8.93it/s, loss=0.248, lr=1e-6]


Epoch 180, Val Loss: 0.2796053183930261


100%|██████████| 25/25 [00:02<00:00,  9.03it/s, loss=0.221, lr=1e-6]


Epoch 181, Val Loss: 0.28600056895187925


100%|██████████| 25/25 [00:02<00:00,  9.21it/s, loss=0.278, lr=1e-6]


Epoch 182, Val Loss: 0.28485369575875147


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.222, lr=1e-6]


Epoch 183, Val Loss: 0.2828455716371536


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.223, lr=1e-6]


Epoch 184, Val Loss: 0.2830220716340201


100%|██████████| 25/25 [00:02<00:00,  9.21it/s, loss=0.334, lr=1e-6]


Epoch 185, Val Loss: 0.2773524139608656


100%|██████████| 25/25 [00:02<00:00,  8.99it/s, loss=0.281, lr=1e-6]


Epoch 186, Val Loss: 0.27952112576791216


100%|██████████| 25/25 [00:02<00:00,  9.13it/s, loss=0.216, lr=1e-6]


Epoch 187, Val Loss: 0.2725838827235358


100%|██████████| 25/25 [00:02<00:00,  9.12it/s, loss=0.243, lr=1e-6]


Epoch 188, Val Loss: 0.2799153136355536


100%|██████████| 25/25 [00:02<00:00,  8.93it/s, loss=0.274, lr=1e-6]


Epoch 189, Val Loss: 0.2767013907432556


100%|██████████| 25/25 [00:02<00:00,  9.24it/s, loss=0.204, lr=1e-6]


Epoch 190, Val Loss: 0.27257933574063437


100%|██████████| 25/25 [00:02<00:00,  8.86it/s, loss=0.207, lr=1e-6]


Epoch 191, Val Loss: 0.2817013348851885


100%|██████████| 25/25 [00:02<00:00,  8.96it/s, loss=0.221, lr=1e-6]


Epoch 192, Val Loss: 0.27838866838387083


100%|██████████| 25/25 [00:02<00:00,  9.22it/s, loss=0.226, lr=1e-6]


Epoch 193, Val Loss: 0.27558340983731405


100%|██████████| 25/25 [00:02<00:00,  9.04it/s, loss=0.223, lr=1e-6]


Epoch 194, Val Loss: 0.2748587056994438


100%|██████████| 25/25 [00:02<00:00,  8.88it/s, loss=0.321, lr=1e-6]


Epoch 195, Val Loss: 0.283822520502976


100%|██████████| 25/25 [00:02<00:00,  9.02it/s, loss=0.291, lr=1e-6]


Epoch 196, Val Loss: 0.27537852632147924


100%|██████████| 25/25 [00:02<00:00,  8.98it/s, loss=0.263, lr=1e-6]


Epoch 197, Val Loss: 0.2763516360095569


100%|██████████| 25/25 [00:02<00:00,  9.16it/s, loss=0.21, lr=1e-6]


Epoch 198, Val Loss: 0.2786255553364754


100%|██████████| 25/25 [00:02<00:00,  9.01it/s, loss=0.21, lr=1e-6]


Epoch 199, Val Loss: 0.2747374081185886


100%|██████████| 25/25 [00:02<00:00,  9.10it/s, loss=0.202, lr=1e-6]


Epoch 200, Val Loss: 0.2820498155696051


## Prediction

Randomly selected sample is reshaped and prepared to be fed to the model.

In [9]:
with  h5py.File("/kaggle/input/burgers-noisy/simulation_data.h5", "r") as f:
    a = f["10"]["noisy"][:]
    b = f["10"]["clean"][:]
    x = f["coords"]["x-coordinates"][:]
    f.close()
a = torch.Tensor(a).to(device)
a = a.view(1, -1)
a.shape

torch.Size([1, 205824])

### Run Predictions

Predictions are made and the output is reshaped for visualisation.

In [10]:
with torch.no_grad():
    u = model(a.to(device)).to("cpu")

u = u.cpu().numpy().reshape((201, 1024))

## Visualisation

The `visualize_burgers` function generates an animated GIF of the Burgers equation's solution over time. It takes spatial coordinates (`xcrd`), the simulation data (`data`), and an identifier (`i`) to name the output GIF. The function iterates through the time steps of the solution, plots each one, and stores the frames for the animation. It then creates an animation using `matplotlib.animation.ArtistAnimation`, saving it as a `.gif` file with a specified frame rate.

In [11]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
from tqdm import tqdm


def visualize_burgers(xcrd, data):
    """
    This function animates the Burgers equation

    Args:
    path : path to the desired file
    param: PDE parameter of the data shard to be visualized
    """

    fig, ax = plt.subplots()

    ims = []

    for i in tqdm(range(data.shape[0])):
        if i == 0:
            im = ax.plot(xcrd, data[i].squeeze(), animated=True, color="blue")
        else:
            im = ax.plot(
                xcrd, data[i].squeeze(), animated=True, color="blue"
            )  # show an initial one first
        ims.append([im[0]])

    # Animate the plot
    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)

    writer = animation.PillowWriter(fps=15, bitrate=1800)
    ani.save("burgerCombo.gif", writer=writer)
    plt.close(fig)

visualize_burgers(x[:], u)

100%|██████████| 201/201 [00:00<00:00, 1941.38it/s]
