# Create Centered Images

## Download the dataset

In [None]:
import tarfile

already_extracted = False

if not already_extracted:

    with tarfile.open('./data/O3a_SL_aux.tar.gz') as tar:
        tar.extractall(path='./data')

## Create the list of available files

In [None]:
import os

if not os.path.exists('./data/O3a_scatt_light_aux/'):
    os.makedirs('./data/O3a_scatt_light_aux/')

directory = './data/O3a_scatt_light_aux/'

files_paths = [elem for elem in [os.path.join(directory, path) for path in os.listdir(directory)] if elem[-3:] == '.h5']

## Plot the q-scan

In [None]:
import h5py as h5 
from gwpy.timeseries import TimeSeries
import matplotlib.pyplot as plt
import numpy as np

### Gather file
file_dir = files_paths[0]
fid = h5.File(file_dir, 'r')

### Gather data of all channels from file
grp_name = file_dir[:-3].split('/')[-1]
group_data = fid[grp_name]

### Gather data of strain channel
strain_channel = group_data['V1:Hrec_hoft_16384Hz']

### Create TimeSeries object
t = TimeSeries(strain_channel[()])
t.t0 = strain_channel.attrs['t0']
t.dt = 1.0 / strain_channel.attrs['sample_rate']

### Create q-scan in the 2 seconds interval
t0_scan = (t.times[0] + (t.times[-1] - t.times[0]) / 2).value
dt_scan = 0.5
hq = t.q_transform(outseg=(t0_scan - dt_scan, t0_scan + dt_scan), frange=(10, 100))

### Create figure and mesh with this period scan
fig = plt.figure(frameon=False, figsize=(3.21, 2.56), dpi=100)  # figsize values are set accordingly to obtain a 256x256 image
ax = fig.add_axes([0, 0, 1, 1])
mesh = ax.pcolormesh(hq)
plt.close(fig)

### Get maximum energy value from the mesh
mesh_data = mesh.get_array()
mesh_max = np.amax(mesh_data)
indices = np.where(mesh_data == mesh_max)
index = indices[0][0] % hq.times.value.shape[0]
time_val = hq.times.value[index]
dt_max = 0.25 # Interval of time around maximum value

### Create q-scan in the proper time period
hq = t.q_transform(outseg=(time_val - dt_max, time_val + dt_max), frange=(10, 100))

### Create figure with proper size and no frame
fig = plt.figure(frameon=False, figsize=(3.21, 2.56), dpi=100)  # figsize values are set accordingly to obtain a 256x256 image
ax = fig.add_axes([0, 0, 1, 1])
mesh = ax.pcolormesh(hq)
mesh.set_clim(0, 26)
cbar = fig.colorbar(mesh, label="Normalised energy")
ax.grid(False)
ax.axis('off')
ax.set_yscale('log')
cbar.remove()

## Plot q-scan for all the auxiliary channels

In [None]:
figures = []
    
