In [9]:
from utils.dataloader import DataPartitions, DataGenerator
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [10]:
partitions = DataPartitions(
    past_frames=4, 
    future_frames=4, 
    root="../datasets/arda/04_21_full/",
    partial = 0.3
)

In [11]:
dataset = DataGenerator(
    root="../datasets/arda/04_21_full/", 
    filenames=partitions.get_areas(), 
    dataset_partitions=partitions.get_train(), 
    past_frames=partitions.past_frames, 
    future_frames=partitions.future_frames, 
    input_dim=(partitions.past_frames, 256, 256, 3),  
    output_dim=(partitions.future_frames, 256, 256, 1), 
    batch_size=16, 
    n_channels=1, 
    shuffle=True,
    buffer_size = 1e3,
    buffer_memory = 100
)

In [12]:
X = dataset.get_X()
Y = dataset.get_Y()

X[X > 10e5] = 0 
Y[Y > 10e5] = 0 

100%|██████████| 14/14 [00:44<00:00,  3.16s/it]
100%|██████████| 14/14 [00:26<00:00,  1.91s/it]


In [13]:
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42)

In [14]:
sc_img = StandardScaler() # image
sc_vvx = StandardScaler() # vvx
sc_vvy = StandardScaler() # vvy

for sample in X_train:
    for batch in sample:
        for frame in batch:
            sc_img.partial_fit(frame[:,:,0])
            sc_vvx.partial_fit(frame[:,:,1])
            sc_vvy.partial_fit(frame[:,:,1])

In [15]:
for s, sample in enumerate(X_train):
    for b, batch in enumerate(sample):
        for f, frame in enumerate(batch):
            X_train[s, b, f, :, :, 0] = sc_img.transform(frame[:,:,0])
            X_train[s, b, f, :, :, 1] = sc_vvx.transform(frame[:,:,1])
            X_train[s, b, f, :, :, 2] = sc_vvy.transform(frame[:,:,2])
            
print("X_train ready")        

for s, sample in enumerate(X_test):
    for b, batch in enumerate(sample):
        for f, frame in enumerate(batch):
            X_test[s, b, f, :, :, 0] = sc_img.transform(frame[:,:,0])
            X_test[s, b, f, :, :, 1] = sc_vvx.transform(frame[:,:,1])
            X_test[s, b, f, :, :, 2] = sc_vvy.transform(frame[:,:,2])
            
print("X_test transformed")
            
for s, sample in enumerate(y_train):
    for b, batch in enumerate(sample):
        for f, frame in enumerate(batch):
            y_train[s, b, f, :, :, 0] = sc_img.transform(frame[:,:,0])
  
print("y_train transformed")

for s, sample in enumerate(y_test):
    for b, batch in enumerate(sample):
        for f, frame in enumerate(batch):
            y_test[s, b, f, :, :, 0] = sc_img.transform(frame[:,:,0])
            
print("y_test transformed")

X_train ready
X_test transformed
y_train transformed
y_test transformed


In [None]:
del X
del Y

### Flax time!

In [2]:
import numpy as np

import jax
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze

from flax import linen as nn
from flax import optim              # Optimizers

from typing import Any, Callable, Sequence, Optional, Tuple, List

from functools import partial


In [3]:
from tqdm import tqdm

In [4]:
class ResNetBlock(nn.Module):
    """ResNet block."""
    filters: int
    strides: Tuple[int, int, int] = (1, 1, 1)

    @nn.compact
    def __call__(self, x,):
        
        residual = x
        
        # -- branch
        y = nn.Conv(self.filters, (3, 3, 3), self.strides)(x)
        #y = nn.BatchNorm(use_running_average=True,
        #             momentum=0.9,
        #             epsilon=1e-5,
        #             dtype=jnp.float32)(y)
        y = nn.relu(y)
        y = nn.Conv(self.filters, (3, 3, 3))(y)
        #y = nn.BatchNorm(use_running_average=True,
        #             momentum=0.9,
        #             epsilon=1e-5,
        #             dtype=jnp.float32)(y)

        # reshape(x)
        if residual.shape != y.shape:
            residual = nn.Conv(self.filters, (1, 1, 1),
                               self.strides, name='conv_proj')(residual)
         #   residual = nn.BatchNorm(use_running_average=True,
         #            momentum=0.9,
         #            epsilon=1e-5,
         #            dtype=jnp.float32, name='norm_proj')(residual)

        # sum
        return nn.relu(residual + y)
    
