In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [34]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [35]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [123]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 500
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [176]:
class SimpleConcatRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.fc = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, x_t, h_prev):
        combined = torch.cat([x_t, h_prev], dim=1)
        h_t = F.tanh(self.fc(combined))
        return h_t

In [None]:
class ExperimentalModel(nn.Module):
    def __init__(self, hidden_size=20):
        super().__init__()
        self.rnn_cell = SimpleConcatRNNCell(1, hidden_size)

        self.x0 = nn.Parameter(torch.zeros(1, 28, 28))
        self.h0 = nn.Parameter(torch.randn(1, hidden_size))

        self.fc = nn.Linear(28 * 28, 10)

    def forward(self, x): # (B, 1, 28, 28)
        B, H, W = x.size(0), x.size(2), x.size(3)

        x_t = self.x0.expand(B, 1, H, W)
        h_t = self.h0.expand(B, -1)

        for _ in range(5):
            x_t_flat = x_t.view(B, -1)
            h_t = self.rnn_cell(x_t_flat.mean(dim=1).unsqueeze(1), h_t)

            x_cat = torch.cat([x_t, x], dim=1)
            weight = h_t[:, :18].view(B, 1, 2, 3, 3)
            bias = h_t[:, 18]

            x_cat = x_cat.view(1, B * 2, H, W)
            weight = weight.view(B, 2, 3, 3)

            x_t = F.conv2d(x_cat, weight, bias=bias, padding=1, stride=1, groups=B)
            x_t = F.relu(x_t.view(B, 1, H, W))

        return self.fc(x_t.view(B, -1))

In [190]:
learning_rate = 0.001 * 1
epochs = 10000

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = ExperimentalModel().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [191]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

x0: 784 params, requires_grad=True
h0: 19 params, requires_grad=True
rnn_cell.fc.weight: 380 params, requires_grad=True
rnn_cell.fc.bias: 19 params, requires_grad=True
fc.weight: 7840 params, requires_grad=True
fc.bias: 10 params, requires_grad=True

9052


In [192]:
patience = 40
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        print(best_val_loss)
        break

Epoch [1/10000], Training Loss: 0.8581
Epoch [1/10000], Validation Loss: 0.3279, Validation Accuracy: 90.52%
Output Summary: Max=13.4369, Min=-24.3840, Median=-3.7755, Mean=-3.7363

Epoch [2/10000], Training Loss: 0.3192
Epoch [2/10000], Validation Loss: 0.2778, Validation Accuracy: 91.90%
Output Summary: Max=14.5049, Min=-27.8686, Median=-3.8174, Mean=-3.8794

Epoch [3/10000], Training Loss: 0.2859
Epoch [3/10000], Validation Loss: 0.2585, Validation Accuracy: 92.50%
Output Summary: Max=15.5644, Min=-29.8379, Median=-3.7387, Mean=-3.8870

Epoch [4/10000], Training Loss: 0.2684
Epoch [4/10000], Validation Loss: 0.2459, Validation Accuracy: 92.81%
Output Summary: Max=16.3785, Min=-31.5484, Median=-3.7306, Mean=-3.9521

Epoch [5/10000], Training Loss: 0.2561
Epoch [5/10000], Validation Loss: 0.2370, Validation Accuracy: 93.28%
Output Summary: Max=17.0178, Min=-32.5589, Median=-3.7082, Mean=-3.9872

Epoch [6/10000], Training Loss: 0.2452
Epoch [6/10000], Validation Loss: 0.2283, Validatio

KeyboardInterrupt: 

In [94]:
for name, param in model.named_parameters():
    print(name, param)

x0 Parameter containing:
tensor([[[-4.8991e-03,  2.5201e-02, -1.1596e-02,  9.7011e-04,  8.4435e-03,
           5.4457e-03,  1.2147e-02,  2.1471e-02,  7.7073e-03, -3.0992e-03,
          -1.3756e-02,  2.0155e-02,  1.1545e-02, -9.1131e-02, -2.3543e-02,
          -2.2076e-02,  7.2532e-03,  1.0363e-02, -1.7747e-02, -1.6975e-02,
          -7.4943e-03, -1.7222e-02, -6.3775e-03,  6.7688e-02,  4.6023e-04,
           3.6882e-03,  8.9947e-03,  1.8488e-02],
         [-5.2839e-03,  8.6376e-04,  3.4281e-02,  2.9808e-03,  3.6648e-03,
           4.5234e-02,  2.3797e-02,  2.6553e-03,  1.0506e-02, -4.4785e-03,
           6.2421e-03,  2.2334e-02, -3.2451e-02,  2.3809e-02, -2.4318e-02,
          -1.3675e-01,  5.2540e-03, -9.0208e-04, -9.9969e-03,  2.3349e-02,
           1.3842e-02, -1.1336e-03,  4.6699e-03,  2.9924e-02,  1.9460e-02,
           3.9729e-03,  4.3285e-02,  4.0165e-02],
         [ 9.5744e-03,  3.5709e-02, -1.2582e-03,  1.4985e-03, -5.3291e-03,
           1.3135e-02,  2.2672e-02,  3.1037e-02, -