In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def iterated_digit_sum(n):
    while n > 9:
        n = sum(int(digit) for digit in str(n))
    return n

In [4]:
def number_to_digits(n, max_digits=9):
    digits = [int(d) for d in str(n)]
    return [0] * (max_digits - len(digits)) + digits

In [5]:
num_samples = 20000
X = np.random.randint(0, 100_000_000, size=num_samples)
y = np.array([iterated_digit_sum(x) for x in X])
X = np.array([number_to_digits(x) for x in X])

# Splitting data
split = int(0.8 * num_samples)
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]

In [6]:
class SimpleTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(10, d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_layers, num_layers)
        self.fc = nn.Linear(d_model, 10)

    def forward(self, src):
        src = self.embedding(src)
        src = src.permute(1, 0, 2)
        output = self.transformer(src, src)
        return self.fc(output[-1])


d_model = 64
nhead = 8
num_layers = 1

net = SimpleTransformer(d_model, nhead, num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)


In [None]:
# Initialize weights using Xavier (Glorot) initialization
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

net.apply(initialize_weights)

In [8]:
X_train_tensor = torch.LongTensor(X_train).to(device)
y_train_tensor = torch.LongTensor(y_train).to(device)
X_val_tensor = torch.LongTensor(X_val).to(device)
y_val_tensor = torch.LongTensor(y_val).to(device)

# Training loop
epochs = 3500
for epoch in range(epochs):
    net.train()
    optimizer.zero_grad()
    outputs = net(X_train_tensor).float()
    loss = criterion(outputs, y_train_tensor)
    loss.backward()
    optimizer.step()

    # Validation
    net.eval()
    with torch.no_grad():
        val_outputs = net(X_val_tensor).float()
        val_loss = criterion(val_outputs, y_val_tensor)

    if epoch%100 == 0:
      print(f"Epoch {epoch+1}/{epochs} - Loss: {loss.item()} - Val Loss: {val_loss.item()}")

Epoch 1/3500 - Loss: 3.043935537338257 - Val Loss: 2.700690507888794
Epoch 101/3500 - Loss: 2.2009785175323486 - Val Loss: 2.2028584480285645
Epoch 201/3500 - Loss: 2.171909809112549 - Val Loss: 2.2190098762512207
Epoch 301/3500 - Loss: 2.11291766166687 - Val Loss: 2.2495791912078857
Epoch 401/3500 - Loss: 2.0000765323638916 - Val Loss: 2.323840618133545
Epoch 501/3500 - Loss: 1.8489491939544678 - Val Loss: 2.4132754802703857
Epoch 601/3500 - Loss: 1.667099952697754 - Val Loss: 2.3689167499542236
Epoch 701/3500 - Loss: 0.9895638227462769 - Val Loss: 1.0590077638626099
Epoch 801/3500 - Loss: 0.19639430940151215 - Val Loss: 0.07977096736431122
Epoch 901/3500 - Loss: 0.07481522113084793 - Val Loss: 0.025816738605499268
Epoch 1001/3500 - Loss: 0.04443313553929329 - Val Loss: 0.00972510501742363
Epoch 1101/3500 - Loss: 0.030537474900484085 - Val Loss: 0.00576006667688489
Epoch 1201/3500 - Loss: 0.020608684048056602 - Val Loss: 0.00345575506798923
Epoch 1301/3500 - Loss: 0.016974303871393204

In [9]:
def predict(net, number):
    net.eval()
    with torch.no_grad():
        input_tensor = torch.LongTensor([number_to_digits(number)]).to(device)
        output = net(input_tensor).float()
        return torch.argmax(output).item()

In [10]:
test_number = 4587874515
print(f"Actual Iterated Digit Sum: {iterated_digit_sum(test_number)}")
print(f"Predicted Iterated Digit Sum: {predict(net, test_number)}")

Actual Iterated Digit Sum: 9
Predicted Iterated Digit Sum: 6


In [11]:
test_number = 15415451
print(f"Actual Iterated Digit Sum: {iterated_digit_sum(test_number)}")
print(f"Predicted Iterated Digit Sum: {predict(net, test_number)}")

Actual Iterated Digit Sum: 8
Predicted Iterated Digit Sum: 8


In [12]:
test_number = 52451235
print(f"Actual Iterated Digit Sum: {iterated_digit_sum(test_number)}")
print(f"Predicted Iterated Digit Sum: {predict(net, test_number)}")

Actual Iterated Digit Sum: 9
Predicted Iterated Digit Sum: 9


In [13]:
test_number = 852963254
print(f"Actual Iterated Digit Sum: {iterated_digit_sum(test_number)}")
print(f"Predicted Iterated Digit Sum: {predict(net, test_number)}")

Actual Iterated Digit Sum: 8
Predicted Iterated Digit Sum: 8
