In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torchvision
from tqdm import tqdm

In [101]:
vid_dataset = torch.load("videos.pt")
meta_data = torch.load("metadata.pt")

train_val_split = 0.8
train_val_split = int(len(vid_dataset) * train_val_split)

train_dataset = vid_dataset[:train_val_split].float() / 255
val_dataset = vid_dataset[train_val_split:].float() / 255

train_meta_data = meta_data[:train_val_split]
val_meta_data = meta_data[train_val_split:]

# only keep co col 0, 1 and normalize
train_meta_data = train_meta_data[:, :2].float() / 255
val_meta_data = val_meta_data[:, :2].float() / 255

# interpolate to 64x64
train_dataset = F.interpolate(train_dataset, size=64)
val_dataset = F.interpolate(val_dataset, size=64)

# create x, y dataset
train_x = train_dataset
train_y = train_meta_data

val_x = val_dataset
val_y = val_meta_data

# create dataloaders
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(train_x, train_y), batch_size=64, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(val_x, val_y), batch_size=64, shuffle=False
)

In [102]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 6, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.PixelUnshuffle(2),
        )
        self.head = nn.Linear(256, 2)


    def forward(self, x):
        B, C, H, W = x.shape
        x = self.layers(x)
        x = F.adaptive_avg_pool2d(x, 1).view(B, -1)
        x = self.head(x)
        return x
    
model = Model()
summary(model, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 64, 64]             168
              ReLU-2            [-1, 6, 64, 64]               0
         MaxPool2d-3            [-1, 6, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]             880
              ReLU-5           [-1, 16, 32, 32]               0
         MaxPool2d-6           [-1, 16, 16, 16]               0
            Conv2d-7           [-1, 32, 16, 16]           4,640
              ReLU-8           [-1, 32, 16, 16]               0
         MaxPool2d-9             [-1, 32, 8, 8]               0
           Conv2d-10             [-1, 64, 8, 8]          18,496
             ReLU-11             [-1, 64, 8, 8]               0
   PixelUnshuffle-12            [-1, 256, 4, 4]               0
           Linear-13                    [-1, 2]             514
Total params: 24,698
Trainable params: 

In [103]:
optimizer = torch.optim.Adam(model.parameters(), lr=8e-4)


In [104]:
val_loss = 0
for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader)
    for x, y in pbar:
        optimizer.zero_grad()
        y_pred = model(x)
        loss = F.mse_loss(y_pred, y.squeeze()[:, :2])
        loss.backward()
        optimizer.step()
        pbar.set_postfix({"loss": loss.item(), "val_loss": val_loss, "overfitting?": (val_loss - loss.item()) > (val_loss / 10)})

    model.eval()
    with torch.no_grad():
        val_loss = 0
        for x, y in val_loader:
            y_pred = model(x)
            val_loss += F.mse_loss(y_pred, y.squeeze()[:, :2])
        val_loss /= len(val_loader)
        val_loss = val_loss.item()


 75%|███████▍  | 56/75 [00:03<00:01, 15.37it/s, loss=0.0428, val_loss=0, overfitting?=0]


KeyboardInterrupt: 

In [118]:
from IPython.display import Video

with torch.no_grad():
    model.eval()
    x, y = next(iter(val_loader))
    y_pred = model(x).detach()
    
    # put green mark on the predicted position
    x = x.permute(0, 2, 3, 1)
    x = x.clone()
    x[0, int(y_pred[0, 1].item() * 64), int(y_pred[0, 0].item() * 64)] = torch.tensor([0, 1, 0])
    x[0] = x[0].clamp(0, 1) * 255

    torchvision.io.write_video("output.mp4", x, 30)

Video("output.mp4", width=320, height=320)