In [None]:
device = 'cpu'


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import tqdm
from ema_pytorch import EMA


In [None]:
train_dataset = torchvision.datasets.ImageFolder("data/celeba_hq/train", transform=torchvision.transforms.Compose([
    torchvision.transforms.Resize((512, 512)),
    torchvision.transforms.ToTensor()
]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_dataset = torchvision.datasets.ImageFolder("data/celeba_hq/val", transform=torchvision.transforms.Compose([
    torchvision.transforms.Resize((512, 512)),
    torchvision.transforms.ToTensor()
]))

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=True)


In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layers = nn.ModuleList([
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(2),
            nn.Conv2d(16, 3, 3, padding=1),
            nn.Sigmoid()
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x
    
model = Model().to(device)
criterion = nn.MSELoss(reduction='mean')


In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0002)


In [None]:
ema = EMA(model, beta=0.9999, update_every=1)


In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32, 32)),
])


In [None]:
@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0
    total = 0
    with torch.no_grad():
        for i, (x, _) in enumerate(tqdm.tqdm(val_loader)):
            xs = transform(x)
            
            x = x.to(device)
            xs = xs.to(device)
            
            y = model(xs)
            loss = criterion(y, x)
            total_loss += loss.item()
            total += 1
    
    return total_loss / total


In [None]:
for epoch in range(10):
    model.train()
    with tqdm.tqdm(train_loader) as t:
        for i, (x, _) in enumerate(t):
            xs = transform(x)
            
            x = x.to(device)
            xs = xs.to(device)
            
            y = model(xs)
            loss = criterion(y, x)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            t.set_postfix_str(f"loss: {loss.item()}, est acc: {torch.exp(-loss).item()}")
            
            ema.update()
        
    val_loss = evaluate(val_loader)
    print(f"val loss: {val_loss}")
    
    # torch.save(model.state_dict(), f"model_{epoch}.pt")


In [None]:
print(evaluate(val_loader))


In [None]:
import matplotlib.pyplot as plt

x = next(iter(val_loader))[0]
xs = transform(x)
xs = xs.to(device)
y = model(xs)

plt.figure(figsize=(20, 10))
for i in range(4):
    plt.subplot(2, 4, i + 1)
    plt.imshow(xs[i].permute(1, 2, 0).cpu())
    plt.axis("off")
    
    plt.subplot(2, 4, i + 5)
    plt.imshow(y[i].permute(1, 2, 0).detach().cpu())
    plt.axis("off")
    
plt.show()
