In [1]:
from collections import OrderedDict
import os
from pathlib import Path
import shutil

from imageio.v3 import imread, imwrite
from PIL import Image
import pysaliency
from pysaliency.baseline_utils import BaselineModel, CrossvalidatedBaselineModel
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

from tqdm import tqdm


from deepgaze_pytorch.layers import (
    Conv2dMultiInput,
    LayerNorm,
    LayerNormMultiInput,
    Bias,
    FlexibleScanpathHistoryEncoding
)

from deepgaze_pytorch.modules import DeepGazeIII, FeatureExtractor
from deepgaze_pytorch.features.densenet import RGBDenseNet201
from deepgaze_pytorch.data import ImageDataset, ImageDatasetSampler, FixationDataset, FixationMaskTransform
from deepgaze_pytorch.training import _train


In [2]:
def build_saliency_network(input_channels):
    return nn.Sequential(OrderedDict([
        ('layernorm0', LayerNorm(input_channels)),
        ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
        ('bias0', Bias(8)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(8)),
        ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),

        ('layernorm2', LayerNorm(16)),
        ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
        ('bias2', Bias(1)),
        ('softplus2', nn.Softplus()),
    ]))


def build_scanpath_network():
    return nn.Sequential(OrderedDict([
        ('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(128)),
        ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),
    ]))


def build_fixation_selection_network(scanpath_features=16):
    return nn.Sequential(OrderedDict([
        ('layernorm0', LayerNormMultiInput([1, scanpath_features])),
        ('conv0', Conv2dMultiInput([1, scanpath_features], 128, (1, 1), bias=False)),
        ('bias0', Bias(128)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(128)),
        ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),

        ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
    ]))

In [3]:
def prepare_spatial_dataset(stimuli, fixations, centerbias, batch_size, path=None):
    if path is not None:
        path.mkdir(parents=True, exist_ok=True)
        lmdb_path = str(path)
    else:
        lmdb_path = None

    dataset = ImageDataset(
        stimuli=stimuli,
        fixations=fixations,
        centerbias_model=centerbias,
        transform=FixationMaskTransform(sparse=False),
        average='image',
        lmdb_path=lmdb_path,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size),
        pin_memory=False,
        num_workers=0,
    )

    return loader

In [4]:
def prepare_scanpath_dataset(stimuli, fixations, centerbias, batch_size, path=None):
    if path is not None:
        path.mkdir(parents=True, exist_ok=True)
        lmdb_path = str(path)
    else:
        lmdb_path = None

    dataset = FixationDataset(
        stimuli=stimuli,
        fixations=fixations,
        centerbias_model=centerbias,
        included_fixations=[-1, -2, -3, -4],
        allow_missing_fixations=True,
        transform=FixationMaskTransform(sparse=False),
        average='image',
        lmdb_path=lmdb_path,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size),
        pin_memory=False,
        num_workers=0,
    )

    return loader

In [5]:
dataset_directory = Path('pysaliency_datasets')
train_directory = Path('train_deepgaze3')

In [6]:
device = 'cuda'

# Pretraining on SALICON

In [7]:
SALICON_train_stimuli, SALICON_train_fixations = pysaliency.get_SALICON_train(location=dataset_directory)
SALICON_val_stimuli, SALICON_val_fixations = pysaliency.get_SALICON_val(location=dataset_directory)

# parameters taken from an early fit for MIT1003. Since SALICON has many more fixations, the bandwidth won't be too small
SALICON_centerbias = BaselineModel(stimuli=SALICON_train_stimuli, fixations=SALICON_train_fixations, bandwidth=0.0217, eps=2e-13, caching=False)

# takes quite some time, feel free to set to zero
train_baseline_log_likelihood = SALICON_centerbias.information_gain(SALICON_train_stimuli, SALICON_train_fixations, verbose=True, average='image')
val_baseline_log_likelihood = SALICON_centerbias.information_gain(SALICON_val_stimuli, SALICON_val_fixations, verbose=True, average='image')

In [8]:
model = DeepGazeIII(
    features=FeatureExtractor(RGBDenseNet201(), [
            '1.features.denseblock4.denselayer32.norm1',
            '1.features.denseblock4.denselayer32.conv1',
            '1.features.denseblock4.denselayer31.conv2',
        ]),
    saliency_network=build_saliency_network(2048),
    scanpath_network=None,
    fixation_selection_network=build_fixation_selection_network(scanpath_features=0),
    downsample=1.5,
    readout_factor=4,
    saliency_map_factor=4,
    included_fixations=[],
)

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30, 45, 60, 75, 90, 105, 120])

