In [None]:
import utils.gn as gn 
import utils.generator as gen
import utils.visualization as viz

import torch_geometric
from torch_geometric.data import Data
import torch

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from tqdm import tqdm

from plot2vid import PlotRecorder


In [None]:
path = 'datasets/fluids'

water_particles, water_frames = [], []
elastic_particles, elastic_frames = [], []
snow_particles, snow_frames = [], []
sand_particles, sand_frames = [], []

for sample in tqdm(range(30)):
    water_particles.append(torch.load(f'{path}/particles/water_{sample}'))
    elastic_particles.append(torch.load(f'{path}/particles/elastic_{sample}'))
    snow_particles.append(torch.load(f'{path}/particles/snow_{sample}'))
    sand_particles.append(torch.load(f'{path}/particles/sand_{sample}'))

    water_frames.append(torch.load(f'{path}/frames/water_{sample}'))
    elastic_frames.append(torch.load(f'{path}/frames/elastic_{sample}'))
    snow_frames.append(torch.load(f'{path}/frames/snow_{sample}'))
    sand_frames.append(torch.load(f'{path}/frames/sand_{sample}'))

all_particles = [*water_particles, *elastic_particles, *snow_particles, *sand_particles]
all_frames = [*water_frames, *elastic_frames, *snow_frames, *sand_frames]

In [None]:
n_trajectories = len(all_particles)
trajectory_length = water_particles[0].shape[0]
video_shape = all_frames[0].shape

In [None]:
particle_trajectories = all_particles
raw_videos = torch.zeros((len(all_frames), *video_shape), dtype=torch.float32)

In [None]:
for i, frame in enumerate(all_frames):
    raw_videos[i] = torch.tensor(frame, dtype=torch.float32) / 255

In [None]:
for i, particles in enumerate(all_particles):
    particle_trajectories[i] = torch.tensor(particles, dtype=torch.float32)

In [None]:
STEP = 0.01
MAX_FPS = 100
RESOLUTION = 64

PHYSICAL_HIDDEN_ENCODING_SIZE = 32
PHYSICAL_ENCODING_SIZE = 4

EDGE_EMBEDDING_SIZE = 48
NODE_EMBEDDING_SIZE = 48
NUM_GN_LAYERS = 3
GN_PROCESSOR_DEPTH = 1
PAST_VELOCITIES = 3
GRAPH_NETWORK_EPOCHS = 2
CONNECTIVITY_RADIUS=0.12
EDGE_NOISE_STD = 0.05
PARTICLE_NOISE_STD = 0.002
MIN_STEP = 50
LIMIT_EDGES_PER_PARTICLE = 100000

