In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pywt
from matplotlib import pyplot as plt

from sklearn.model_selection import train_test_split

from scipy.spatial.distance import cdist

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_probability as tfp

from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.python.ops.numpy_ops import np_config
from tensorflow.train import BytesList
from tensorflow.train import Example, Features, Feature
np_config.enable_numpy_behavior()

import librosa as lb

import requests
from tqdm import trange, tqdm
from pathlib import Path
import os

from model import WaveletAE
from utils import get_style_correlation_transform

from vggish_preprocessing.preprocess_sound import preprocess_sound
import vggish_preprocessing.vggish_params as vggish_params

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

# IMuse - Image To Music Style Transfer
## Data EDA and Preprocessing

### Image Wiki Art Sentiment Data

In [None]:
wikiart_images = pd.read_csv('../data/image/WikiArt-Emotions/WikiArt-info.tsv', sep='\t')
print(f'{wikiart_images.shape[0]} unique images')
wikiart_images.head()

In [None]:
wikiart_images = wikiart_images[['ID', 'Image URL']]
wikiart_images.columns = ['id', 'url']
wikiart_images.head()

In [None]:
wikiart_images[(wikiart_images.url.str.endswith('.jpg')) | (wikiart_images.url.str.endswith('.JPG'))].shape[0]

Almost all of the data is .jpg so we can remove the 4 png files in order to keep the consistency.

In [None]:
wikiart_images = wikiart_images[(wikiart_images.url.str.endswith('.jpg')) | (wikiart_images.url.str.endswith('.JPG'))]

In [None]:
def download_img(url, directory, file_id):
    image = requests.get(url).content
    with open(f'{directory}/{file_id}.jpg', 'wb') as handler:
        handler.write(image)

# for img_id, imd_url in tqdm(wikiart_images.values):
#     download_img(imd_url, '../data/image/wikiart', img_id)

### Emotion Votes

In [None]:
arts_emotions = pd.read_csv('../data/image/WikiArt-Emotions/WikiArt-Emotions-All.tsv', sep='\t')
arts_emotions.head()

In [None]:
# Keeping the "Image Only" emotions
arts_emotions = arts_emotions[['ID', *arts_emotions.columns[29:49]]]
arts_emotions.columns = [col.split(':')[-1].strip() for col in arts_emotions.columns]
arts_emotions.columns = ['id', *arts_emotions.columns[1:]]
arts_emotions.head()

In [None]:
#Create column with top emotion associated with artwork 
prob_df4 = arts_emotions.loc[:, ('agreeableness', 'anger', 'anticipation','arrogance', 'disagreeableness',
       'disgust', 'fear','gratitude', 'happiness', 'humility', 'love',
       'optimism', 'pessimism','regret', 'sadness','shame', 'shyness',
       'surprise', 'trust','neutral')]
arts_emotions["emotion"] = prob_df4.idxmax(axis = 1)
arts_emotions.head()

In [None]:
arts_emotions.emotion.unique()

In [None]:
emotion_quadrants = {
    'happiness': 'Q1',
    'surprise': 'Q1',
    'sadness': 'Q3',
    'disagreeableness': 'Q2',
    'fear': 'Q2',
    'trust': 'Q1',
    'anticipation': 'Q2',
    'humility': 'Q4',
    'shame': 'Q3',
    'arrogance': 'Q2',
    'love': 'Q1',
    'disgust': 'Q2',
    'optimism': 'Q1',
    'anger': 'Q2',
    'pessimism': 'Q3',
    'neutral': 'Q4',
    'gratitude': 'Q1',
    'agreeableness': 'Q4',
    'shyness': 'Q4',
    'happy': 'Q1',
    'sad': 'Q3',
    'tender': 'Q1',
    'high val.': 'Q1',
    'low val.': 'Q2',
    'high ener.': 'Q1',
    'low ener.': 'Q4',
    'high tens.': 'Q2',
    'low tens.': 'Q4',
    'anger high': 'Q2',
    'anger mod.': 'Q3',
    'fear high': 'Q2',
    'fear mod.': 'Q3',
    'happy high': 'Q1',
    'happy mod.': 'Q2',
    'sad high': 'Q3',
    'sad mod.': 'Q4',
    'tender high': 'Q4',
    'tender mod.': 'Q1',
    'valence pos. high': 'Q1',
    'valence pos. mod.': 'Q1',
    'valence neg. mod.': 'Q2',
    'valence neg. high': 'Q2',
    'energy pos. high': 'Q1',
    'energy pos. mod.': 'Q1',
    'energy neg. mod.': 'Q2',
    'energy neg. high': 'Q2',
    'tension pos. high': 'Q2',
    'tension pos. mod.': 'Q1',
    'tension neg. mod.': 'Q3',
    'tension neg. high': 'Q4',
}

In [None]:
def set_quadrant(df):
    df['quadrant'] = 0

    def get_quadrant(row):
        row.quadrant = emotion_quadrants[row.emotion]
        
        return row

    return df.apply(get_quadrant, axis=1)