Using cache found in /home/bethge/mkuemmerer31/.cache/torch/hub/pytorch_vision_v0.6.0


In [9]:
train_loader = prepare_spatial_dataset(SALICON_train_stimuli, SALICON_train_fixations, SALICON_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / 'SALICON_train')
validation_loader = prepare_spatial_dataset(SALICON_val_stimuli, SALICON_val_fixations, SALICON_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / 'SALICON_val')

Generate LMDB to train_deepgaze3/lmdb_cache/SALICON_train


100%|██████████| 10000/10000 [00:58<00:00, 170.48it/s]


Flushing database ...
Populating fixations cache


100%|█████████▉| 68992354/68992355 [01:39<00:00, 692386.13it/s] 


Generate LMDB to train_deepgaze3/lmdb_cache/SALICON_val


100%|██████████| 5000/5000 [00:39<00:00, 127.86it/s]


Flushing database ...
Populating fixations cache


100%|█████████▉| 38846997/38846998 [00:22<00:00, 1732635.07it/s]


In [10]:
_train(train_directory / 'pretraining',
    model,
    train_loader, train_baseline_log_likelihood,
    validation_loader, val_baseline_log_likelihood,
    optimizer, lr_scheduler,
    minimum_learning_rate=1e-7,
    device=device,
)


Training Already finished


# Preparing the MIT1003 dataset

In [11]:
mit_stimuli_orig, mit_scanpaths_orig = pysaliency.external_datasets.mit.get_mit1003_with_initial_fixation(location=dataset_directory, replace_initial_invalid_fixations=True)



In [12]:
def convert_stimulus(input_image):
    size = input_image.shape[0], input_image.shape[1]
    if size[0] < size[1]:
        new_size = 768, 1024
    else:
        new_size = 1024,768
    
    # pillow uses width, height
    new_size = tuple(list(new_size)[::-1])
    
    new_stimulus = np.array(Image.fromarray(input_image).resize(new_size, Image.BILINEAR))
    return new_stimulus

def convert_fixations(stimuli, fixations):
    new_fixations = fixations.copy()
    for n in tqdm(list(range(len(stimuli)))):
        stimulus = stimuli.stimuli[n]
        size = stimulus.shape[0], stimulus.shape[1]
        if size[0] < size[1]:
            new_size = 768, 1024
        else:
            new_size = 1024,768
        x_factor = new_size[1] / size[1]
        y_factor = new_size[0] / size[0]
        
        inds = new_fixations.n == n
        new_fixations.x[inds] *= x_factor
        new_fixations.y[inds] *= y_factor
        new_fixations.x_hist[inds] *= x_factor
        new_fixations.y_hist[inds] *= y_factor
    
    return new_fixations

def convert_fixation_trains(stimuli, fixations):
    train_xs = fixations.train_xs.copy()
    train_ys = fixations.train_ys.copy()
    
    for i in tqdm(range(len(train_xs))):
        n = fixations.train_ns[i]
        
        size = stimuli.shapes[n][0], stimuli.shapes[n][1]
        
        if size[0] < size[1]:
            new_size = 768, 1024
        else:
            new_size = 1024,768
        
        x_factor = new_size[1] / size[1]
        y_factor = new_size[0] / size[0]
        
        train_xs[i] *= x_factor
        train_ys[i] *= y_factor
        
    new_fixations = pysaliency.FixationTrains(
        train_xs = train_xs,
        train_ys = train_ys,
        train_ts = fixations.train_ts.copy(),
        train_ns = fixations.train_ns.copy(),
        train_subjects = fixations.train_subjects.copy(),
        attributes={key: getattr(fixations, key).copy() for key in fixations.__attributes__ if key not in ['subjects', 'scanpath_index']},
    )
    return new_fixations



def convert_stimuli(stimuli, new_location: Path):
    assert isinstance(stimuli, pysaliency.FileStimuli)
    new_stimuli_location = new_location / 'stimuli'
    new_stimuli_location.mkdir(parents=True, exist_ok=True)
    new_filenames = []
    for filename in tqdm(stimuli.filenames):
        stimulus = imread(filename)
        new_stimulus = convert_stimulus(stimulus)
        
        basename = os.path.basename(filename)
        new_filename = new_stimuli_location / basename
        if new_stimulus.size != stimulus.size:
            imwrite(new_filename, new_stimulus)
        else:
            #print("Keeping")
            shutil.copy(filename, new_filename)
        new_filenames.append(new_filename)
    return pysaliency.FileStimuli(new_filenames)

