In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

100%|██████████| 9.91M/9.91M [00:12<00:00, 780kB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 191kB/s]
100%|██████████| 1.65M/1.65M [00:02<00:00, 785kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.03MB/s]


In [4]:
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
print(model)

SimpleNN(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [6]:
for images, labels in train_loader:
    print("Input batch shape:", images.shape)
    print("Labels batch shape:", labels.shape)
    break

Input batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])


In [7]:
images, labels = next(iter(train_loader))
outputs = model(images)
print("Output shape:", outputs.shape)

Output shape: torch.Size([64, 10])


In [8]:
def check_nan(tensor, name):
    if torch.isnan(tensor).any():
        print(f"Warning: NaN detected in {name}")
    if torch.isinf(tensor).any():
        print(f"Warning: Inf detected in {name}")

for param in model.parameters():
    check_nan(param, "Model Parameter")

In [9]:
for epoch in range(1):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        for name, param in model.named_parameters():
            if param.grad is not None:
                print(f"Gradient for {name}: {param.grad.norm()}")
        
        optimizer.step()
        print("Loss:", loss.item())
        break

Gradient for fc1.weight: 1.667392373085022
Gradient for fc1.bias: 0.06074222922325134
Gradient for fc2.weight: 0.65090012550354
Gradient for fc2.bias: 0.15991102159023285
Loss: 2.368743658065796