for key in group_data.keys():
    
    dataset = group_data[key]
    
    # Create TimeSeries
    t = TimeSeries(dataset[()])
    t.t0 = dataset.attrs['t0']
    t.dt = 1.0 / dataset.attrs['sample_rate']
    
    # Calculate q-transform
    t0_scan = t.times[t.times.shape[0]//2].value
    dt_scan = 2.0
    hq = t.q_transform(outseg=(t0_scan - dt_scan, t0_scan + dt_scan), frange=(10, 100))
    
    # Create Figure
    fig = plt.figure(frameon=False, figsize=(3.21, 2.56), dpi=200)  # figsize values are set accordingly to obtain a 256x256 image
    ax = fig.add_axes([0, 0, 1, 1])
    mesh = ax.pcolormesh(hq)
    mesh.set_clim(0, 30)
    cbar = fig.colorbar(mesh, label="Normalised energy")
    ax.grid(False)
    ax.axis('off')
    ax.set_yscale('log')
    cbar.remove()
    plt.close(fig)
        
    figures.append(fig)

## Compress the images into a single large one

In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image

def figures_to_tensors(figs):
    tensor_figs = []
    transform = transforms.ToTensor()
    
    for fig in figs:
        fig.canvas.draw()
        buffer = np.array(fig.canvas.renderer.buffer_rgba())
        image = Image.fromarray(buffer[:256, :256, :3])
        tensor_figs.append(transform(image))
    
    return tensor_figs

tens = figures_to_tensors(figures)
final_fig = torch.cat([tensor for tensor in tens], -1)
transform = transforms.ToPILImage()
final_image = transform(final_fig)

final_image

## Put all the preciding into a function

In [None]:
def plots(file_number):

    ### Gather file
    file_dir = files_paths[file_number]
    fid = h5.File(file_dir, 'r')

    ### Gather data of all channels from file
    grp_name = file_dir[:-3].split('/')[-1]
    group_data = fid[grp_name]
    
    figures = []
    
    for key in group_data.keys():

        dataset = group_data[key]

        # Create TimeSeries
        t = TimeSeries(dataset[()])
        t.t0 = dataset.attrs['t0']
        t.dt = 1.0 / dataset.attrs['sample_rate']

        # Calculate q-transform
        t0_scan = t.times[t.times.shape[0]//2].value
        dt_scan = 0.5
        hq = t.q_transform(outseg=(t0_scan - dt_scan, t0_scan + dt_scan), frange=(10, 100))

        ### Create figure and mesh with this period scan
        fig = plt.figure(frameon=False, figsize=(3.21, 2.56), dpi=100)  # figsize values are set accordingly to obtain a 256x256 image
        ax = fig.add_axes([0, 0, 1, 1])
        mesh = ax.pcolormesh(hq)
        plt.close(fig)

        ### Get maximum energy value from the mesh
        mesh_data = mesh.get_array()
        mesh_max = np.amax(mesh_data)
        indices = np.where(mesh_data == mesh_max)
        index = indices[0][0] % hq.times.value.shape[0]
        time_val = hq.times.value[index]
        dt_max = 0.25  # Interval of time around maximum value

        ### Create q-scan in the proper time period
        hq = t.q_transform(outseg=(time_val - dt_max, time_val + dt_max), frange=(10, 100))

        # Create Figure
        fig = plt.figure(frameon=False, figsize=(3.21, 2.56), dpi=100)  # figsize values are set accordingly to obtain a 256x256 image
        ax = fig.add_axes([0, 0, 1, 1])
        mesh = ax.pcolormesh(hq)
        mesh.set_clim(0, 30)
        cbar = fig.colorbar(mesh, label="Normalised energy")
        ax.grid(False)
        ax.axis('off')
        ax.set_yscale('log')
        cbar.remove()
        plt.close(fig)

        figures.append(fig)
        
    tens = figures_to_tensors(figures)
    final_fig = torch.cat([tensor for tensor in tens], -1)
    transform = transforms.ToPILImage()
    final_image = transform(final_fig)
    
    return final_image

## Save the images locally

In [None]:
if not os.path.exists('./all_channels_images/'):
    os.makedirs('./all_channels_images/')

savedir = './all_channels_images/'

already_done = False

if not already_done:
    for i in range(len(files_paths)):
        final_image = plots(i)
        final_image.save(f'{savedir}ch_img{i}.png')

## Merge some images into a single big one

In [None]:
images_list = os.listdir(savedir)
images_list = [savedir + x for x in images_list if x[-4:] == '.png']
images_list.sort(key=lambda x: int(x.split('g')[-2].split('.')[0]))
### How many files we want to include in the same picture?
images_list = images_list[:20]

imgs = [Image.open(i) for i in images_list]

min_img_width = min(i.width for i in imgs)

total_height = 0
for i, img in enumerate(imgs):
    # If the image is larger than the minimum width, resize it
    if img.width > min_img_width:
        imgs[i] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
    total_height += imgs[i].height

# Now that we know the total height of all of the resized images, we know the height of our final image
img_merge = Image.new(imgs[0].mode, (min_img_width, total_height))
y = 0
for img in imgs:
    img_merge.paste(img, (0, y))

    y += img.height
img_merge.save('multiple_files_img.png')

### Build Pix2Pix dataset

In [None]:
ch_images = [savedir + x for x in os.listdir(savedir) if x[-4:] == '.png']

to_tensor = transforms.ToTensor()
to_image = transforms.ToPILImage()

In [None]:
num = 2  # The aux channel number that we want to consider: should go from 1 to 7

if not os.path.exists('./aux_channel_two/'):
    os.makedirs('./aux_channel_two/')
    
ch_img_dir = './aux_channel_two/'
train_size = 600

for i in range(len(ch_images)):
    tens = to_tensor(Image.open(ch_images[i]))

    first_tens = tens[:, :, :256]
    second_tens = tens[:, :, 256*num:256*(num+1)]

    final_tens = torch.cat((first_tens, second_tens), 2)

    image = to_image(final_tens)

    if i < train_size:
        image.save(ch_img_dir + f'train/img{i}.png')
    else:
        image.save(ch_img_dir + f'test/img{i-train_size}.png')

### Bluid cycleGAN dataset

In [None]:
num = 2  # The aux channel number that we want to consider

if not os.path.exists('./aux_channel_two_cycle/'):
    os.makedirs('./aux_channel_two_cycle/')

ch_img_dir = './aux_channel_two_cycle/'
train_size = 600

for i in range(len(ch_images)):
    tens = to_tensor(Image.open(ch_images[i]))

    first_tens = tens[:, :, :256]
    second_tens = tens[:, :, 256*num:256*(num+1)]
    
    strain_tens = torch.cat((first_tens, second_tens), 2)
    aux_tens = torch.cat((second_tens, first_tens), 2)

    image1 = to_image(strain_tens)
    image2 = to_image(aux_tens)

    if i < train_size:
        image1.save(ch_img_dir + f'train/strain/img{i}.png')
        image2.save(ch_img_dir + f'train/aux/img{i}.png')
    else:
        image1.save(ch_img_dir + f'test/strain/img{i-train_size}.png')
        image2.save(ch_img_dir + f'test/aux/img{i-train_size}.png')