In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
import ffmpeg

from IPython.display import HTML

In [2]:
device = torch.device('cuda')

N_HIDDEN = 8
WIDTH = 20

# STEP_SIZE = 0.005
MAX_T = 3
FRAMES = 300

# 4K
# RESOLUTION_X = 3840
# RESOLUTION_Y = 2160

# 1080p
RESOLUTION_X = 1920
RESOLUTION_Y = 1080

In [3]:
def getLinear(in_dims, out_dims):
    layer = torch.nn.Linear(in_dims, out_dims)
    torch.nn.init.uniform_(layer.weight, -2.0, 2.0)
    return layer

In [4]:
def getModel():
    layers = [
        getLinear(3, WIDTH),
        torch.nn.Tanh()
    ]

    for i in range(N_HIDDEN):
        layers.append(getLinear(WIDTH, WIDTH))
        layers.append(torch.nn.Tanh())

    layers.append(getLinear(WIDTH, 3))
    layers.append(torch.nn.Sigmoid())

    model = torch.nn.Sequential(*layers)
    model = model.to(device)
    
    return model

In [5]:
model = getModel()

# velocities = []
# for p in model.parameters():
#     velocities = torch.randn_like(p)

# def stepModel():
#     model.requires_grad_(False)
    
#     for v, param in zip(velocities, model.parameters()):
#         param += v * STEP_SIZE

In [16]:
# coords_x = np.linspace(start=-1, stop=1, num=RESOLUTION_X, dtype=np.float32)
coords_x = np.sin(np.linspace(start=0, stop=np.pi, num=RESOLUTION_X, dtype=np.float32))
coords_x = (coords_x * 2) - 1

# coords_y = np.linspace(start=-1, stop=1, num=RESOLUTION_Y, dtype=np.float32)
coords_y = np.sin(np.linspace(start=0, stop=np.pi, num=RESOLUTION_Y, dtype=np.float32))
coords_y = (coords_y * 2) - 1


# coords_t = np.linspace(start=-MAX_T, stop=MAX_T, num=FRAMES, dtype=np.float32)
coords_t = np.tan(np.linspace(start=0, stop=np.pi, num=FRAMES, dtype=np.float32))
coords_t = (2 * coords_t) - 1

coords = np.stack(np.meshgrid(coords_x, coords_y, coords_t))
coords = np.transpose(coords, [1, 2, 3, 0])

coords_flat = coords.reshape(-1, 3)
coords_flat = torch.tensor(coords_flat).to(torch.float)
coords_flat = coords_flat.to(device)

In [17]:
N_CHUNKS = 1000

def get_outs(chunk_flat):
    with torch.no_grad():
        out_flat = model(chunk_flat)
    out = out_flat.to('cpu').numpy()
    return out
    

with torch.no_grad():
    out_chunks = []
    for chunk in torch.chunk(coords_flat, N_CHUNKS):
        out_chunks.append(model(chunk).to('cpu'))

In [18]:
out_flat = torch.cat(out_chunks)
out = out_flat.reshape(RESOLUTION_Y, RESOLUTION_X, FRAMES, 3).numpy()

In [19]:
# From: https://github.com/kkroening/ffmpeg-python/issues/246

frames = np.transpose(out, [2, 0, 1, 3]) * 255

def vidwrite(fn, images, framerate=60, vcodec='libx264'):
    if not isinstance(images, np.ndarray):
        images = np.asarray(images)
    n,height,width,channels = images.shape
    process = (
        ffmpeg
            .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height))
            .output(fn, pix_fmt='yuv420p', vcodec=vcodec, r=framerate)
            .overwrite_output()
            .run_async(pipe_stdin=True)
    )
    for frame in images:
        process.stdin.write(
            frame
                .astype(np.uint8)
                .tobytes()
        )
    process.stdin.close()
    process.wait()

In [20]:
vidwrite("out.mp4", frames)