In [31]:
# TODO: Fix attention to more closely mimic standard or GPT transformer block 

In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=1, batch_first=True)

        # self.norm2 = nn.LayerNorm(dim)

        # hidden_dim = int(dim * 4)
        # self.mlp = nn.Sequential(
        #     nn.Linear(dim, hidden_dim),
        #     nn.GELU(),
        #     nn.Linear(hidden_dim, dim),
        # )

    def forward(self, x):
        # x: (B*L, 5, C)
        
        # Attention block
        # x_norm = self.norm1(x)
        # attn_output, _ = self.attn(x_norm, x_norm, x_norm)  # self-attention
        # x = x + attn_output  # residual

        # # MLP block
        # x_norm = self.norm2(x)
        # x = x + self.mlp(x_norm)  # residual

        # return x

        attn_output, _ = self.attn(x, x, x)  # self-attention
        return attn_output

class AttentionTanhConv2d(nn.Module):
    def __init__(self, in_channels, patch_size=2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels

        self.attn = nn.MultiheadAttention(embed_dim=in_channels, num_heads=1, batch_first=True)
        self.pos_embed = nn.Parameter(torch.randn(4, in_channels))
        self.stride_embed = nn.Parameter(torch.randn(4, in_channels))

        self.summary_token = nn.Parameter(torch.randn(1, in_channels))
        self.blocks = nn.ModuleList([
            TransformerBlock(dim=in_channels)
            for _ in range(1)
        ])

    def forward(self, x, stride=1, patch_size=2):
        B, C, H, W = x.shape

        # Unfold to get 2x2 patches: shape (B, C*4, L) where L = num patches
        patches = F.unfold(x, kernel_size=patch_size, stride=stride)
        L = patches.shape[-1]
        
        # Reshape to (B, L, 4, C)
        patches = patches.transpose(1, 2).reshape(B, L, patch_size * patch_size, C)
        
        # (B*L, 4, C) for attention
        tokens = patches.reshape(B * L, 4, C)
        tokens = tokens + self.pos_embed.unsqueeze(0)

        if stride == 2:
            tokens = tokens + self.stride_embed

        # Append summary token: (B*L, 5, C)
        summary = self.summary_token.expand(B * L, 1, C)
        tokens = torch.cat([summary, tokens], dim=1)

        for block in self.blocks:
            tokens = block(tokens)  # Each block does LayerNorm + Attention + MLP + Residual

        # Extract the summary output: (B*L, C)
        out = tokens[:, 0, :]
        
        # Reshape to (B, out_channels, H_out, W_out)
        H_out = (H - patch_size) // stride + 1
        W_out = (W - patch_size) // stride + 1
        out = out.view(B, H_out, W_out, self.out_channels).permute(0, 3, 1, 2)

        return out

In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [34]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

mps


In [65]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [66]:
train_data = train_dataset.data.to(device).float() / 255.0
train_targets = train_dataset.targets.to(device)

test_data = test_dataset.data.to(device).float() / 255.0
test_targets = test_dataset.targets.to(device)

train_data = train_data.unsqueeze(1)
test_data = test_data.unsqueeze(1)

def get_batches(data, targets, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size], targets[i:i + batch_size]

batch_size = 1000
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [76]:
class MNISTAttentionGRUCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, stride=2, padding=1)
        # self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        # self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=1)

        # self.weights = nn.Parameter(torch.randn(16, 1, 1))
        # self.channel_projector = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=1)
        self.attn_gru_conv = AttentionTanhConv2d(in_channels=16, patch_size=2)
        # self.fc1 = nn.Linear(16, 10)
        # self.fc2 = nn.Linear(7 * 7 * 16, 10)

        self.conv1 = nn.Conv2d(1, 16, kernel_size=2, stride=2, padding=1)  # 28x28 -> 15x15
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1) # 15x15 -> 8x8
        # self.conv3 = nn.Conv2d(16, 16, kernel_size=4, stride=1, padding=1) # 8x8   -> 7x7
        # self.conv3 = nn.Conv2d(16, 16, kernel_size=2, stride=1, padding=0) # 8x8   -> 7x7
        self.conv4 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=1) # 7x7   -> 5x5
        self.conv5 = nn.Conv2d(16, 16, kernel_size=5, stride=1, padding=0) # 5x5   -> 1x1

        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(16, 10)

    def forward(self, x): # (B, 1, 28, 28)
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        # x = self.relu(self.conv3(x))
        x = self.attn_gru_conv(x)
        x = x.contiguous()
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))

        # x = self.channel_projector(x)
        # x = F.relu(x)

        # x = self.attn_gru_conv(x, 2) # (B, ... , 14)
        # x2 = self.attn_gru_conv(x, 2) # (B, ... , 7)
        # x = self.attn_gru_conv(x2, 1) # (B, ... , 6)
        # x = self.attn_gru_conv(x, 2) # (B, ... , 3)
        # x = self.attn_gru_conv(x, 1) # (B, ... , 2)
        # x = self.attn_gru_conv(x, 1) # (B, ... , 1)

        # x = self.conv1(x) # (B, 16, 15, 15)
        # x = self.conv2(x) # (B, 16, 8, 8)
        # x = self.conv3(x) # (B, 16, 5, 5)
        # x = self.attn_gru_conv(x) # (B, 16, 4, 4)
        # x = self.attn_gru_conv(x) # (B, 16, 3, 3)
        # x = self.attn_gru_conv(x) # (B, 16, 2, 2)
        # x = self.attn_gru_conv(x) # (B, 16, 1, 1)

        x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), -1)  # Flatten
        x = self.fc1(x)
        return x

