In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class MaskedConv2d(nn.Conv2d):
    def __init__(self, mask_type, *args, **kwargs):
        super(MaskedConv2d, self).__init__(*args, **kwargs)
        assert mask_type in ('A', 'B')
        self.register_buffer('mask', self.weight.data.clone())
        _, depth, height, width = self.weight.size()
        self.mask.fill_(1)
        if mask_type == 'A':
            self.mask[:, :, height // 2, width // 2:] = 0
            self.mask[:, :, height // 2 + 1:] = 0
        else:  # Mask type B
            self.mask[:, :, height // 2, width // 2 + 1:] = 0
            self.mask[:, :, height // 2 + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

class PixelCNN(nn.Module):
    def __init__(self, input_channels=1, hidden_channels=64):
        super(PixelCNN, self).__init__()
        self.layer1 = MaskedConv2d('A', input_channels, hidden_channels, 7, padding=3)
        self.layer2 = MaskedConv2d('B', hidden_channels, hidden_channels, 7, padding=3)
        self.layer3 = MaskedConv2d('B', hidden_channels, hidden_channels, 7, padding=3)
        self.layer4 = MaskedConv2d('B', hidden_channels, input_channels, 7, padding=3)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        return torch.sigmoid(self.layer4(x))

# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PixelCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCELoss()

# Assuming the use of MNIST
train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True,
                                         transform=transforms.ToTensor()),
                          batch_size=32, shuffle=True)

# Training loop
for epoch in range(10):
    for images, _ in train_loader:
        images = images.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, images)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch}, Loss: {loss.item()}')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 79913026.98it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 12105095.82it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 43392739.30it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3579580.75it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch 0, Loss: 0.0834907591342926


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

losses = []
for epoch in range(10):
    epoch_losses = []
    for images, _ in train_loader:
        images = images.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, images)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
    avg_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(avg_loss)
    print(f'Epoch {epoch}, Loss: {avg_loss}')

plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.show()

In [None]:
def generate_images(model, device, num_images=10):
    model.eval()
    sample = torch.Tensor(num_images, 1, 28, 28).fill_(0.5).to(device)  # Start with half-gray images
    with torch.no_grad():
        for i in range(28):
            for j in range(28):
                out = model(sample)
                probs = torch.sigmoid(out[:, :, i, j])
                sample[:, :, i, j] = torch.bernoulli(probs)
    return sample.cpu()

generated_images = generate_images(model, device)
fig, axes = plt.subplots(1, num_images, figsize=(20, 2))
for i, img in enumerate(generated_images):
    axes[i].imshow(img.squeeze(), cmap='gray')
    axes[i].axis('off')
plt.show()

In [None]:
def plot_activations(model, image, device):
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    model.layer1.register_forward_hook(get_activation('layer1'))
    model.layer2.register_forward_hook(get_activation('layer2'))

    model.eval()
    with torch.no_grad():
        _ = model(image.to(device))

    for layer in ['layer1', 'layer2']:
        act = activation[layer].squeeze()
        fig, axarr = plt.subplots(act.size(0))
        for idx in range(act.size(0)):
            axarr[idx].imshow(act[idx].cpu())
            axarr[idx].axis('off')
        plt.show()

plot_activations(model, images[0].unsqueeze(0), device)