In [1]:
import os
from math import pi
from glob import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms as TF
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

device = 'cuda:0'
phase_mod = True


Bad key "text.kerning_factor" on line 4 in
/opt/conda/envs/pytorch/lib/python3.7/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.1.3/matplotlibrc.template
or from the matplotlib source distribution


In [2]:
class GaussianFourierFeatureTransform(nn.Module):
    """
    An implementation of Gaussian Fourier feature mapping.
    "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
       https://arxiv.org/abs/2006.10739
       https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
    Given an input of size [batches, num_input_channels, width, height],
     returns a tensor of size [batches, mapping_dim*2, width, height].
    """

    def __init__(self, num_input_channels=2, mapping_dim=256, scale=10):
        super().__init__()

        self._num_input_channels = num_input_channels
        self.mapping_dim = mapping_dim
        self._B = torch.randn((num_input_channels, mapping_dim)) * scale

    def forward(self, x, phase=None):
        assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())

        batches, channels, width, height = x.shape
        assert channels == self._num_input_channels,\
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)

        # Make shape compatible for matmul with _B.
        # From [B, C, W, H] to [(B*W*H), C].
        x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)

        x = x @ self._B.to(x.device)

        # From [(B*W*H), C] to [B, W, H, C]
        x = x.view(batches, width, height, self.mapping_dim)
        # From [B, W, H, C] to [B, C, W, H]
        x = x.permute(0, 3, 1, 2)

        if phase is not None:
            x = 2 * pi * x + phase
        else:
            x = 2 * pi * x
        
        return torch.cat([torch.sin(x), torch.cos(x)], dim=1)

In [3]:
class LFF(nn.Module):
    def __init__(self, mapping_size, num_input_channels=2):
        super(LFF, self).__init__()
        self.ffm = ConLinear(num_input_channels, mapping_size, is_first=True)
        self.activation = SinActivation()

    def forward(self, x, phase=None):            
        x = self.ffm(x)
        if phase is not None:
            x = self.activation(x + phase)
        else:
            x = self.activation(x)
        return x
    
class ConLinear(nn.Module):
    def __init__(self, ch_in, ch_out, is_first=False, bias=True):
        super(ConLinear, self).__init__()
        self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1, padding=0, bias=bias)
        if is_first:
            nn.init.uniform_(self.conv.weight, -np.sqrt(9 / ch_in), np.sqrt(9 / ch_in))
        else:
            nn.init.uniform_(self.conv.weight, -np.sqrt(3 / ch_in), np.sqrt(3 / ch_in))

    def forward(self, x):
        return self.conv(x)


class SinActivation(nn.Module):
    def __init__(self,):
        super(SinActivation, self).__init__()

    def forward(self, x, phase=None):
        if phase:
            return torch.sin(x + phase)
        else:
            return torch.sin(x)
        
        
def get_grid(h, w, b=0, norm=True, device='cpu'):
    if norm:
        xgrid = np.linspace(0, w, num=w) / w
        ygrid = np.linspace(0, h, num=h) / h
    else:
        xgrid = np.linspace(0, w, num=w)
        ygrid = np.linspace(0, h, num=h)
    xv, yv = np.meshgrid(xgrid, ygrid, indexing='xy')
    grid = np.stack([xv, yv], axis=-1)[None]

    grid = torch.from_numpy(grid).float().to(device)
    if b > 0:
        grid = grid.expand(b, -1, -1, -1)  # [Batch, H, W, UV]
        return grid.permute(0, 3, 1, 2)  # [Batch, UV, H, W]
    else:
        return grid[0].permute(2, 0, 1)  # [UV, H, W]


