In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm import tqdm

In [2]:
dataset = torch.load('videos.pt')

train_val_split = 0.8
train_val_split = int(len(dataset) * train_val_split)
train_dataset = dataset[:train_val_split].float() / 255
val_dataset = dataset[train_val_split:].float() / 255

# interpolate to 64x64
train_dataset = F.interpolate(train_dataset, size=64)
val_dataset = F.interpolate(val_dataset, size=64)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

In [7]:
class BasicResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(BasicResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.shortcut = (
            nn.Conv2d(in_channels, out_channels, 1)
            if in_channels != out_channels
            else nn.Identity()
        )

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        return F.relu(out)


class ResidualStack(nn.Module):
    def __init__(self, in_channels, out_channels, depth):
        super(ResidualStack, self).__init__()
        self.blocks = nn.ModuleList(
            [
                BasicResidualBlock(in_channels, out_channels)
                if i == 0
                else BasicResidualBlock(out_channels, out_channels)
                for i in range(depth)
            ]
        )

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        in_channels *= 4
        self.px_unshuffle = nn.PixelUnshuffle(2)
        self.stack = ResidualStack(in_channels, out_channels, 2)

    def forward(self, x):
        x = self.px_unshuffle(x)
        x = self.stack(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        out_channels *= 4
        self.px_shuffle = nn.PixelShuffle(2)
        self.stack = ResidualStack(in_channels, out_channels, 2)

    def forward(self, x):
        x = self.stack(x)
        x = self.px_shuffle(x)
        return x


class Autoencoder(nn.Module):
    def __init__(self, width, depth):
        super(Autoencoder, self).__init__()
        self.conv1 = nn.Conv2d(3, width, 3, padding=1)
        self.down_blocks = nn.ModuleList(
            [DownBlock(width, width) for _ in range(depth)]
        )
        self.up_blocks = nn.ModuleList([UpBlock(width, width) for _ in range(depth)])
        self.conv2 = nn.Conv2d(width, 3, 3, padding=1)

    def forward(self, x):
        x = self.encode(x)
        x = self.decode(x)

        return F.sigmoid(x)

    def encode(self, x):
        x = F.relu(self.conv1(x))
        for block in self.down_blocks:
            x = block(x)
        return x

    def decode(self, x):
        for block in self.up_blocks:
            x = block(x)
        x = self.conv2(x)
        return F.sigmoid(x)


model = Autoencoder(32, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")
latent_shape = model.encode(torch.zeros(1, 3, 64, 64))
out = model.decode(latent_shape)
print(
    f"Latent shape: {latent_shape.shape}\nDownsample Ratio: {out.shape[-1] // latent_shape.shape[-1]}"
)

Model has 1,107,235 parameters
Latent shape: torch.Size([1, 32, 16, 16])
Downsample Ratio: 4


In [8]:
val_losses = 0
for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=True)
    for batch in pbar:
        optimizer.zero_grad()
        out = model(batch)
        loss = F.binary_cross_entropy(out, batch)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())

    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            out = model(batch)
            val_losses += F.binary_cross_entropy(out, batch)

    print(f"Epoch {epoch} | Val Loss: {val_losses / len(val_loader)}")
    val_losses = 0

Epoch 1: 100%|██████████| 75/75 [02:20<00:00,  1.87s/it, loss=0.329]


Epoch 0 | Val Loss: 0.3272697627544403


Epoch 2: 100%|██████████| 75/75 [02:53<00:00,  2.31s/it, loss=0.329]


Epoch 1 | Val Loss: 0.3272697627544403


Epoch 3: 100%|██████████| 75/75 [03:06<00:00,  2.48s/it, loss=0.329]


Epoch 2 | Val Loss: 0.3272697627544403


Epoch 4: 100%|██████████| 75/75 [05:48<00:00,  4.65s/it, loss=0.33] 


Epoch 3 | Val Loss: 0.3272697627544403


Epoch 5: 100%|██████████| 75/75 [02:26<00:00,  1.95s/it, loss=0.329]


Epoch 4 | Val Loss: 0.3272697627544403


Epoch 6: 100%|██████████| 75/75 [03:03<00:00,  2.45s/it, loss=0.329]


Epoch 5 | Val Loss: 0.3272697627544403


Epoch 7:  56%|█████▌    | 42/75 [01:45<01:22,  2.50s/it, loss=0.33] 


KeyboardInterrupt: 

In [None]:
from IPython.display import Video
# plot reconstructions
with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(next(model.parameters()).device)
        out = model(batch)
        
        torchvision.io.write_video('reconstructions.mp4', out.permute(0, 2, 3, 1).cpu().numpy() * 255, 30)
        torchvision.io.write_video('originals.mp4', batch.permute(0, 2, 3, 1).cpu().numpy() * 255, 30)

        break

display(Video('reconstructions.mp4', width=400))
display(Video('originals.mp4', width=400))