arts_emotions = set_quadrant(arts_emotions)
arts_emotions = arts_emotions[['id', 'quadrant']]

### Music Data Preparation

In [None]:
osts_set_1 = pd.read_csv('../data/music/OSTs/set1_tracklist.csv', index_col=0)
osts_set_1['set'] = 1
osts_set_2 = pd.read_csv('../data/music/OSTs/set2_tracklist.csv', index_col=0)
osts_set_2['set'] = 2

osts = pd.concat([osts_set_1, osts_set_2])
osts = osts[['Emotion', 'set']]
osts.columns = ['emotion', 'set']
osts['emotion'] = osts.emotion.str.lower()
osts = set_quadrant(osts)

osts.head()

In [None]:
songs = pd.read_csv('../data/music/others/annotations.csv')
songs.columns = ['song', 'quadrant']
songs.head()

In [None]:
osts['music'] = '../data/music/OSTs/Set' + osts.set.astype(str) + '/' + osts.index.map('{0:0=3d}'.format) + '.mp3'
songs['music'] = '../data/music/others/' + songs.quadrant + '/' + songs.song + '.mp3'

osts = osts[['music', 'quadrant']]
songs = songs[['music', 'quadrant']]

music_data = pd.concat([osts, songs]).reset_index(drop=True)

### Map Music to Images

In [None]:
def set_related_music(df, already_set_df = None, use_second_set = False):
    df['music'] = ''

    q1_ids = music_data[(music_data.quadrant == 'Q1')].music.values
    q2_ids = music_data[(music_data.quadrant == 'Q2')].music.values
    q3_ids = music_data[(music_data.quadrant == 'Q3')].music.values
    q4_ids = music_data[(music_data.quadrant == 'Q4')].music.values
    
    df.loc[df.quadrant == 'Q1', 'music'] = np.random.choice(
        q1_ids,
        df[df.quadrant == 'Q1'].shape[0],
        replace=q1_ids.shape[0] < df[df.quadrant == 'Q1'].shape[0]
    )
    df.loc[df.quadrant == 'Q2', 'music'] = np.random.choice(
        q2_ids,
        df[df.quadrant == 'Q2'].shape[0],
        replace=q2_ids.shape[0] < df[df.quadrant == 'Q2'].shape[0]
    )
    df.loc[df.quadrant == 'Q3', 'music'] = np.random.choice(
        q3_ids,
        df[df.quadrant == 'Q3'].shape[0],
        replace=q3_ids.shape[0] < df[df.quadrant == 'Q3'].shape[0]
    )
    df.loc[df.quadrant == 'Q4', 'music'] = np.random.choice(
        q4_ids,
        df[df.quadrant == 'Q4'].shape[0],
        replace=q4_ids.shape[0] < df[df.quadrant == 'Q4'].shape[0]
    )
    
    return df

data = set_related_music(arts_emotions)
data['img'] = '../data/image/wikiart/' + data.id + '.jpg'
data.drop(['id'], 1, inplace=True)
data.head()

### TF Dataset

In [None]:
x_train, x_test, y_train, y_test = train_test_split(data.music, data.img, test_size=0.08, stratify=data.quadrant, shuffle=True)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.15)

train_ds = pd.DataFrame({'x': x_train, 'y': y_train})
test_ds = pd.DataFrame({'x': x_test, 'y': y_test})
val_ds = pd.DataFrame({'x': x_val, 'y': y_val})

train_ds = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(train_ds))
test_ds = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(test_ds))
val_ds = tf.data.Dataset.from_tensor_slices(tf.convert_to_tensor(val_ds))

x_train.reset_index(drop=True)
x_test.reset_index(drop=True)

x_val.reset_index(drop=True)
y_val.reset_index(drop=True)

y_train.reset_index(drop=True)
y_test.reset_index(drop=True)

In [None]:
def _bytes_feature(value, raw_string = False):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.numpy() if not raw_string else value]))

def normalized_wt_downsampling(x, wavelet, level):
        LL = pywt.wavedec2(x, wavelet, 'periodization', level)[0]
        LL = LL / np.abs(LL).max()

        return LL

def per_channel_wd(img, level=1, wavelet='haar'):
    r, g, b = tf.unstack(img, axis=2)
    r = normalized_wt_downsampling(r, wavelet, level)
    g = normalized_wt_downsampling(g, wavelet, level)
    b = normalized_wt_downsampling(b, wavelet, level)

    return tf.stack([r, g, b], axis=2)

class DatasetGenerator:
    def __init__(self, max_dim_size = 128, vgg_input_max_size = 512):
        self.max_dim_size = max_dim_size
        self.vgg_input_max_size = vgg_input_max_size
        self.wavelet_ae = WaveletAE()
    
    def load_image(self, img_path):
        self.img_path = img_path
        self.img_bytes = tf.io.read_file(img_path)
        self.img_raw = tf.image.decode_image(self.img_bytes, channels = 3, dtype=tf.float32)
        