In [77]:
learning_rate = 0.001 * 1
epochs = 500

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = MNISTAttentionGRUCNN().to(device)
# model = torch.compile(MNIST2DLSTMClassifier()).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [78]:
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()} params, requires_grad={param.requires_grad}")

total_params = sum(p.numel() for p in model.parameters())
print()
print(total_params)

attn_gru_conv.pos_embed: 64 params, requires_grad=True
attn_gru_conv.stride_embed: 64 params, requires_grad=True
attn_gru_conv.summary_token: 16 params, requires_grad=True
attn_gru_conv.attn.in_proj_weight: 768 params, requires_grad=True
attn_gru_conv.attn.in_proj_bias: 48 params, requires_grad=True
attn_gru_conv.attn.out_proj.weight: 256 params, requires_grad=True
attn_gru_conv.attn.out_proj.bias: 16 params, requires_grad=True
attn_gru_conv.blocks.0.attn.in_proj_weight: 768 params, requires_grad=True
attn_gru_conv.blocks.0.attn.in_proj_bias: 48 params, requires_grad=True
attn_gru_conv.blocks.0.attn.out_proj.weight: 256 params, requires_grad=True
attn_gru_conv.blocks.0.attn.out_proj.bias: 16 params, requires_grad=True
conv1.weight: 64 params, requires_grad=True
conv1.bias: 16 params, requires_grad=True
conv2.weight: 2304 params, requires_grad=True
conv2.bias: 16 params, requires_grad=True
conv4.weight: 6400 params, requires_grad=True
conv4.bias: 16 params, requires_grad=True
conv5.weig

In [79]:

patience = 40
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        print(best_val_loss)
        break


Epoch [1/500], Training Loss: 2.1117
Epoch [1/500], Validation Loss: 1.3833, Validation Accuracy: 52.37%
Output Summary: Max=8.1235, Min=-12.0908, Median=-0.2159, Mean=-0.2843

Epoch [2/500], Training Loss: 0.9710
Epoch [2/500], Validation Loss: 0.7537, Validation Accuracy: 72.05%
Output Summary: Max=16.1052, Min=-18.1718, Median=-0.6576, Mean=-0.9242

Epoch [3/500], Training Loss: 0.6722
Epoch [3/500], Validation Loss: 0.6347, Validation Accuracy: 76.07%
Output Summary: Max=17.8981, Min=-18.3741, Median=-0.7355, Mean=-1.0063

Epoch [4/500], Training Loss: 0.5974
Epoch [4/500], Validation Loss: 0.5839, Validation Accuracy: 77.83%
Output Summary: Max=19.8420, Min=-19.6280, Median=-0.7308, Mean=-1.0446

Epoch [5/500], Training Loss: 0.5491
Epoch [5/500], Validation Loss: 0.5402, Validation Accuracy: 79.40%
Output Summary: Max=21.4533, Min=-20.2688, Median=-0.9171, Mean=-0.9980

Epoch [6/500], Training Loss: 0.5049
Epoch [6/500], Validation Loss: 0.5017, Validation Accuracy: 81.42%
Output

In [70]:

patience = 40
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        print(best_val_loss)
        break


Epoch [1/500], Training Loss: 1.9182
Epoch [1/500], Validation Loss: 1.2326, Validation Accuracy: 56.59%
Output Summary: Max=11.2913, Min=-14.3222, Median=0.5692, Mean=0.3573

Epoch [2/500], Training Loss: 0.8644
Epoch [2/500], Validation Loss: 0.7375, Validation Accuracy: 72.58%
Output Summary: Max=20.4001, Min=-20.5289, Median=2.7531, Mean=2.7068

