In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import numpy as np
from IPython.display import clear_output

In [2]:
x_train = torch.load('./models/x_train.pt').unsqueeze(1)
y_train = torch.load('./models/y_train.pt').unsqueeze(1)

x_train.shape, y_train.shape

(torch.Size([23376, 1, 15, 15]), torch.Size([23376, 1, 15, 15]))

In [3]:
dataset = data.TensorDataset(x_train, y_train)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

In [4]:
model = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(64, 128, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(128, 256, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(256, 1, kernel_size=1)
)
cost = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [5]:
epochs = 10

for epoch in range(1, epochs + 1):
    run_loss = 0.
    loss_total = 0.

    for i, (x, y) in enumerate(dataloader):
        prediction = model(x)
        loss = cost(prediction, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        run_loss += loss.item() * x.size(0)

        clear_output(wait=True)
        print(f'epoch: {epoch:2d}/{epochs} {i}/{len(dataloader)} cost: {loss_total:.6f}')

    loss_total = run_loss/len(dataloader.dataset)
    print(f'epoch: {epoch:2d}/{epochs} cost: {loss_total:.6f}')

torch.save(model, './models/model.pt')

epoch: 10/10 730/731 cost: 0.000000
epoch: 10/10 cost: 4.106892


In [30]:
model = torch.load('./models/model.pt')

v = torch.zeros((15, 15))
v[2, 1] = 1.

# for row in range(15):
#     for col in range(15):
#         print(v[row, col].item(), end=' ')
#     print()

output = model(v.reshape((1, 15, 15)))[0]

# print()

# for row in range(15):
#     for col in range(15):
#         print(f'{output[row, col].item():.2f}', end=' ')
#     print()

# x = torch.topk(output.reshape((225)), k=2)
# value, idx = x.values[1].item(), x.indices[1].item()
# print(f'output[{idx // 15, idx % 15}] = {value}')
output.shape

torch.Size([15, 15])

In [31]:
torch.FloatTensor([1]) < torch.FloatTensor([2])

tensor([True])