In [None]:
videos = raw_videos[:,::(MAX_FPS // 20)]

In [None]:
videos.shape

In [None]:
graph_dataset = gn.ParticleVideoDataset(
    number_of_classes=4,
    raw_particles=particle_trajectories, 
    videos=videos,
    video_encoder=None,
    system='fluid',
    past_velocities=PAST_VELOCITIES,
    connectivity_radius=CONNECTIVITY_RADIUS,
    edge_noise_std=EDGE_NOISE_STD,
    particle_noise_std=PARTICLE_NOISE_STD,
    minimum_rollout_step=MIN_STEP,
    limit_edges_per_particle=LIMIT_EDGES_PER_PARTICLE
)

torch.save(graph_dataset, 'datasets/fluids/classes=4_samples=30.torch')

In [None]:
graph_dataset.one_hot_encode_class = False
graph_dataset.minimum_rollout_step = MIN_STEP

In [None]:
from torch import nn
import time

def train(model: nn.Module,
          train_loader,
          validation_loader=None,
          n_epochs=10,
          lr=0.001,
          lr_decay=1.0,
          verbose=False,
          additional_models=[],
          checkpoint_name=None):
    stats = {
        'loss_iteration': [],
        'loss_epoch': [],
        'loss_validation': []
    }

    # Training setup
    criterion = nn.MSELoss()
    
    optimizers = [torch.optim.Adam(model.parameters(), lr=lr)]
    for m in additional_models:
        optimizers.append(torch.optim.Adam(m.parameters(), lr=lr))
    
    schedulers = []
    for o in optimizers:
        schedulers.append(torch.optim.lr_scheduler.ExponentialLR(o, gamma=lr_decay))

    # Actual training
    for epoch in range(1, n_epochs+1):
        train_loss = 0.0
        graph_dataset.edge_noise_std = EDGE_NOISE_STD
        graph_dataset.particle_noise_std = PARTICLE_NOISE_STD
        
        for i, graph in tqdm(enumerate(train_loader)):
            for o in optimizers:
                o.zero_grad()

            predicted_graph = model(graph.clone())

            loss = criterion(predicted_graph.x, graph.y)
            loss.backward()

            for o in optimizers:
                o.step()

            train_loss += loss.item()
            stats['loss_iteration'].append(loss.item())
            
            if verbose and i > 0 and i % 500 == 0:
                print(f"Epoch: {epoch} Batch: {i} \t Loss: {train_loss / i}")
        
        for s in schedulers:
            s.step()

        train_loss = train_loss / len(train_loader)
        stats['loss_epoch'].append(train_loss)
        if verbose:
            print(f"Epoch: {epoch} \t Loss: {train_loss}")
            
        if validation_loader is not None:
            graph_dataset.edge_noise_std = 0.0
            graph_dataset.particle_noise_std = 0.0

            with torch.no_grad():
                validation_loss = []
                for i, graph in enumerate(validation_loader):
                    predicted_graph = model(graph)
                    loss = criterion(predicted_graph.x, graph.y)
                    validation_loss.append(loss.item())

                stats['loss_validation'].append(sum(validation_loss) / len(validation_loss))
                
                if verbose:
                    print(f'\tValidation loss: {stats["loss_validation"][-1]}')

        if checkpoint_name:
            torch.save(model, f'models/fluids/experiments/model_{checkpoint_name}_epoch={epoch}.torch')
            for i, m in enumerate(additional_models):
                torch.save(m, f'models/fluids/experiments/additional_{i}_{checkpoint_name}_epoch={epoch}.torch')

    return stats

In [None]:
from utils.video import ConvVideoEncoder
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from utils.builder import build_mlp

class VideoEncoder(nn.Module):
    def __init__(self, 
                 frame_encoder: nn.Module,
                 latent_frame_size: int,
                 encoding_size: int,
                 hidden_state_size=32,
                 num_lstm_layers=1):
        super(VideoEncoder, self).__init__()

        self.hidden_state_size = hidden_state_size

        self.frame_encoder = frame_encoder

        self.linear_in = nn.Linear(latent_frame_size, hidden_state_size)
        self.lstm = nn.LSTM(hidden_state_size, hidden_state_size, num_layers=num_lstm_layers, batch_first=True)
        self.linear_out = nn.Linear(hidden_state_size, encoding_size)
        self.activation = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, seq_length, height, width, channels = x.shape

        x = self.frame_encoder(x.view(batch_size * seq_length, height, width, channels)).view(batch_size, seq_length, -1)

        x = self.linear_in(x)
        x = self.activation(x)

        x, _ = self.lstm(x)

        x = self.activation(x)
        x = self.linear_out(x)
        x = self.activation(x)

        return x

frame_encoder = nn.Sequential(
    nn.Flatten(), 
    build_mlp(
        input_dim = RESOLUTION * RESOLUTION * 3,
        hidden_layers = [8 * RESOLUTION, 2 * RESOLUTION],
        out_dim = 128,
        activations=['relu', 'relu', 'relu']
    )
)

video_encoder = torch.load('models/fluids/experiments/additional_0_fps=20_classes=4_epoch=17.torch')

In [None]:
video_encoder

In [None]:
graph_dataset.video_encoder = video_encoder

In [None]:
videos.shape

In [None]:
video_encoder(videos[:5]).shape

In [None]:
graph_dataset_loader = DataLoader(graph_dataset, batch_size=8, shuffle=True)
validation_graph_dataset_loader = DataLoader(graph_dataset, batch_size=8, shuffle=True)

In [None]:
preprocessing_data = {
    'node_mean': graph_dataset.node_mean.detach().clone(),
    'node_std': graph_dataset.node_std.detach().clone(),
    'edge_mean': graph_dataset.edge_mean.detach().clone(),
    'edge_std': graph_dataset.edge_std.detach().clone()
}
        
postprocessing_data = {
    'out_mean': graph_dataset.out_mean.detach().clone(),
    'out_std': graph_dataset.out_std.detach().clone(),
}
        
graph_network = gn.build_encoder_processor_decoder(
    node_size=PAST_VELOCITIES * 2 + PHYSICAL_ENCODING_SIZE + 4,
    edge_size=3, 
    node_latent=NODE_EMBEDDING_SIZE, 
    edge_latent=EDGE_EMBEDDING_SIZE,
    output_size=2,
    num_gn_layers=NUM_GN_LAYERS,
    shared_gn_layers=False,
    processor_depth=GN_PROCESSOR_DEPTH,
    aggregation_fun='sum',
    preprocessing_data=preprocessing_data,
    postprocessing_data=postprocessing_data
)

In [None]:
graph_dataset.edge_noise_std = EDGE_NOISE_STD / 5
graph_dataset.particle_noise_std = PARTICLE_NOISE_STD / 3


In [None]:
print('Training the graph network')
graph_network_stats = train(
    model=graph_network,
    train_loader=graph_dataset_loader,
    validation_loader=validation_graph_dataset_loader,
    n_epochs=100,
    lr=0.0001,
    lr_decay=0.95,
    verbose=True,
    additional_models=[video_encoder],
    checkpoint_name=f'_fps={20}_classes={4}',
)