In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


hidden_size = 64
output_size = 10
beta_init = 1e-4
learning_rate = 1e-4 
dt = 1e-4
num_epochs = 100 
dropout_rate = 0.3  

# basic demonstration of travel wave net(it is not a real implementation) 

class TravelingWaveLayer(nn.Module):
    def __init__(self, input_size, output_size, beta_init):
        super(TravelingWaveLayer, self).__init__()
        self.w = nn.Parameter(torch.randn(output_size, input_size) * 0.001)
        self.beta = nn.Parameter(torch.tensor(beta_init).float().to(device))
        self.x = torch.arange(input_size, dtype=torch.float).unsqueeze(0).to(device)

    def forward(self, x, t):
        w2 = self.w * (self.x * torch.cosh(5*self.beta) + t * torch.sinh(5*self.beta)).unsqueeze(1)
        return torch.matmul(w2, x.unsqueeze(-1)).squeeze(-1)


    def update_beta(self, dt):
        with torch.no_grad():
            beta_xx = (torch.roll(self.beta, -1) - 2*self.beta + torch.roll(self.beta, 1))
            beta_t = (self.beta - self.beta.clone().detach()) / dt
            self.beta.data += dt * (beta_t - beta_xx + self.beta - self.beta**3 - 0.01 * self.beta)  
            self.beta.data = torch.clamp(self.beta.data, -0.10, 0.10) 

class TravelingWaveNN(nn.Module):
    def __init__(self, output_size, beta_init, dropout_rate):
        super(TravelingWaveNN, self).__init__()

        self.input_size = 28 * 28 

        self.wave_layer1 = TravelingWaveLayer(self.input_size, hidden_size // 2, beta_init)
        self.wave_layer2 = TravelingWaveLayer(hidden_size // 2, hidden_size // 4, beta_init)
        self.wave_layer3 = TravelingWaveLayer(hidden_size // 4, hidden_size // 2, beta_init)
        self.wave_layer4 = TravelingWaveLayer(hidden_size // 2, output_size, beta_init)
        self.wave_layer5 = TravelingWaveLayer(output_size, output_size, beta_init)
        


        self.t1 = nn.Parameter(torch.tensor([0.9], device=device))
        self.t2 = nn.Parameter(torch.tensor([0.9], device=device))
        self.t3 = nn.Parameter(torch.tensor([0.9], device=device))
        self.t4 = nn.Parameter(torch.tensor([0.9], device=device))
        self.t5 = nn.Parameter(torch.tensor([0.9], device=device))

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.wave_layer1(x, self.t1)
        x = self.wave_layer2(x, self.t2)
        x = self.wave_layer3(x, self.t3)
        x = self.wave_layer4(x, self.t4)
        x = self.wave_layer5(x, self.t5)
        return x

train_dataset = datasets.MNIST(root='./data', 
                               train=True, 
                               download=True, 
                               transform=transforms.ToTensor()) 
test_dataset = datasets.MNIST(root='./data', 
                              train=False, 
                              download=True, 
                              transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=4)

model = TravelingWaveNN(output_size, beta_init, dropout_rate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

max_grad_norm = 1.0  

for epoch in range(num_epochs):
    start_time = time.time()
    train_loss = 0
    correct = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()

        model.wave_layer1.update_beta(dt)
        model.wave_layer2.update_beta(dt)  
        model.wave_layer3.update_beta(dt)
        model.wave_layer4.update_beta(dt)
        model.wave_layer5.update_beta(dt)

        train_loss += loss.item() * data.size(0)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    train_accuracy = 100. * correct / len(train_loader.dataset)
    
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.0f}%, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.0f}%, Time: {time.time() - start_time:.2f}s")

model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += criterion(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
test_accuracy = 100. * correct / len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {test_accuracy:.0f}%')

Using device: cuda
Epoch 1/100, Loss: 1.3238, Accuracy: 47%, Test Loss: 0.0011, Test Accuracy: 59%, Time: 3.39s
Epoch 2/100, Loss: 1.0588, Accuracy: 61%, Test Loss: 0.0011, Test Accuracy: 62%, Time: 3.16s
Epoch 3/100, Loss: 0.9866, Accuracy: 65%, Test Loss: 0.0007, Test Accuracy: 78%, Time: 3.18s
Epoch 4/100, Loss: 0.7185, Accuracy: 78%, Test Loss: 0.0007, Test Accuracy: 79%, Time: 3.16s
Epoch 5/100, Loss: 0.7046, Accuracy: 79%, Test Loss: 0.0007, Test Accuracy: 79%, Time: 3.25s
Epoch 6/100, Loss: 0.6996, Accuracy: 79%, Test Loss: 0.0007, Test Accuracy: 80%, Time: 3.17s
Epoch 7/100, Loss: 0.6850, Accuracy: 80%, Test Loss: 0.0005, Test Accuracy: 84%, Time: 3.14s
Epoch 8/100, Loss: 0.4896, Accuracy: 86%, Test Loss: 0.0005, Test Accuracy: 86%, Time: 3.26s
Epoch 9/100, Loss: 0.4712, Accuracy: 87%, Test Loss: 0.0005, Test Accuracy: 86%, Time: 3.23s
Epoch 10/100, Loss: 0.4662, Accuracy: 87%, Test Loss: 0.0005, Test Accuracy: 86%, Time: 3.23s
Epoch 11/100, Loss: 0.4634, Accuracy: 87%, Test Lo