In [32]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [28]:
NUM_SAMPLES = 2000
INPUT_DIM = 10
NUM_CLASSES = 4
LR = 0.001
EPOCH = 10

In [22]:
X_train = torch.randn(NUM_SAMPLES, INPUT_DIM)
y_train = torch.randint(0, num_classes, (NUM_SAMPLES,))

In TensorFlow, it's possible to use a softmax activation in the output layer, but using logits is generally preferred for numerical stability and optimization. In PyTorch, however, since the loss function already includes softmax internally, applying softmax in the output layer should be avoided.

In [17]:
class PrefferedModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 25),
            nn.ReLU(),
            nn.Linear(25, 15),
            nn.ReLU(),
            nn.Linear(15, 4)
        )

    def forward(self, x):
        return self.model(x)

In [23]:
input_dim = X_train.shape[1]

In [24]:
model = PrefferedModel(input_dim=input_dim)

In [25]:
criterion = nn.CrossEntropyLoss()

In [27]:
optimizer = optim.Adam(model.parameters(), lr=LR)

In [30]:
for epoch in range(EPOCH):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

Epoch 1, Loss: 1.3936
Epoch 2, Loss: 1.3929
Epoch 3, Loss: 1.3922
Epoch 4, Loss: 1.3916
Epoch 5, Loss: 1.3910
Epoch 6, Loss: 1.3905
Epoch 7, Loss: 1.3900
Epoch 8, Loss: 1.3895
Epoch 9, Loss: 1.3890
Epoch 10, Loss: 1.3886


In [44]:
model.eval()

with torch.no_grad():
    logits = model(X_train)
    probs = F.softmax(logits, dim=1)
    pred_indices = torch.argmax(probs, dim=1)

print(probs[:5])
print(pred_indices[:5])

tensor([[0.2440, 0.2494, 0.2683, 0.2382],
        [0.2407, 0.2456, 0.2836, 0.2300],
        [0.2258, 0.2563, 0.2850, 0.2329],
        [0.2266, 0.2445, 0.2852, 0.2437],
        [0.2297, 0.2554, 0.2852, 0.2297]])
tensor([2, 2, 2, 2, 2])
