Goal of this script:
Find mean vectors in the latent space, for each modality of our variables (age, gender)

in order to make some modifications to data, and challenge our classifiers

# 0. Path and import

In [5]:
import sys, os

#path = "drive/MyDrive/projet_digitale/Spectrogram_Reconstruction_Model"

#os.chdir(path)

!ls

sys.path.append(os.getcwd() + os.sep + 'src/')

__pycache__			models.py
data				results
latent_space_exploration.ipynb	spectrogram_classifier.ipynb
modality_shift.ipynb		src
models


In [6]:
import numpy as np

from scipy.io.wavfile import write as write_waveform
from collections import OrderedDict
import matplotlib.pyplot as plt
import pickle
import pandas as pd
import pathlib
import glob2

import torch
import torch.nn.functional as F
import torch.nn as nn


from spectrogram_stream import SpectrogramStream
from autoencoders import ConvolutionalAutoencoder
from encoders import ConvolutionalEncoder
from bottlenecks import ConvolutionalBottleneck
from reconstructors import ConvolutionalDecoder
from visualization import spectrogram_to_waveform, compute_reconstruction_plot

# 1. Load encoded data

In [3]:
#save/load encoded data
save = False
load = True
if save :
  pickle.dump(projection, open("data/" + "projection.pickle", "wb"))
  pickle.dump(sound_id, open("data/" + "sound_id.pickle", "wb"))

if load:
  projection = pickle.load(open("data/" + "projection.pickle", "rb"))
  sound_id = pickle.load(open("data/" + "sound_id.pickle", "rb"))
    
projection = np.array(projection)
projection.shape

# format in dataframe with ids
representation = pd.DataFrame(projection, index=sound_id)

In [7]:
#load labels file
label_df = pd.read_csv("data/labels.tsv", sep='\t').drop(columns="Unnamed: 0").set_index("sound_id")

In [8]:
# load model
def find_last_checkpoint(models_path, experiment_name):
    checkpoint_path_pattern = os.path.join(models_path, experiment_name + '_checkpoint_*.pth')
    print(f'Looking for checkpoints at {checkpoint_path_pattern}')
    checkpoints = glob2.glob(checkpoint_path_pattern)
    checkpoints = sorted(checkpoints, key=lambda x: int(os.path.splitext(x)[0].split('_')[-1]))
    print(f'Loading checkpoint {checkpoints[-1]}...')
    last_checkpoint = torch.load(checkpoints[-1], map_location=device)
    print(f'Checkpoint {checkpoints[-1]} loaded')
    return last_checkpoint

data_path = 'data'
models_path = 'models'
experiment_name = 'dataset2filtered_b64_baseline_larger_l1'
results_path = os.path.join('results', experiment_name)
frame_step = 46
n_iter = 300
sampling_rate = 16000
n_images = 10

pathlib.Path(results_path).mkdir(parents=True, exist_ok=True)

# Device
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

# Select available device
device = torch.device("cpu") #torch.device("cuda:0")#torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')


# Build and load model
print('Building model...')
model = ConvolutionalAutoencoder(
    encoder=ConvolutionalEncoder(),
    bottleneck=ConvolutionalBottleneck(),
    reconstructor=ConvolutionalDecoder()
).to(device)

print('Searching last checkpoint...')
checkpoint = find_last_checkpoint(models_path, experiment_name)

print('Loading checkpoint...')
new_state_dict = OrderedDict()
for mod in ['model', 'optimizer']:
    new_state_dict[mod] = {}
    for k, v in checkpoint[mod if mod in checkpoint else mod + '_exception'].items():
        new_state_dict[mod][k.replace('module.', '')] = v
checkpoint = new_state_dict

model.load_state_dict(checkpoint['model'])

Device: cpu
Building model...
Searching last checkpoint...
Looking for checkpoints at models/dataset2filtered_b64_baseline_larger_l1_checkpoint_*.pth
Loading checkpoint models/dataset2filtered_b64_baseline_larger_l1_checkpoint_126633.pth...
Checkpoint models/dataset2filtered_b64_baseline_larger_l1_checkpoint_126633.pth loaded
Loading checkpoint...


<All keys matched successfully>

# 2. Shift representation with mean modality vectors

In [4]:
def compute_mean_vector(var = "gender",mod = 1):
    """
    function to compute mean vector for chosen variable and modality
    """
    # select ids for which we have variable = modality
    ids = label_df.loc[label_df[var] == mod].index
    
    # select latent representation for selected ids
    latent_df = representation.loc[representation.index.isin(ids)]
    
    return np.array(latent_df.apply(np.mean,axis=0))

In [59]:
# compute and save mean vector for age and gender
for var in ["gender", "age"]:
    for mod in [0,1,2]:
        
        if var == "gender" and mod == 2:
            pass
        
        else:
            mean = compute_mean_vector(var, mod)
            pickle.dump(mean, open("data/" + "mean_" + var + str(mod) + ".pickle", "wb"))            

In [60]:
# shift our data & save it in "projection_shift_age01.pickle"
for var in ["gender", "age"]:
    for start_mod in [0,1,2]:
        for end_mod in [0,1,2]:

            if (var == "gender" and (start_mod == 2 or end_mod == 2)) or (start_mod == end_mod):
                pass

            else:
                # read file with mean vectors for the 2 modalities
                file0,file1 = "data/" + "mean_" + var + str(start_mod) + ".pickle", "data/" + "mean_" + var + str(end_mod) + ".pickle"
                start_vector, end_vector = pickle.load( open(file0, "rb")), pickle.load( open(file1, "rb"))

                # Shift our data in latent space 
                #(for convenience, we shift all our data using vectorial difference between 2 modalities)
                file = "data/" + "projection_shift_" + var + str(start_mod) + str(end_mod) + ".pickle"
                modified_representation = representation + (end_vector - start_vector)
                pickle.dump(modified_representation, open(file, "wb"))

# 3. Reconstruct spectrograms from modified representations

In [28]:
def reconstruct(var="age",start_mod=0, end_mod = 1,N=50):
    """
    load shifted data, reconstruct spectrograms and save results in pickle file
    """
    # load data
    file = "data/" + "projection_shift_" + var + str(start_mod) + str(end_mod) + ".pickle"
    representation = pickle.load(open(file,"rb"))
    
    # cast the latent representation into correct shape for reconstruction, and cast in torch tensor
    representation = torch.Tensor(representation.to_numpy().reshape(-1,1,24,11))
    representation = representation.to(device)
    #print(representation.shape)
    
    # reconstruct
    reconstruction_list= []
    for n in range(0, representation.shape[0]//N +1 ):
        with torch.no_grad():
            reconstruction = model.reconstructor.forward(representation[n*N:(n+1)*N])

        reconstruction_list.append(reconstruction)#reconstruction.to("cpu"))
        print(n)
        #del reconstruction
        if n%N == 0:
            print(n)
        #    torch.cuda.empty_cache()
    
    # save
    file_out = "data/" + "reconstruction_shifted_" + var + str(start_mod) + str(end_mod) + ".pickle"
    pickle.dump(torch.cat(tuple(reconstruction_list)).numpy(), open(file_out,"wb"))
    
    return "Done"

In [29]:
for var in ["gender", "age"]:
    for start_mod in [0,1,2]:
        for end_mod in [0,1,2]:

            if (var == "gender" and (start_mod == 2 or end_mod == 2)) or (start_mod == end_mod):
                pass

            else:
                reconstruct(var=var,start_mod=start_mod, end_mod=end_mod)

0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


KeyboardInterrupt: 