In [4]:
class FMLP(nn.Module):
    def __init__(self,
                 internal_dim,
                 num_layers=3,
                 act=nn.LeakyReLU(), 
                 ff_func=LFF,
                 phase_mod=False,
                 num_input_channels=2,
                ):
        super().__init__()
        self.lff = ff_func(mapping_dim=internal_dim, num_input_channels=num_input_channels)
        self.net = [ConLinear(internal_dim * 2, internal_dim * 2, is_first=False),
                    act,
        ]
        for layer_n in range(num_layers):
            self.net.append(ConLinear(internal_dim * 2, internal_dim * 2, is_first=False))
            self.net.append(act)
    
        self.net.append(ConLinear(internal_dim * 2, 3, is_first=False))
        self.net = nn.Sequential(*self.net)
        
        if phase_mod:
            self.phase_net = [ff_func(mapping_dim=internal_dim // 2,
                                      num_input_channels=3),
                              ConLinear(internal_dim,
                                        internal_dim,
                                        is_first=False),
                              act,
                              ConLinear(internal_dim,
                                        internal_dim,
                                        is_first=False),
                              act,
                              ConLinear(internal_dim,
                                        internal_dim,
                                        is_first=False),
                              act,
                              ConLinear(internal_dim,
                                        1,
                                        is_first=False),]
            self.phase_net = nn.Sequential(*self.phase_net)
    
    def forward(self, coords):
        if phase_mod:
            # coords: list of [spatial_coords:[B,2,H,W], spacetime_coords[B, 3, H, W]]
            spatial_coords, time_coords = coords
            phase_feats = self.phase_net(time_coords)
            fourier_feats = self.lff(spatial_coords, phase_feats)
        else:
            fourier_feats = self.lff(coords)
        return self.net(fourier_feats)

In [5]:
class FramesDataset(Dataset):
    def __init__(self, folder, resolution=256, phase_mod=False):
        super().__init__()
        self.files = sorted(glob(os.path.join(folder, '*.png')))
        self.num_frames = len(self.files)
        self.resolution = resolution
        self.phase_mod = phase_mod
        
    def __len__(self):
        return self.num_frames
    
    def __getitem__(self, idx):
        img = Image.open(self.files[idx])
        img = TF.Resize(self.resolution)(img)
        img = TF.ToTensor()(img)
        
        coords = get_grid(*img.shape[-2:], 0, True, 'cpu')
        if self.phase_mod:
            timestamp = int(self.files[idx].split('/')[-1].split('.')[0].split('_')[-1])
            timestamp_float = timestamp / self.num_frames
            coords = [coords, torch.cat([coords, coords[[0]] * 0. + timestamp_float], dim=0)]
        else:
            timestamp = int(self.files[idx].split('/')[-1].split('.')[0].split('_')[-1])
            timestamp_float = timestamp / self.num_frames
            coords = torch.cat([coords, coords[[0]] * 0. + timestamp_float], dim=0)
        return img, coords

In [6]:
fmlp = FMLP(128,
            num_layers=8,
            act=nn.LeakyReLU(),
            ff_func=GaussianFourierFeatureTransform,
            phase_mod=phase_mod,
            num_input_channels=2).to(device)

In [7]:
dset = FramesDataset('./frames/71', phase_mod=phase_mod, resolution=((256, 256)))
dloader = DataLoader(dset, batch_size=4, shuffle=True, num_workers=8)

In [8]:
opt = torch.optim.Adam(fmlp.parameters())

In [None]:
for i in range(1000):
    losses = []
    for img, coords in tqdm(dloader):
        img, coords = img.to(device), [coordss.to(device) for coordss in coords]
        opt.zero_grad()
        out = torch.sigmoid(fmlp(coords))
        loss = ((img - out) ** 2).mean()
        loss.backward()
        opt.step()
        losses.append(loss.item())
    print(i, np.mean(losses))
    
    if i % 5 == 0:
        out_file = cv2.VideoWriter(f'./output_{i}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 60.0, (256, 256))
        coords = get_grid(256, 256, 1, True, device)
        for frame in tqdm(np.linspace(-0.2, 1.2, 200)):
            coords_stacked = [coords, torch.cat([coords, coords[:, [0]] * 0. + frame], dim=1)]
            out = torch.sigmoid(fmlp(coords_stacked))
            out_file.write((out[0].permute(1, 2, 0).cpu().data.numpy()[:, :, ::-1] * 255).clip(0, 255).astype(np.uint8))
        out_file.release()

HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))


0 0.027553329967400605


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))


1 0.017408519698416485


HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))


2 0.012962188933263806


HBox(children=(FloatProgress(value=0.0, max=85.0), HTML(value='')))

In [None]:
out_file = cv2.VideoWriter('./output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 60.0, (256, 256))
coords = get_grid(256, 256, 1, True, device)
for frame in tqdm(np.linspace(-0.2, 1.2, 200)):
    coords_stacked = [coords, torch.cat([coords, coords[:, [0]] * 0. + frame], dim=1)]
    out = torch.sigmoid(fmlp(coords_stacked))
    out_file.write((out[0].permute(1, 2, 0).cpu().data.numpy()[:, :, ::-1] * 255).clip(0, 255).astype(np.uint8))
out_file.release()

In [None]:
!ffmpeg -i output.mp4 output_recoded.mp4 -y