In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("mps")
print(f"Using {device} device")

Using mps device


In [11]:
class ImageEmbedding(nn.Module):
    def __init__(self, input_channels, embed_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, embed_dim, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )

    def forward(self, x):
        return self.conv(x)

In [20]:
class MiniStableDiffusion(nn.Module):
    def __init__(self, embed_dim, artist_dim):
        super().__init__()
        self.embedding = ImageEmbedding(3, embed_dim)
        self.model = nn.Sequential(
            nn.Linear(embed_dim + artist_dim, 256),
            nn.Linear(256, embed_dim + artist_dim)
        )

    def forward(self, input_image, artist_id):
        embed = self.embedding(input_image)
        return self.model(embed + artist_id)

In [13]:
# Hyperparameters
embed_dim = 64 # 64 x 64 image
artist_dim = 16 # 16 possible artists
learning_rate = 0.001
num_epochs = 10

In [23]:
# Initialize the model, loss function and optimizer
model = MiniStableDiffusion(embed_dim, artist_dim).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [24]:
# Generate a 64 x 64 input image
input_image = torch.randn(1, 3, 64, 64).to(device)
# Generate a artist ID
artist_id = torch.randn(1, artist_dim).to(device)

print(input_image)
print(artist_id)

tensor([[[[ 0.0382, -1.0326,  1.6651,  ...,  0.4060,  1.3885,  1.5412],
          [ 0.0368,  0.1334,  0.4681,  ...,  1.0872,  1.2975,  0.6330],
          [ 1.3485,  0.4951, -0.8267,  ..., -0.2671, -0.6901, -0.6826],
          ...,
          [ 1.6119, -0.1807,  0.3658,  ...,  0.1559,  1.1326, -1.2212],
          [ 1.0478,  0.1940,  1.0301,  ...,  0.6800,  1.7220, -0.5570],
          [ 0.1039,  0.5230,  0.5104,  ...,  1.0293,  0.5324, -0.8728]],

         [[ 0.2333,  0.3853, -0.5016,  ..., -0.0856,  0.8257, -0.0045],
          [ 1.2666, -0.3176,  0.2978,  ..., -0.6036,  0.2281, -0.3135],
          [-1.4242, -0.3590,  0.5917,  ...,  0.2864,  0.8542, -0.4944],
          ...,
          [-0.9287, -0.3396, -0.5429,  ...,  0.7194, -0.3461,  0.8573],
          [-0.6420,  0.5528,  0.3257,  ...,  1.0379,  1.5631,  0.2282],
          [ 1.5854, -0.2417,  1.3674,  ...,  1.6776,  0.1210,  0.6292]],

         [[-0.4042, -1.8183, -0.3230,  ...,  1.6800,  1.5950, -0.0246],
          [ 0.6559, -0.6064, -

In [25]:
for epoch in range(num_epochs):
    optimizer.zero_grad()
    output = model(input_image, artist_id)
    loss = criterion(output, input_image.view(input_image.size(0), -1))
    loss.backward()
    optimizer.step()
    print(f'Epoch: {epoch+1}, Loss: {loss.item()}')

RuntimeError: The size of tensor a (32) must match the size of tensor b (16) at non-singleton dimension 3

In [None]:
# Generate a sample output image and list of contributing artist IDs
output_image = model(input_image, artist_id)
contributing_artists = [artist_id]

# Print it
print(output_image)
print(contributing_artists)