In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import os
import matplotlib.pyplot as plt


In [2]:
plot_dir = 'imgs'
os.makedirs(plot_dir, exist_ok=True)

In [3]:
# Configuration (same as before)
dataset_name = 'mnist'
img_size = 28
n_channels = 1
img_coords = 2
lr = 1e-4
batch_size = 64
num_latent = 32
hidden_features = 256
num_layers = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# NEW: Mask parameters
mask_type = 'center'  # 'center' or 'random'
mask_size = 8  # For center mask


In [5]:
# SIREN Model (same as before)
class SirenLayer(nn.Module):
    def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
        super().__init__()
        self.in_f = in_f
        self.w0 = w0
        self.linear = nn.Linear(in_f, out_f)
        self.is_first = is_first
        self.is_last = is_last
        self.init_weights()

    def init_weights(self):
        b = 1 / self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
        with torch.no_grad():
            self.linear.weight.uniform_(-b, b)

    def forward(self, x):
        x = self.linear(x)
        return x if self.is_last else torch.sin(self.w0 * x)

def gon_model(dimensions):
    layers = [SirenLayer(dimensions[0], dimensions[1], is_first=True)]
    for i in range(1, len(dimensions)-2):
        layers.append(SirenLayer(dimensions[i], dimensions[i+1]))
    layers.append(SirenLayer(dimensions[-2], dimensions[-1], is_last=True))
    return nn.Sequential(*layers)

In [6]:
# NEW: Mask creation function
def create_mask(batch_size, img_size):
    if mask_type == 'center':
        mask = torch.ones(batch_size, img_size, img_size, 1)
        c = img_size//2
        mask[:, c-mask_size//2:c+mask_size//2, c-mask_size//2:c+mask_size//2, :] = 0
    elif mask_type == 'random':
        mask = torch.rand(batch_size, img_size, img_size, 1) > 0.25
        mask = mask.float()
    return mask.reshape(batch_size, -1, 1).to(device)

In [7]:
# Helper functions (same as before)
def get_mgrid(sidelen, dim=2):
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    return mgrid.reshape(-1, dim)

In [8]:
# Dataset and model setup
dataset = torchvision.datasets.MNIST('data', train=True, download=True,
                                           transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
gon_shape = [img_coords+num_latent] + [hidden_features]*num_layers + [n_channels]
F = gon_model(gon_shape).to(device)
optim = torch.optim.Adam(F.parameters(), lr=lr)
c = get_mgrid(img_size, 2).repeat(batch_size, 1, 1).to(device)  # Coordinates

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [9]:
# Training loop with inpainting
for step in range(1001):
    # Get batch and create masked version
    x, _ = next(iter(train_loader))
    x = x.permute(0, 2, 3, 1).reshape(batch_size, -1, n_channels).to(device)
    mask = create_mask(batch_size, img_size)  # NEW: Create mask
    x_masked = x * mask  # NEW: Apply mask

    # Inner loop: Find z using MASKED pixels
    z = torch.zeros(batch_size, 1, num_latent).to(device).requires_grad_()
    z_rep = z.repeat(1, c.size(1), 1)
    g = F(torch.cat((c, z_rep), dim=-1))
    L_inner = ((g - x_masked)**2 * mask).sum() / mask.sum()  # NEW: Masked loss
    z = -torch.autograd.grad(L_inner, z, create_graph=True)[0]

    # Outer loop: Reconstruct FULL image
    z_rep = z.repeat(1, c.size(1), 1)
    g = F(torch.cat((c, z_rep), dim=-1))
    L_outer = ((g - x)**2).mean()

    optim.zero_grad()
    L_outer.backward()
    optim.step()

    # Save visualizations
    if step % 100 == 0:
        with torch.no_grad():
            # Get example from first batch element
            example_idx = 0
            masked = x_masked[example_idx].reshape(img_size, img_size, 1)
            reconstructed = g[example_idx].reshape(img_size, img_size, 1)

            # Create side-by-side plot
            fig, (ax1, ax2) = plt.subplots(1, 2)
            ax1.imshow(masked.cpu().numpy(), cmap='gray')
            ax1.set_title('Masked Input')
            ax2.imshow(reconstructed.cpu().numpy(), cmap='gray')
            ax2.set_title('Reconstruction')
            plt.savefig(f'imgs/inpaint_{step}.png')
            plt.close()

        print(f'Step {step}, Loss: {L_outer.item():.4f}')

print("Training complete! Check 'imgs/' for inpainting results!")

Step 0, Loss: 0.1253
Step 100, Loss: 0.0542
Step 200, Loss: 0.0543
Step 300, Loss: 0.0539
Step 400, Loss: 0.0525
Step 500, Loss: 0.0532
Step 600, Loss: 0.0566
Step 700, Loss: 0.0548
Step 800, Loss: 0.0552
Step 900, Loss: 0.0537
Step 1000, Loss: 0.0520
Training complete! Check 'imgs/' for inpainting results!