mit_scanpaths_twosize = convert_fixation_trains(mit_stimuli_orig, mit_scanpaths_orig)
mit_stimuli_twosize = convert_stimuli(mit_stimuli_orig, train_directory / 'MIT1003_twosize')

100%|██████████| 15045/15045 [00:00<00:00, 214631.98it/s]
100%|██████████| 1003/1003 [01:10<00:00, 14.30it/s]


In [13]:
# remove the initial forced fixation from the training data, it's only used for conditioning
mit_fixations_twosize = mit_scanpaths_twosize[mit_scanpaths_twosize.lengths > 0]

In [14]:
# parameters optimized on MIT1003 for maximum leave-one-image-out crossvalidation log-likelihood
MIT1003_centerbias = CrossvalidatedBaselineModel(
    mit_stimuli_twosize,
    mit_fixations_twosize,
    bandwidth=10**-1.6667673342543432,
    eps=10**-14.884189168516073,
    caching=False,
)

In [15]:
for crossval_fold in range(10):
    MIT1003_stimuli_train, MIT1003_fixations_train = pysaliency.dataset_config.train_split(mit_stimuli_twosize, mit_fixations_twosize, crossval_folds=10, fold_no=crossval_fold)
    MIT1003_stimuli_val, MIT1003_fixations_val = pysaliency.dataset_config.validation_split(mit_stimuli_twosize, mit_fixations_twosize, crossval_folds=10, fold_no=crossval_fold)

    train_baseline_log_likelihood = MIT1003_centerbias.information_gain(MIT1003_stimuli_train, MIT1003_fixations_train, verbose=True, average='image')
    val_baseline_log_likelihood = MIT1003_centerbias.information_gain(MIT1003_stimuli_val, MIT1003_fixations_val, verbose=True, average='image')

    # finetune spatial model on MIT1003

    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=None,
        fixation_selection_network=build_fixation_selection_network(scanpath_features=0),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[],
    )

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9, 12, 15, 18, 21, 24])

    train_loader = prepare_spatial_dataset(MIT1003_stimuli_train, MIT1003_fixations_train, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_train_spatial_{crossval_fold}')
    validation_loader = prepare_spatial_dataset(MIT1003_stimuli_val, MIT1003_fixations_val, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_val_spatial_{crossval_fold}')

    _train(train_directory / 'MIT1003_spatial' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'pretraining' / 'final.pth',
    )


    # Train scanpath model

    train_loader = prepare_scanpath_dataset(MIT1003_stimuli_train, MIT1003_fixations_train, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_train_scanpath_{crossval_fold}')
    validation_loader = prepare_scanpath_dataset(MIT1003_stimuli_val, MIT1003_fixations_val, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_val_scanpath_{crossval_fold}')

    # first train with partially frozen saliency network


    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=build_scanpath_network(),
        fixation_selection_network=build_fixation_selection_network(scanpath_features=16),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[-1, -2, -3, -4],
    )
    model = model.to(device)

    frozen_scopes = [
        "saliency_network.layernorm0",
        "saliency_network.conv0",
        "saliency_network.bias0",
        "saliency_network.layernorm1",
        "saliency_network.conv1",
        "saliency_network.bias1",
    ]

    for scope in frozen_scopes:
        for parameter_name, parameter in model.named_parameters():
            if parameter_name.startswith(scope):
                print("Fixating parameter", parameter_name)
                parameter.requires_grad = False


    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 31, 32, 33, 34, 35])

    _train(train_directory / 'MIT1003_scanpath_partially_frozen_saliency_network' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'MIT1003_spatial' /  f'crossval-10-{crossval_fold}' / 'final.pth'
    )

    # Now finetune full scanpath model

    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=build_scanpath_network(),
        fixation_selection_network=build_fixation_selection_network(scanpath_features=16),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[-1, -2, -3, -4],
    )

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9, 12, 15, 18, 21, 24])

    _train(train_directory / 'MIT1003_scanpath' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'MIT1003_scanpath_partially_frozen_saliency_network' / f'crossval-10-{crossval_fold}' / 'final.pth'
    )


Using random shuffles for crossvalidation
Using random shuffles for crossvalidation


100%|██████████| 808/808 [02:11<00:00,  6.15it/s]
100%|██████████| 94/94 [00:13<00:00,  6.89it/s]
Using cache found in /home/bethge/mkuemmerer31/.cache/torch/hub/pytorch_vision_v0.6.0
