In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import numpy as np
from IPython.display import clear_output

In [5]:
x_train = torch.load('./models/x_train.pt').unsqueeze(1)
y_train = torch.load('./models/y_train.pt').unsqueeze(1)

x_train.shape, y_train.shape

(torch.Size([23376, 1, 15, 15]), torch.Size([23376, 1, 15, 15]))

In [49]:
dataset = data.TensorDataset(x_train, y_train)
dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True)

In [50]:
model = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(64, 128, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(128, 256, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(256, 1, kernel_size=1)
)
cost = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

In [51]:
epochs = 60

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

for epoch in range(1, epochs + 1):
    run_loss = 0.
    loss_total = 0.

    for i, (x, y) in enumerate(dataloader):
        prediction = model(x)
        loss = cost(prediction, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        run_loss += loss.item() * x.size(0)

        clear_output(wait=True)
        print(f'epoch: {epoch:2d}/{epochs} {i}/{len(dataloader)} cost: {loss_total:.6f}')

    loss_total = run_loss/len(dataloader)
    print(f'epoch: {epoch:2d}/{epochs} cost: {loss_total:.6f}')

torch.save(model, './models/model.pt')

epoch: 60/60 730/731 cost: 0.000000
epoch: 60/60 cost: 54.603101


In [10]:
model = torch.load('./models/model.pt')

loss = 0.
for i in range(len(x_train)):
    x = x_train[i]
    y = y_train[i]
    prediction = model(x)

    cos = torch.nn.CosineSimilarity(dim=1)
    mean = (torch.mean(cos(prediction, y)).item() + 1) / 2
    print(f'{i} cost: {mean:.4f}')

    loss += mean

similarity = loss / len(x_train)
similarity

0 cost: 0.5649
1 cost: 0.5502
2 cost: 0.5617
3 cost: 0.5838
4 cost: 0.6318
5 cost: 0.6214
6 cost: 0.6303
7 cost: 0.6560
8 cost: 0.6486
9 cost: 0.6485
10 cost: 0.6636
11 cost: 0.6985
12 cost: 0.6792
13 cost: 0.6901
14 cost: 0.6874
15 cost: 0.6926
16 cost: 0.7003
17 cost: 0.7271
18 cost: 0.7321
19 cost: 0.7284
20 cost: 0.7292
21 cost: 0.7584
22 cost: 0.7548
23 cost: 0.7613
24 cost: 0.7629
25 cost: 0.7579
26 cost: 0.7607
27 cost: 0.7635
28 cost: 0.7624
29 cost: 0.7584
30 cost: 0.7641
31 cost: 0.7997
32 cost: 0.7994
33 cost: 0.7987
34 cost: 0.7981
35 cost: 0.7988
36 cost: 0.7994
37 cost: 0.7927
38 cost: 0.7996
39 cost: 0.8313
40 cost: 0.8645
41 cost: 0.8611
42 cost: 0.8647
43 cost: 0.8567
44 cost: 0.8663
45 cost: 0.8664
46 cost: 0.8651
47 cost: 0.8637
48 cost: 0.8664
49 cost: 0.8660
50 cost: 0.8647
51 cost: 0.8653
52 cost: 0.8663
53 cost: 0.8596
54 cost: 0.8658
55 cost: 0.8995
56 cost: 0.5164
57 cost: 0.5622
58 cost: 0.5640
59 cost: 0.5777
60 cost: 0.6314
61 cost: 0.6131
62 cost: 0.6218
63

0.7293425263073786