Epoch [3/500], Training Loss: 0.6860
Epoch [3/500], Validation Loss: 0.6699, Validation Accuracy: 75.01%
Output Summary: Max=21.3139, Min=-21.5421, Median=3.3929, Mean=3.0213

Epoch [4/500], Training Loss: 0.6417
Epoch [4/500], Validation Loss: 0.6359, Validation Accuracy: 76.19%
Output Summary: Max=22.2096, Min=-21.9511, Median=3.9949, Mean=3.3501

Epoch [5/500], Training Loss: 0.6117
Epoch [5/500], Validation Loss: 0.6099, Validation Accuracy: 77.13%
Output Summary: Max=22.9944, Min=-21.2470, Median=4.3341, Mean=3.5689

Epoch [6/500], Training Loss: 0.5858
Epoch [6/500], Validation Loss: 0.5880, Validation Accuracy: 78.20%
Output Summary:

In [74]:

patience = 40
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break


Epoch [1/500], Training Loss: 2.1232
Epoch [1/500], Validation Loss: 1.4701, Validation Accuracy: 40.80%
Output Summary: Max=10.2939, Min=-10.7633, Median=0.2907, Mean=0.1655

Epoch [2/500], Training Loss: 0.9177
Epoch [2/500], Validation Loss: 0.7180, Validation Accuracy: 72.62%
Output Summary: Max=31.6700, Min=-34.7090, Median=-0.2609, Mean=0.0209

Epoch [3/500], Training Loss: 0.6307
Epoch [3/500], Validation Loss: 0.5989, Validation Accuracy: 78.11%
Output Summary: Max=30.9822, Min=-30.8298, Median=0.0657, Mean=-0.0007

Epoch [4/500], Training Loss: 0.5357
Epoch [4/500], Validation Loss: 0.5289, Validation Accuracy: 80.80%
Output Summary: Max=33.5743, Min=-31.3023, Median=0.1645, Mean=-0.0661

Epoch [5/500], Training Loss: 0.4897
Epoch [5/500], Validation Loss: 0.4936, Validation Accuracy: 81.92%
Output Summary: Max=34.3337, Min=-29.7314, Median=0.3283, Mean=-0.0808

Epoch [6/500], Training Loss: 0.4611
Epoch [6/500], Validation Loss: 0.4706, Validation Accuracy: 82.56%
Output Summ

In [51]:

patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break


Epoch [1/500], Training Loss: 2.3282
Epoch [1/500], Validation Loss: 2.3035, Validation Accuracy: 10.28%
Output Summary: Max=0.4745, Min=0.0816, Median=0.2738, Mean=0.2726

Epoch [2/500], Training Loss: 2.3005
Epoch [2/500], Validation Loss: 2.2889, Validation Accuracy: 10.28%
Output Summary: Max=0.4730, Min=0.0296, Median=0.2607, Mean=0.2452

Epoch [3/500], Training Loss: 1.1786
Epoch [3/500], Validation Loss: 0.4489, Validation Accuracy: 85.79%
Output Summary: Max=17.3885, Min=-16.1064, Median=0.8293, Mean=0.9764

Epoch [4/500], Training Loss: 0.3915
Epoch [4/500], Validation Loss: 0.3475, Validation Accuracy: 89.59%
Output Summary: Max=17.1890, Min=-14.7909, Median=0.8270, Mean=1.0324

Epoch [5/500], Training Loss: 0.3467
Epoch [5/500], Validation Loss: 0.3206, Validation Accuracy: 90.50%
Output Summary: Max=18.5099, Min=-15.6053, Median=0.8729, Mean=1.0608

Epoch [6/500], Training Loss: 0.3244
Epoch [6/500], Validation Loss: 0.3039, Validation Accuracy: 91.03%
Output Summary: Max=1

KeyboardInterrupt: 

In [12]:
import torch
import torch.nn.functional as F

# Metal Backend device or CPU device
device = "mps" if torch.backends.mps.is_available() else "cpu"

if __name__ == '__main__':

    tensor = torch.empty(4, 2, 40, 40).to(device)
    unfolded_tensor = F.unfold(input=tensor, kernel_size=3, padding=1, stride=1)
    print("torch version:", torch.__version__)

torch version: 2.6.0


In [16]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

here
here
here


KeyboardInterrupt: 

In [None]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.1728
Epoch [1/1000], Validation Loss: 1.6802, Validation Accuracy: 56.44%
Output Summary: Max=2.3652, Min=-2.7602, Median=0.0619, Mean=0.0334

