In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("mps")
print(f"Using {device} device")

Using mps device


In [2]:
# Hyperparameters
embed_dim = 64 # 64 x 64 image
artist_dim = 16 # 16 possible artists
learning_rate = 0.001
num_epochs = 10

In [3]:
# Generate a 64 x 64 input image
input_image = torch.randn(1, 64, 64).to(device)
# Generate a artist ID
artist_id = torch.zeros(artist_dim).to(device)
artist_id[1] = 1

In [4]:
input_image.dim()
print(input_image)
print(artist_id)

tensor([[[ 0.8791, -0.4170,  0.4367,  ...,  1.5764,  0.1596, -0.2268],
         [ 0.8000,  2.1211,  0.7472,  ..., -0.4380, -0.0661,  0.7586],
         [-0.8151, -0.6690,  0.5545,  ...,  1.6681,  2.2424,  3.5777],
         ...,
         [-1.2877,  1.0814, -0.6811,  ..., -1.3663,  1.9063,  0.4265],
         [-0.7138,  1.5651, -0.2588,  ..., -0.7623, -1.3313,  0.8528],
         [-1.1911,  0.3793,  0.4360,  ..., -0.8034, -0.5732,  0.2622]]],
       device='mps:0')
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='mps:0')


In [5]:
class ImageEmbedding(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, embed_dim, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

    def forward(self, x):
        return self.conv(x)

In [6]:
m = ImageEmbedding(embed_dim).to(device)
print(m)

res = m(input_image)

print(len(res))
print(res.dim())
print(res)

ImageEmbedding(
  (conv): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
)
64
3
tensor([[[2.9150e-01, 3.1017e-01, 7.5294e-01,  ..., 4.4214e-01,
          4.8840e-01, 5.5453e-02],
         [1.8904e-01, 0.0000e+00, 3.0637e-01,  ..., 2.8130e-01,
          3.3281e-01, 6.4625e-01],
         [1.8673e-01, 3.6913e-01, 9.0881e-01,  ..., 2.7862e-01,
          0.0000e+00, 4.6233e-01],
         ...,
         [2.3606e-01, 7.6784e-01, 0.0000e+00,  ..., 1.0417e+00,
          1.8969e+00, 5.7396e-01],
         [3.4005e-02, 3.7618e-01, 1.2480e+00,  ..., 2.3852e-01,
          2.1087e-01, 0.0000e+00],
         [0.0000e+00, 5.1832e-01, 1.1842e+00,  ..., 1.1882e+00,
          8.1199e-01, 0.0000e+00]],

        [[4.3713e-01, 1.8483e-01, 0.0000e+00,  ..., 9.7127e-01,
          8.0660e-01, 6.2276e-01],
         [2.0018e-01, 0.0000e+00, 1.3432e+00,  ..., 6.3587e-01,
    

In [7]:
model = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Linear(256, artist_dim),
            nn.ReLU()
        ).to(device)

In [16]:
res = model(input_image)
print("res =", res)
print(res.dim())
print(len(res), len(res[0]), len(res[0][0]))

# get max
max = res.argmax()
print(max)

res = tensor([[[0.4699, 0.0000, 0.0000,  ..., 0.1925, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0254, 0.0000],
         [0.3879, 0.0000, 0.0000,  ..., 0.1631, 0.0377, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.2277, 0.3010, 0.0000],
         [0.0000, 0.0000, 0.0684,  ..., 0.3161, 0.0000, 0.0000],
         [0.1831, 0.0000, 0.0000,  ..., 0.0000, 0.2664, 0.0000]]],
       device='mps:0', grad_fn=<ReluBackward0>)
3
1 64 16
tensor(807, device='mps:0')


In [None]:
## Testing

# pool of square window of size=3, stride=2
m = nn.MaxPool2d(3, stride=2)
print(m)
input_image = torch.randn(1, 3, 64, 64).to(device)
print(input_image.dim())
output = m(input_image)

print(output.dim())