In [1]:
import argparse
import os
import time
import pytorch_lightning as pl
import torchvision.transforms
from rich import print
from torch.utils.data import DataLoader
import lima
import glob
import h5py
import skimage.io as io
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
import cmcrameri.cm as cmc
from pykitPIV import Particle, FlowField, Motion, Image

In [2]:
import platform
print(platform.python_version())

In [3]:
max_of_images = 65535.0

<a id=synthetic-images></a>

***

## Generate synthetic images with `pykitPIV`

[Go to the top](#top-page)

In [4]:
image_size = (124,124)
size_buffer = 10
figsize = (5,3)

In [5]:
def generate_images(n_images, random_seed):

    tic = time.perf_counter()

    # Instantiate an object of the Particle class:
    particles = Particle(n_images,
                         size=image_size,
                         size_buffer=size_buffer,
                         diameters=(4,4.1),
                         distances=(1,2),
                         densities=(0.05,0.1),
                         signal_to_noise=(5,20),
                         diameter_std=0.2,
                         seeding_mode='random',
                         random_seed=random_seed)

    # Instantiate an object of the FlowField class:
    flowfield = FlowField(n_images,
                          size=image_size,
                          size_buffer=size_buffer,
                          random_seed=random_seed)

    flowfield.generate_random_velocity_field(gaussian_filters=(10,11),
                                             n_gaussian_filter_iter=20,
                                             displacement=(0,10))

    # Instantiate an object of the Motion class:
    motion = Motion(particles, 
                    flowfield, 
                    time_separation=0.1)

    # Advect particles:
    motion.forward_euler(n_steps=10)

    # Instantiate an object of the Image class:
    image = Image(random_seed=random_seed)

    # Prepare images - - - - - - - - - - - - - - - - - - 

    image.add_particles(particles)

    image.add_flowfield(flowfield)
  
    image.add_motion(motion)
    
    image.add_reflected_light(exposures=(0.7,0.8),
                              maximum_intensity=2**16-1,
                              laser_beam_thickness=1,
                              laser_over_exposure=1,
                              laser_beam_shape=0.95,
                              alpha=1/10)

    image.remove_buffers()

    toc = time.perf_counter()

    print(f'Time it took: {(toc - tic)/60:0.1f} minutes.\n')

    return image

<a id=synthetic-images-training-set></a>

### Training set

[Go to the top](#top-page)

In [None]:
# n_images = 100

In [None]:
# training_random_seed = 100

In [None]:
# image_train = generate_images(n_images, training_random_seed)

# image_pairs_train = image_train.image_pairs_to_tensor()
# targets_train = image_train.targets_to_tensor()

In [None]:
# image_train.save_to_h5({'I': image_pairs_train, 'targets': targets_train}, filename='PIV-dataset-train.h5')

In [None]:
# image_train.plot(0,
#                  instance=1,
#                  with_buffer=True,
#                  xlabel='Width [px]',
#                  ylabel='Height [px]',
#                  cmap='Greys_r',
#                  figsize=figsize);

In [None]:
# max_of_images = np.max(image_pairs_train)
# max_of_images

<a id=synthetic-images-testing-set></a>

### Testing set

[Go to the top](#top-page)

In [None]:
# n_images = 20

In [None]:
# test_random_seed = 200

In [None]:
# image_test = generate_images(n_images, test_random_seed)

# image_pairs_test = image_test.image_pairs_to_tensor()
# targets_test = image_test.targets_to_tensor()

In [None]:
# image_train.save_to_h5({'I': image_pairs_test, 'targets': targets_test}, filename='PIV-dataset-test.h5')

In [None]:
# image_test.plot(0,
#                 instance=1,
#                 with_buffer=True,
#                 xlabel='Width [px]',
#                 ylabel='Height [px]',
#                 cmap='Greys_r',
#                 figsize=figsize);

<a id=train-LIMA></a>
***

## Train `LIMA` with the generated images

[Go to the top](#top-page)

<a id=train-LIMA-input-data></a>
### Prepare input dataset for LIMA

[Go to the top](#top-page)

In [6]:
transform = torchvision.transforms.Compose([lima.datatransform.RandomAffine(degrees=17, translate=(0.2, 0.2), scale=(0.9, 2.0)),
                                            lima.datatransform.RandomHorizontalFlip(),
                                            lima.datatransform.RandomVerticalFlip(),
                                            lima.datatransform.ToTensor(),
                                            lima.datatransform.RandomBrightness(factor=(0.5, 2)),
                                            lima.datatransform.RandomNoise(std=(0, 0)),])

In [7]:
transform = torchvision.transforms.Compose([lima.datatransform.ToTensor()])

#### Use dataset generated on-the-fly with `pykitPIV`:

[Go to the top](#top-page)

In [None]:
# class pykitPIVDataset(Dataset):
#     """Load pykitPIV-generated dataset"""

#     def __init__(self, image_pairs, targets, transform=None, n_samples=None, pin_to_ram=False):

#         self.data = image_pairs.astype(np.float32)
#         self.target = targets.astype(np.float32)

#         if n_samples:
#             self.data = self.data[:n_samples]
#             self.target = self.target[:n_samples]
#         if pin_to_ram:
#             self.data = np.array(self.data)
#             self.target = np.array(self.target)

#         self.transform = transform

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         if torch.is_tensor(idx):
#             idx = idx.tolist()
#         sample = self.data[idx], self.target[idx]
#         if self.transform:
#             sample = self.transform(sample)

#         return sample

In [None]:
# train_dataset = pykitPIVDataset(image_pairs=image_pairs_train/max_of_images,
#                                 targets=targets_train,
#                                 transform=transform)

In [None]:
# test_dataset = pykitPIVDataset(image_pairs=image_pairs_test/max_of_images,
#                                 targets=targets_test,
#                                 transform=transform)

#### Use a pre-saved dataset generated with `pykitPIV`:

[Go to the top](#top-page)

In [8]:
class pykitPIVDatasetFromPath(Dataset):
    """Load pykitPIV-generated dataset"""

    def __init__(self, path, transform=None, n_samples=None, pin_to_ram=False):
        
        f = h5py.File(path, "r")
        self.data = np.array(f["I"]).astype(np.float32)/max_of_images
        self.target = np.array(f["targets"]).astype(np.float32)

        print(self.target.max())

        if n_samples:
            self.data = self.data[:n_samples]
            self.target = self.target[:n_samples]
            
        if pin_to_ram:
            self.data = np.array(self.data)
            self.target = np.array(self.target)
            f.close()
            
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        sample = self.data[idx], self.target[idx]
        
        if self.transform:
            sample = self.transform(sample)

        return sample

In [9]:
train_dataset = pykitPIVDatasetFromPath(path='PIV-dataset-train.h5',
                                        transform=transform)

In [10]:
test_dataset = pykitPIVDatasetFromPath(path='PIV-dataset-test.h5',
                                       transform=transform)

#### Use dataset generated with Matlab:

[Go to the top](#top-page)

In [None]:
# path = 'PIV_n3_s180_maxd10_rnd_v1.h5'

In [None]:
# f = h5py.File(path, "r")

# images = f["I"]
# images = np.array(images)
# targets = f["target"]
# targets = np.array(targets)[:,2:4,:,:]

# f.close()

In [None]:
# plt.imshow(np.array(images)[0,0,:,:], cmap='Greys_r')

In [None]:
# class HDF5Dataset(Dataset):
#     """HDF5Dataset loaded"""

#     def __init__(self, path, transform=None, n_samples=None, pin_to_ram=False):
#         f = h5py.File(path, "r")
#         self.data = f["I"]
#         self.target = np.array(f["target"])[:,2:4,:,:]

#         if n_samples:
#             self.data = self.data[:n_samples]
#             self.target = self.target[:n_samples]
#         if pin_to_ram:
#             self.data = np.array(self.data)
#             self.target = np.array(self.target)
#             f.close()
#         self.transform = transform

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         if torch.is_tensor(idx):
#             idx = idx.tolist()
#         sample = self.data[idx], self.target[idx]
#         if self.transform:
#             sample = self.transform(sample)

#         return sample

In [None]:
# train_dataset = HDF5Dataset(path=path,
#                             transform=transform,)

In [None]:
# test_dataset = HDF5Dataset(path=path,
#                            transform=transform,)

<a id=train-LIMA-train></a>
### Begin training

[Go to the top](#top-page)

In [11]:
train_loader = DataLoader(train_dataset,
                          batch_size=5,
                          shuffle=True,
                          num_workers=1,
                          pin_memory=True)

test_loader = DataLoader(test_dataset,
                         batch_size=10)

In [12]:
random_seed = 100

In [13]:
pl.seed_everything(random_seed, workers=True)

Global seed set to 100


100

In [14]:
model = lima.LIMA(output_level=1,
                  div_flow=0.05,
                  loss_weights=[0.005,0.01,0.02,0.04,0.08,0.04,0.04],
                  search_range=4,
                  num_chs=[1, 16, 32, 64, 96, 128, 196],
                  loss='l1_loss',
                  loss_weights_order='inc',
                  loss_J='abs',
                  loss_J_gamma=1e-1,
                  full_res=False,
                  full_res_loss_weight_multiplier=2.0,
                  epochs=10,
                  optimizer='Adam',
                  base_lr=0.001,
                  weight_decay=4e-4,
                  momentum=0.9,
                  num_workers=20,
                  beta=0.999,
                  reduction="sum",
                  scheduler='ReduceLROnPlateau',
                  lr_decay=0.2,
                  patience=5,
                  debug=0)

In [15]:
trainer = pl.Trainer()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [16]:
# trainer = pl.Trainer(gpus=1)

In [17]:
trainer.fit(model,
            train_loader,
            test_loader)

  rank_zero_warn(

  | Name                      | Type             | Params
---------------------------------------------------------------
0 | feature_pyramid_extractor | FeatureExtractor | 1.0 M 
1 | warping_layer             | WarpingLayer     | 0     
2 | flow_estimators           | ContextNetwork   | 576 K 
---------------------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.466     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|                                                                                                                                      | 0/2 [00:00<?, ?it/s]torch.Size([10, 81, 2, 2])


  rank_zero_warn(


torch.Size([10, 81, 4, 4])
Sanity Checking DataLoader 0:  50%|███████████████████████████████████████████████████████████████                                                               | 1/2 [00:00<00:00,  3.73it/s]torch.Size([10, 81, 2, 2])
torch.Size([10, 81, 4, 4])
                                                                                                                                                                                               

  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|                                                                                                                                                          | 0/22 [00:00<?, ?it/s]*** LIMA  Data:: torch.float32 cpu    ****
torch.Size([5, 81, 2, 2])
torch.Size([5, 81, 4, 4])
Epoch 0:   5%|████▎                                                                                           | 1/22 [00:09<03:12,  9.16s/it, loss=6.44e+04, v_num=56, train_loss_step=6.44e+4]torch.Size([5, 81, 2, 2])
torch.Size([5, 81, 4, 4])
Epoch 0:   9%|████████▋                                                                                       | 2/22 [00:13<02:16,  6.84s/it, loss=8.01e+04, v_num=56, train_loss_step=95745.5]torch.Size([5, 81, 2, 2])
torch.Size([5, 81, 4, 4])
Epoch 0:  14%|█████████████▏                                                                                   | 3/22 [00:18<01:55,  6.08s/it, loss=6.3e+04, v_num=56, train_loss_step=2.87e+4]torch.Size([5, 81, 2, 2])
torch.Size(

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


<a id=predict></a>
***

## Make predictions from the trained network

[Go to the top](#top-page)

Switch to the evaluation mode of the LIMA model:

In [None]:
model = model.eval()

In [None]:
image_to_predict = 0
velocity_component = 0

In [None]:
predicted_flow = model.inference(image_pairs_train[:,:,:,:].astype(np.float32)/max_of_images)

In [None]:
predicted_flow = model.inference(train_dataset.data[:,:,:,:].astype(np.float32)/max_of_images)

In [None]:
predicted_flow.shape

In [None]:
train_dataset.target.shape

In [None]:
plt.imshow(predicted_flow[image_to_predict,velocity_component,:,:], 
           cmap=cmc.batlow, 
           origin='lower')
plt.colorbar();

In [None]:
plt.imshow(train_dataset.target[image_to_predict,velocity_component,:,:], 
           cmap=cmc.batlow, 
           origin='lower')
plt.colorbar();

In [None]:
x = torch.rand((1,1,100,100), requires_grad=True)

#### Predict from a random tensor:

In [None]:
x = torch.rand(10, 2, 100, 100).cpu()
predicted_flow = model.inference(x)

In [None]:
predicted_flow.shape

In [None]:
plt.imshow(predicted_flow[0,0,:,:], 
           cmap=cmc.batlow, 
           origin='lower')
plt.colorbar();

***