Epoch [2/1000], Training Loss: 1.2055
Epoch [2/1000], Validation Loss: 0.7723, Validation Accuracy: 84.77%
Output Summary: Max=4.2014, Min=-2.8849, Median=-0.0332, Mean=0.1487

Epoch [3/1000], Training Loss: 0.5983
Epoch [3/1000], Validation Loss: 0.4247, Validation Accuracy: 90.35%
Output Summary: Max=5.0898, Min=-3.8560, Median=-0.2723, Mean=0.0734

Epoch [4/1000], Training Loss: 0.3736
Epoch [4/1000], Validation Loss: 0.3020, Validation Accuracy: 92.28%
Output Summary: Max=5.7756, Min=-4.4066, Median=-0.3293, Mean=0.0801

Epoch [5/1000], Training Loss: 0.2840
Epoch [5/1000], Validation Loss: 0.2441, Validation Accuracy: 93.43%
Output Summary: Max=6.2063, Min=-4.4058, Median=-0.3786, Mean=0.0622

Epoch [6/1000], Training Loss: 0.2379
Epoch [6/1000], Validation Loss: 0.2234, Validation Accuracy: 93.91%
Output Su

KeyboardInterrupt: 

In [None]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

Epoch [1/1000], Training Loss: 2.3023
Epoch [1/1000], Validation Loss: 2.3016, Validation Accuracy: 15.66%
Output Summary: Max=0.1028, Min=-0.0965, Median=0.0467, Mean=0.0335

Epoch [2/1000], Training Loss: 2.2845
Epoch [2/1000], Validation Loss: 2.1563, Validation Accuracy: 27.01%
Output Summary: Max=0.7537, Min=-0.6424, Median=0.0407, Mean=0.0154

Epoch [3/1000], Training Loss: 1.7491
Epoch [3/1000], Validation Loss: 1.2844, Validation Accuracy: 61.41%
Output Summary: Max=3.8563, Min=-3.6543, Median=0.1253, Mean=0.0692

Epoch [4/1000], Training Loss: 1.0003
Epoch [4/1000], Validation Loss: 0.7231, Validation Accuracy: 80.32%
Output Summary: Max=5.1553, Min=-4.6620, Median=-0.1656, Mean=0.0501

Epoch [5/1000], Training Loss: 0.6128
Epoch [5/1000], Validation Loss: 0.4861, Validation Accuracy: 86.23%
Output Summary: Max=6.4259, Min=-4.8845, Median=-0.1543, Mean=0.0677

Epoch [6/1000], Training Loss: 0.4516
Epoch [6/1000], Validation Loss: 0.3873, Validation Accuracy: 88.88%
Output Summ

KeyboardInterrupt: 

In [None]:
patience = 10000
best_val_loss = float('inf')
no_improvement_epochs = 0

all_outputs = []

for epoch in range(10000):
    model.train()
    running_loss = 0.0
    num_batches = 0
    
    for data, target in get_batches(train_data, train_targets, batch_size):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        num_batches += 1

    print(f"Epoch [{epoch + 1}/{epochs}], Training Loss: {running_loss / num_batches:.4f}")

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    num_batches = 0
    epoch_outputs = []

    with torch.no_grad():
        for data, target in get_batches(test_data, test_targets, batch_size):
            outputs = model(data)
            loss = criterion(outputs, target)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += target.size(0)
            num_batches += 1
            correct += (predicted == target).sum().item()

            epoch_outputs.append(outputs)

    all_outputs_tensor = torch.cat(epoch_outputs, dim=0)
    all_outputs.append(all_outputs_tensor)

    max_val = torch.max(all_outputs_tensor).item()
    min_val = torch.min(all_outputs_tensor).item()
    median_val = torch.median(all_outputs_tensor).item()
    mean_val = torch.mean(all_outputs_tensor).item()

    accuracy = 100 * correct / total
    val_loss /= num_batches
    print(f"Epoch [{epoch + 1}/{epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {accuracy:.2f}%")
    print(f"Output Summary: Max={max_val:.4f}, Min={min_val:.4f}, Median={median_val:.4f}, Mean={mean_val:.4f}")
    print()

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improvement_epochs = 0
    else:
        no_improvement_epochs += 1

    if no_improvement_epochs >= patience:
        print(f"Early stopping triggered after {epoch + 1} epochs.")
        break

torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
Epoch [1/1000], Training Loss: 2.3031
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
torch.Size([2500, 32, 5, 5])
Epoch [1/1000], Validation Loss: 2.3011, Validation Accuracy: 10.28%
Output Summary: Max=0.0767, Min=-0.0834, Median=0.0046, Mean=0.0008

torch.Size([

KeyboardInterrupt: 