## Training with IRM2FL

Different neural network frameworks are trained with for artificial labelling of focal adhesion structures in IRM:

- a U-Net with paired IRM-IF images
- a Pix2Pix network with paired IRM-IF images
- a CycleGAN with paired IRM-IF images
- a newly proposed 2LGAN with paired and unpaired IRM-IF images


<img src="..\nets.png" width="600" align="center"/>

In [None]:
from irm2fl.data.TFRecords import TFRecords
from irm2fl.models.Trainer import Trainer

from irm2fl.models import UNet, Pix2Pix, CycleGAN, TwoLGAN
from irm2fl.models.modules import CARE, FNet, PatchGAN32, PatchGAN34

import irm2fl.models.Losses as Losses

In [None]:
def get_data(dataset_name_paired=None, dataset_name_unpaired=None):
    
    data_paired = TFRecords(dir_tfrecords = r'tfrecords/{}'.format(dataset_name_paired),
                            dict_input  = {'feature_name': 'image_irm', 'patch_size': (192, 192, 3)},
                            dict_target = {'feature_name': 'image_if_paired', 'patch_size': (192, 192, 1)})
    
    if dataset_name_unpaired is None:
        return [data_paired]
    else:
        data_unpaired = TFRecords(dir_tfrecords = r'tfrecords/{}'.format(dataset_name_unpaired),
                                  dict_input  = {'feature_name': 'image_if_unpaired', 'patch_size': (192, 192, 1)})
        return [data_paired, data_unpaired]

### U-Net

In [None]:
ds = 'ds1'

for loss, final_activation in [(Losses.MSE(), None),
                               (Losses.MS_SSIM(n_scales=1), 'sigmoid')]:

    data = get_data(dataset_name_paired = f'{ds}_192px_irm2fl')
    
    model_name = f'UNet_{loss.name}'

    generator = FNet(final_activation=final_activation)

    model = UNet(generator = generator)
    model.dir_model = r"models/{}/{}".format(ds, model_name)

    MyTrainer = Trainer(data=data, model=model)

    MyTrainer.train(loss=loss, epochs=10)
    MyTrainer.plot_examples(display=True)
    MyTrainer.evaluate()

### Pix2Pix

In [None]:
ds = 'ds1'

for loss, final_activation in [(Losses.MSE(), None),
                               (Losses.MS_SSIM(n_scales=1), 'sigmoid')]:

    
    data = get_data(dataset_name_paired = f'{ds}_192px_irm2fl')

    model_name = f'P2P_{loss.name}'

    generator = FNet(final_activation=final_activation)
    discriminator = PatchGAN32(input_shape=(128,128,2))

    model = Pix2Pix(
                generator = FNet(final_activation=final_activation),
                discriminator = PatchGAN32(input_shape=(128,128,2)),
                   )

    model.dir_model = r"models/{}/{}".format(ds, model_name)

    MyTrainer = Trainer(data=data, model=model)

    MyTrainer.train(loss=loss, epochs=10)
    MyTrainer.plot_examples(display=True)
    MyTrainer.evaluate()

### CycleGAN

In [None]:
ds = 'ds1'

for loss, final_activation in [(Losses.MAE(), 'sigmoid'),
                               (Losses.MS_SSIM(n_scales=1), 'sigmoid')]:

    data = get_data(dataset_name_paired   = f'{ds}_192px_irm2fl',
                    dataset_name_unpaired = f'ds3_192px_fl-only')

    model_name = f'CG_{loss.name}'

    model = CycleGAN(
                generator  = CARE(final_activation=final_activation),
                generator2 = CARE(final_activation=final_activation),
                discriminator  = PatchGAN34(input_shape=(128,128,1)),
                discriminator2 = PatchGAN34(input_shape=(128,128,1))
                   )

    model.dir_model = r"models/{}/{}".format(ds, model_name)

    MyTrainer = Trainer(data=data, model=model)

    MyTrainer.train(loss=loss, epochs=10)
    MyTrainer.plot_examples(display=True)
    MyTrainer.evaluate()

### 2LGAN

In [None]:
ds = 'ds1'

for loss, final_activation in [(Losses.MAE(), 'sigmoid'),
                               (Losses.MS_SSIM(n_scales=1), 'sigmoid')]:

    data = get_data(dataset_name_paired   = f'{ds}_192px_irm2fl',
                    dataset_name_unpaired = f'ds3_192px_fl-only')

    model_name = f'2LG_{loss.name}'

    model = TwoLGAN(
                generator  = CARE(final_activation=final_activation),
                generator2 = CARE(final_activation=final_activation),
                discriminator  = PatchGAN34(input_shape=(128,128,1)),
                   )

    model.dir_model = r"models/{}/{}".format(ds, model_name)

    MyTrainer = Trainer(data=data, model=model)

    MyTrainer.train(loss=loss, epochs=10)
    MyTrainer.plot_examples(display=True)
    MyTrainer.evaluate()