class Encoder(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        
        x = nn.Conv(64, (7,7,7), strides=(2,2,2))(x)
        x = nn.avg_pool(x, (2,2,2))
        
        channels = [64, 128, 256, 512]
        blocks = [3, 3, 5, 2]
        
        for i, c in enumerate(channels):
            # moving to the next channel size
            if i is not 1:
                x = nn.Conv(c, (3,3,3), strides=(2,2,2))(x)
                x = nn.Conv(c, (3,3,3))(x)
            
            # stacked blocks
            for b in blocks:
                x = ResNetBlock(filters=c)(x)
        
        return x
    
# il decoder satura la memoria
class Decoder(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        
        '''
        channels = [128, 64]
        blocks = [3, 3]
        
        for i, c in enumerate(channels):
            
            # stacked blocks
            for b in blocks:
                x = InvResNetBlock(filters=c)(x)
               
            if i is not channels[-1]:
                x = nn.ConvTranspose(c, (3,3,3))(x)
                x = nn.ConvTranspose(c, (3,3,3), strides=(2,2,2))(x) 
        '''
        
        x = nn.ConvTranspose(64, (1,1,1), strides=(2,2,2))(x)
        x = nn.ConvTranspose(64, (7,7,7), strides=(2,2,2))(x)
        x = nn.ConvTranspose(1, (1,1,1))(x)
        
        return x
    
class Autoencoder(nn.Module):
    """Complete ResNet"""
    
    def setup(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

    def __call__(self, x):
        return self.decoder(self.encoder(x))

In [5]:
def get_initial_params(key, shape):
    init_shape = jnp.ones(shape, jnp.float32)
    initial_params = Autoencoder().init(key, init_shape)['params']
    return initial_params

In [6]:
input_shape = (1, 4, 256, 256, 3)

In [7]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

params = get_initial_params(init_rng, input_shape)

In [39]:
X_train.__getitem__(1).shape

(16, 4, 256, 256, 3)

In [41]:
logits = Autoencoder().apply({'params': params}, X_train.__getitem__(1))

RuntimeError: Resource exhausted: Out of memory while trying to allocate 10480615488 bytes.

In [None]:
def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

In [None]:
def create_optimizer(params, learning_rate, beta):
    optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
    optimizer = optimizer_def.create(params)
    return optimizer

In [None]:
def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
      'loss': loss,
      'accuracy': accuracy
    }
    return metrics

In [None]:
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    # Split into training/test sets
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    # Convert to floating-points
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
    return train_ds, test_ds

In [None]:
def eval_model(model, test_ds):
    metrics = eval_step(model, test_ds)
    metrics = jax.device_get(metrics)
    eval_summary = jax.tree_map(lambda x: x.item(), metrics)
    return eval_summary['loss'], eval_summary['accuracy']

In [None]:
@jax.jit
def train_step(optimizer, batch):
    
    def loss_fn(params):
            logits = Autoencoder().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grad = grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)
    metrics = compute_metrics(logits, batch['label'])
    return optimizer, metrics

In [None]:
@jax.jit
def eval_step(params, batch):
    logits = CNN().apply({'params': params}, batch['image'])
    return compute_metrics(logits, batch['label'])

In [None]:
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds['image']))
    perms = perms[:steps_per_epoch * batch_size]  # Skip an incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    batch_metrics = []

    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        optimizer, metrics = train_step(optimizer, batch)
        batch_metrics.append(metrics)

    training_batch_metrics = jax.device_get(batch_metrics)
    training_epoch_metrics = {
      k: np.mean([metrics[k] for metrics in training_batch_metrics])
      for k in training_batch_metrics[0]}

    print('Training - epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch, training_epoch_metrics['loss'], training_epoch_metrics['accuracy'] * 100))

    return optimizer, training_epoch_metrics

In [None]:
train_ds, test_ds = get_datasets()

In [None]:
print(train_ds["image"].shape)
print(train_ds["label"].shape)

In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

params = get_initial_params(init_rng)

In [None]:
learning_rate = 0.1
beta = 0.9

optimizer = create_optimizer(params, learning_rate=learning_rate, beta=beta)

In [None]:
num_epochs = 10
batch_size = 32

In [None]:
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, epoch, input_rng)
    test_loss, test_accuracy = eval_model(optimizer.target, test_ds)
    print('Testing - epoch: %d, loss: %.2f, accuracy: %.2f' % (epoch, test_loss, test_accuracy * 100))

In [None]:
class CNN(nn.Module):
  train: bool
    
  @nn.compact
  def __call__(self, x):
        
        
    x = nn.Conv(features=32, kernel_size = (3, 3, 3), use_bias = False, dtype=jnp.float32)(x)
    x = nn.relu(x)
    x = nn.BatchNorm(use_running_average=not self.train,
                 momentum=0.9,
                 epsilon=1e-5,
                 dtype=jnp.float32)(x)
    
    x = nn.avg_pool(x, window_shape=(2, 2, 2))
 
    return x