#         dwt_level = tf.experimental.numpy.log2(tf.reduce_max(self.img_raw.shape) / (self.max_dim_size + ((tf.reduce_max(self.img_raw.shape) - tf.reduce_min(self.img_raw.shape)) / 2)))
#         dwt_level = tf.round(dwt_level)
#         dwt_level = tf.cast(dwt_level, tf.uint8)
        
#         self.img_resized = per_channel_wd(self.img_raw, dwt_level)
#         self.img_resized = tfa.image.gaussian_filter2d(self.img_resized, (6, 6), sigma=6e-1)
#         self.img_resized = tf.image.resize_with_crop_or_pad(self.img_resized, self.max_dim_size, self.max_dim_size)
        
        ar = self.img_raw.shape[0] / self.img_raw.shape[1]
        if ar > 1:
            size = [self.vgg_input_max_size, int(self.vgg_input_max_size / ar)]
        else:
            size = [int(ar * self.vgg_input_max_size), self.vgg_input_max_size]

        self.img_raw = tf.image.resize(self.img_raw, size)
            
    def load_music(self, music_path):
        audio_data, sr = lb.load(music_path)
        train_len = 10 * sr
        random_start = np.random.randint(audio_data.shape[0] - train_len)
        audio_data = audio_data[random_start : random_start + train_len]

        self.spec = preprocess_sound(audio_data, sr)
        
    def get_style_transormations(self):
        feat, _ = self.wavelet_ae.get_features(tf.expand_dims(self.img_raw, 0))
        self.style_ede, self.style_means = self.wavelet_ae.get_style_correlations(tf.expand_dims(self.img_raw, 0), ede=False)

        for i in range(len(self.style_ede)):
            self.style_ede[i] = tf.cast(tfp.math.fill_triangular_inverse(self.style_ede[i], upper=True), tf.float16)

        for i in range(len(self.style_means)):
            self.style_means[i] = tf.cast(self.style_means[i], tf.float16)
    
    def process(self, music, img):
        self.load_image(img)
        self.load_music(music)
        self.get_style_transormations()
    
    def serialize_information(self):
#         img_resized = tf.cast(self.img_resized * 255, tf.uint8)
#         img_resized = tf.image.encode_jpeg(img_resized)
        
        features = Features(feature = {
            'img_path': _bytes_feature(self.img_path.encode('utf-8'), raw_string=True),
            
#             'resized_image': _bytes_feature(img_resized),
            
            'block1_feat': _bytes_feature(tf.io.serialize_tensor(self.style_ede[0][0])),
            'block1_mean': _bytes_feature(tf.io.serialize_tensor(self.style_means[0][0])),
            
            'block2_feat': _bytes_feature(tf.io.serialize_tensor(self.style_ede[1][0])),
            'block2_mean': _bytes_feature(tf.io.serialize_tensor(self.style_means[1][0])),
            
            'block3_feat': _bytes_feature(tf.io.serialize_tensor(self.style_ede[2][0])),
            'block3_mean': _bytes_feature(tf.io.serialize_tensor(self.style_means[2][0])),
            
            'block4_feat': _bytes_feature(tf.io.serialize_tensor(self.style_ede[3][0])),
            'block4_mean': _bytes_feature(tf.io.serialize_tensor(self.style_means[3][0])),

            'music_spec': _bytes_feature(tf.io.serialize_tensor(self.spec)),
        })

        return Example(features=features).SerializeToString()

datagen = DatasetGenerator()

In [None]:
datagen = DatasetGenerator()
tf_record_options = tf.io.TFRecordOptions(compression_type = "GZIP")

BASE_DATA_DIR = Path(os.getcwd()).parent  / "data" / "tfrecords-whithout-ede"

def write_as_TFRecords(dataset, target_dir, batch_size, datagen):
    dataset = dataset.batch(batch_size)
    dataset_len = len(list(dataset))
    for part_id, data in enumerate(dataset):
        filename = str(target_dir / f"{part_id}.tfrecord")
        with tf.io.TFRecordWriter(filename, options = tf_record_options) as writer:
            for music, image in tqdm(data):
                music_path = music.numpy().decode("utf-8")
                image_path = image.numpy().decode("utf-8")
                if not Path(music_path).is_file() or not Path(image_path).is_file():
                    continue
                datagen.process(music_path, image_path)
                writer.write(datagen.serialize_information())
            writer.close()
            
write_as_TFRecords(dataset = train_ds,
                   target_dir = BASE_DATA_DIR / "train", 
                   batch_size = 1024,
                   datagen = datagen)

write_as_TFRecords(dataset = test_ds,
                   target_dir = BASE_DATA_DIR / "test", 
                   batch_size = len(list(test_ds)),
                   datagen = datagen)

write_as_TFRecords(dataset = val_ds,
                   target_dir = BASE_DATA_DIR / "val", 
                   batch_size = len(list(val_ds)),
                   datagen = datagen)