In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms.v2
from torchinfo import summary
from tqdm import tqdm
from ema_pytorch import EMA
import matplotlib.pyplot as plt


In [2]:
train_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=torchvision.transforms.v2.Compose(
        [
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ]
    ),
)
val_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=torchvision.transforms.v2.Compose(
        [
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ]
    ),
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=128, shuffle=False
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
@torch.no_grad()
def evaluate():
    model.eval()
    losses = []

    for x, y in val_loader:
        x, y = x.to("mps"), y.to("mps")
        y_hat = model(x)
        loss = criterion(y_hat, y)


        losses.append(loss.item())

    model.train()
    return {
        "loss": sum(losses) / len(losses),
        "psnr" : 10 * torch.log10(1 / sum(losses) / len(losses))
    }

In [8]:
class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels
            ),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.GroupNorm(16, out_channels),
            nn.GELU(),
        )

    def forward(self, x):
        return self.layers(x)


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=True):
        super().__init__()
        self.layers = nn.ModuleDict(
            {
                "conv1": DepthwiseSeparableConv2d(in_channels, out_channels),
                "sample": nn.MaxPool2d(2)
                if downsample
                else nn.Upsample(scale_factor=2),
            }
        )

    def forward(self, x):
        x = self.layers["conv1"](x)
        x = self.layers["sample"](x)
        return x


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            Block(3, 16),
            Block(16, 32),
            Block(32, 64),
            Block(64, 128),
        )

        self.decoder = nn.Sequential(

            Block(256, 128, downsample=False),
            Block(128, 64, downsample=False),
            Block(64, 32, downsample=False),
            Block(32, 16, downsample=False),
            nn.Conv2d(16, 3, kernel_size=3, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = Model()
print(summary(model, (1, 3, 32, 32)))
model = model.to("mps")

Layer (type:depth-idx)                                  Output Shape              Param #
Model                                                   [1, 3, 16, 16]            --
├─Sequential: 1-1                                       [1, 128, 2, 2]            --
│    └─Block: 2-1                                       [1, 16, 16, 16]           --
│    │    └─ModuleDict: 3-1                             --                        126
│    └─Block: 2-2                                       [1, 32, 8, 8]             --
│    │    └─ModuleDict: 3-2                             --                        768
│    └─Block: 2-3                                       [1, 64, 4, 4]             --
│    │    └─ModuleDict: 3-3                             --                        2,560
│    └─Block: 2-4                                       [1, 128, 2, 2]            --
│    │    └─ModuleDict: 3-4                             --                        9,216
├─Sequential: 1-2                                   

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
ema = EMA(model, beta=0.999, update_after_step=1000)
criterion = nn.MSELoss()

In [6]:
for epoch in range(10):
    pbar = tqdm(train_loader)
    for x, y in pbar:
        x, y = x.to("mps"), y.to("mps")
        x_hat = model(x)
        loss = criterion(x_hat, x)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ema.update()
        
        pbar.set_postfix_str(f"loss: {loss.item():.4f}")

  0%|          | 0/782 [00:00<?, ?it/s]

100%|██████████| 782/782 [00:34<00:00, 22.48it/s, loss: 0.0244]
 38%|███▊      | 300/782 [00:11<00:18, 26.67it/s, loss: 0.0238]


KeyboardInterrupt: 

In [11]:
# plot reconstructed images
with torch.no_grad():
    for x, y in val_loader:
        x, y = x.to("mps"), y.to("mps")
        x_hat = model(x)
        break


plt.subplot(1, 2, 1)
plt.imshow(x[0].cpu().permute(1, 2, 0))
plt.title("original")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(x_hat[0].cpu().permute(1, 2, 0))
plt.title("reconstructed")
plt.axis("off")

plt.show()

NameError: name 'plt' is not defined