In [1]:
import numpy as np
from utils import generate_and_save_eeg_for_all_images
import torch
from mne.time_frequency import psd_array_multitaper
import os
import torch.nn.functional as F
from scipy.signal import spectrogram
import random
import matplotlib.pyplot as plt
from scipy.special import softmax

In [2]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
fs = 250
selected_channel_idxes = [3, 4, 5] # 'O1', 'Oz', 'O2'

In [3]:
# sub-08 get target eeg and psd
model_path = '/mnt/repo0/kyw/close-loop/sub_encoder_alexnet/sub-08/model_state_dict.pt'
lowest_path = ['/mnt/repo0/kyw/images_set/test_images/00183_tick/tick_06s.jpg']
save_path = '/mnt/repo0/kyw/close-loop/modulation'
label = ['target']
# generate_and_save_eeg_for_all_images(model_path, lowest_path, save_path, device, label)
target_path = '/mnt/repo0/kyw/close-loop/modulation/target_1.npy'
target_signal = np.load(target_path, allow_pickle=True)
selected_target_signal = target_signal[selected_channel_idxes, :]
target_psd, target_freqs = psd_array_multitaper(selected_target_signal, fs, adaptive=True, normalization='full', verbose=0) # psd(3, 126)
target_psd = torch.from_numpy(target_psd.flatten())
target_psd = target_psd.unsqueeze(0)
# print(target_psd)

In [4]:
def get_image_pool(image_set_path):
    test_images_path = []
    labels = []
    for sub_test_image in sorted(os.listdir(image_set_path)):
        if sub_test_image.startswith('.'):
            continue
        sub_image_path = os.path.join(image_set_path, sub_test_image)
        for image in sorted(os.listdir(sub_image_path)):
            if image.startswith('.'):
                continue
            image_label = os.path.splitext(image)[0]
            labels.append(image_label)
            image_path = os.path.join(sub_image_path, image)
            test_images_path.append(image_path)
    return test_images_path, labels 
image_set_path = '/mnt/repo0/kyw/images_set/test_images'
test_images_path, _ = get_image_pool(image_set_path)
target_path = '/mnt/repo0/kyw/images_set/test_images/00183_tick/tick_06s.jpg'
test_images_path.remove(target_path)

In [5]:
def get_avg_signal(signal, name, save_path):
    average_signals = np.mean(signal, axis=0)
    plt.figure(figsize=(10, 3))
    plt.plot(average_signals)
    plt.title('Average Signal')
    plt.xlabel('Time (samples)')
    plt.ylabel('Amplitude')
    plt.grid()
    plt.savefig(save_path + f'{name}_avg_signal.jpg')
    plt.show()
    return average_signals

def get_time_freq(average_signals, fs, name, save_path):
    frequencies, times, Sxx = spectrogram(average_signals, fs, nperseg=50)
    plt.figure(figsize=(10, 6))
    plt.pcolormesh(times, frequencies, 10 * np.log10(Sxx + 1e-10), shading='gouraud')
    plt.ylabel('Frequency (Hz)')
    plt.xlabel('Time (s)')
    plt.title('Time-Frequency')
    plt.colorbar(label='Intensity (dB)')
    plt.ylim(0, fs / 2)
    plt.savefig(save_path + f'{name}_time_freq.jpg')
    plt.show()

In [6]:
def get_eeg_pool(gene_eeg):
    eeg_paths = []
    for eeg in sorted(os.listdir(gene_eeg)):
        eeg_path = os.path.join(gene_eeg, eeg)
        eeg_paths.append(eeg_path)
    return eeg_paths
gene_eeg = '/mnt/repo0/kyw/close-loop/sub_encoder_alexnet_test/sub-08'
eeg_paths = get_eeg_pool(gene_eeg)
target_eeg = '/mnt/repo0/kyw/close-loop/sub_encoder_alexnet_test/sub-08/00183_tick_183.npy'
eeg_paths.remove(target_eeg)

In [7]:
def get_prob_random_sample(test_images_path, eeg_paths, fs, selected_channel_idxes, processed_paths):
    available_paths = [path for path in test_images_path if path not in processed_paths]
    sample_image_paths = random.sample(available_paths, 10)
    processed_paths.update(sample_image_paths)
    idxes = [test_images_path.index(path) for path in sample_image_paths]
    sample_eeg_paths = [eeg_paths[idx] for idx in idxes]
    similarities = []
    for sample_eeg_path in sample_eeg_paths:
        sample_eeg = np.load(sample_eeg_path, allow_pickle=True)
        selected_eeg = sample_eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        similarities.append(sim.item())
    probabilities = softmax(similarities)
    chosen_indices = np.random.choice(len(probabilities), size=2, p=probabilities)
    chosen_similarities = [similarities[idx] for idx in chosen_indices.tolist()] 
    chosen_image_paths = [sample_image_paths[idx] for idx in chosen_indices.tolist()]
    chosen_eeg_paths = [sample_eeg_paths[idx] for idx in chosen_indices.tolist()]
    return chosen_similarities, chosen_image_paths, chosen_eeg_paths
processed_paths = set()
chosen_similarities, chosen_image_paths, chosen_eeg_paths = get_prob_random_sample(test_images_path, eeg_paths, fs, selected_channel_idxes, processed_paths)
print(chosen_similarities, chosen_image_paths, chosen_eeg_paths)

[0.9245198083818775, 0.8688719949912524] ['/mnt/repo0/kyw/images_set/test_images/00107_lampshade/lampshade_05s.jpg', '/mnt/repo0/kyw/images_set/test_images/00189_tube_top/tube_top_11s.jpg'] ['/mnt/repo0/kyw/close-loop/sub_encoder_alexnet_test/sub-08/00107_lampshade_107.npy', '/mnt/repo0/kyw/close-loop/sub_encoder_alexnet_test/sub-08/00189_tube_top_189.npy']


In [None]:
def get_prob_sample(test_images_path, eeg_paths, fs, selected_channel_idxes, processed_paths):
    available_paths = [path for path in test_images_path if path not in processed_paths]
    sample_image_paths = random.sample(available_paths, 10)
    idxes = [test_images_path.index(path) for path in sample_image_paths]
    sample_eeg_paths = [eeg_paths[idx] for idx in idxes]
    print(idxes)
    print(sample_image_paths, sample_eeg_paths)

    similarities = []
    for sample_eeg_path in sample_eeg_paths:
        sample_eeg = np.load(sample_eeg_path, allow_pickle=True)
        selected_eeg = sample_eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        similarities.append(sim.item())
    print(similarities)
    probabilities = softmax(similarities)
    print(probabilities)
    top_indices = np.argsort(probabilities)[-1:] 
    top_similarity = [similarities[i] for i in top_indices]
    top_original_indices = [idxes[i] for i in top_indices]
    top_eeg_paths = [sample_eeg_paths[i] for i in top_indices]
    top_image_paths = [sample_image_paths[i] for i in top_indices]

    remaining_indices = np.setdiff1d(np.arange(len(probabilities)), top_indices)
    print(remaining_indices)
    remaining_probs = probabilities[remaining_indices] / probabilities[remaining_indices].sum()
    print(remaining_probs)
    chosen_index = np.random.choice(remaining_indices, p=remaining_probs)
    chosen_original_index = idxes[chosen_index]
    chosen_eeg_path = sample_eeg_paths[chosen_index]
    chosen_image_path = sample_image_paths[chosen_index]
    chosen_similarity = similarities[chosen_index]
    print('chosen_index:', chosen_index)
    print("Chosen original index:", chosen_original_index)
    print("Chosen sample path:", chosen_eeg_path)
    print("Chosen image path:", chosen_image_path)
    print('chosen_similarity:', chosen_similarity)
    processed_paths.update(sample_image_paths)
    return list(zip(top_original_indices, top_similarity)) + [(chosen_original_index, chosen_similarity)], top_eeg_paths, processed_paths, top_image_paths, chosen_eeg_path, chosen_image_path
# processed_paths = set()
# pair_cs, top_eeg_paths, processed_paths, top_image_paths, chosen_eeg_path, chosen_image_path= get_prob_sample(test_images_path, eeg_paths, fs, selected_channel_idxes, processed_paths)
# print(pair_cs)

In [None]:
# new_cs = []
# for top_sample_path in top_sample_paths:
#     eeg = np.load(top_sample_path, allow_pickle=True)
#     selected_eeg = eeg[selected_channel_idxes, :]
#     psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
#     psd = torch.from_numpy(psd.flatten())
#     psd = psd.unsqueeze(0)
#     sim = F.cosine_similarity(target_psd, psd)
#     new_cs.append(sim.item())
#     print(sim)
# print(new_cs)

# new_cs = []
# for _, similarity in pair_cs:
#     new_cs.append(float(similarity)) 
# print(new_cs)

In [12]:
from PIL import Image
from custom_pipeline_tjh import *
from diffusion_prior_tjh import *
import open_clip
from utils import Proj_img

vlmodel, preprocess_train, feature_extractor = open_clip.create_model_and_transforms(
    model_name = 'ViT-H-14', pretrained = None, precision='fp32', device=device
)

model_weights_path = "/mnt/repo0/kyw/open_clip_pytorch_model.bin"
model_state_dict = torch.load(model_weights_path, map_location=device)
vlmodel.load_state_dict(model_state_dict)
vlmodel.eval()

diffusion_model_path = "/mnt/repo0/kyw/close-loop/sub_model/sub-08/diffusion_250hz/ATM_S_reconstruction_scale_0_1000_40.pth"
checkpoint = torch.load(diffusion_model_path, map_location=device)
img_model = Proj_img() 
img_model.load_state_dict(checkpoint['img_model_state_dict'])
generator = Generator4Embeds(num_inference_steps=4, device=device, guidance_scale=2.0)

# def image_to_images(image_gt_path, num_images, device, num_round, file_name):
#     img_model.eval()
#     gt_image_input = torch.stack([preprocess_train(Image.open(image_gt_path).convert("RGB"))]).to(device)
#     vlmodel.to(device)
#     img_embeds = vlmodel.encode_image(gt_image_input)
#     save_img_path = f'/mnt/repo0/kyw/close-loop/loop_random/loop{num_round}'
#     os.makedirs(save_img_path, exist_ok=True)
#     batch_size = 2 
#     for batch_start in range(0, num_images, batch_size):
#         batch_images = []
#         for idx in range(batch_start, min(batch_start + batch_size, num_images)):
#             with torch.no_grad(): 
#                 image = generator.generate(img_embeds, guidance_scale=5.0)
#             save_imgs_path = os.path.join(save_img_path, f'{file_name}_{idx}.jpg') 
#             image.save(save_imgs_path)
#             print(f"图片保存至: {save_imgs_path}")
#         del batch_images
#         torch.cuda.empty_cache()

def fusion_image_to_images(image_gt_paths, num_images, device, save_path, scale):
    img_model.eval()
    img_embeds = []
    for image_gt_path in image_gt_paths:
        gt_image_input = torch.stack([preprocess_train(Image.open(image_gt_path).convert("RGB"))]).to(device)
        vlmodel.to(device)
        img_embed = vlmodel.encode_image(gt_image_input)
        img_embeds.append(img_embed)

    embed1, embed2 = img_embeds[0], img_embeds[1]
    embed_len = embed1.size(1)
    start_idx = random.randint(0, embed_len - scale - 1)
    end_idx = start_idx + scale
    temp = embed1[:, start_idx:end_idx].clone()
    embed1[:, start_idx:end_idx] = embed2[:, start_idx:end_idx]
    embed2[:, start_idx:end_idx] = temp

    save_img_path = save_path
    os.makedirs(save_img_path, exist_ok=True)
    batch_size = 2 
    for batch_start in range(0, num_images, batch_size):
        batch_images = []
        for idx in range(batch_start, min(batch_start + batch_size, num_images)):
            with torch.no_grad(): 
                image = generator.generate(embed1, guidance_scale=2.0)
            save_imgs_path = os.path.join(save_img_path, f'{scale}_{idx}.jpg') 
            image.save(save_imgs_path)
            print(f"图片保存至: {save_imgs_path}")
        del batch_images
        torch.cuda.empty_cache()


    # file_name = os.path.basename(top_image_path)
    # file_name = os.path.splitext(file_name)[0]
    # image_to_images(top_image_path, 2, device, 1, file_name)


  model_state_dict = torch.load(model_weights_path, map_location=device)
  checkpoint = torch.load(diffusion_model_path, map_location=device)


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [11]:
# 1
save_path = '/mnt/repo0/kyw/close-loop/loop_random/loop1'
fusion_image_to_images(chosen_image_paths, 4, device, save_path, 256)

NameError: name 'fusion_image_to_images' is not defined

In [None]:
fusion_image_to_images(top_image_paths, 4, device, 1, 128)
fusion_image_to_images(top_image_paths, 4, device, 1, 256)
fusion_image_to_images(top_image_paths, 4, device, 1, 512)

In [None]:
# 1
image_path_list = []
label_list = []
for image in sorted(os.listdir(save_path)):
    image_path = os.path.join(save_path, image)
    new_sample_path.append(image_path)
    image_path_list.append(image_path)
    file_name = os.path.splitext(image)[0]
    label_list.append(file_name)
print(new_sample_path)
generate_and_save_eeg_for_all_images(model_path, image_path_list, save_path, device, label_list)

In [None]:
similarities = []
for eeg in sorted(os.listdir(save_path)):
    if eeg.endswith('npy'):
        eeg_path = os.path.join(save_path, eeg)
        print(eeg_path)
        file_name = os.path.splitext(eeg)[0]
        eeg = np.load(eeg_path, allow_pickle=True)
        selected_eeg = eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        new_cs.append(sim.item())
        similarities.append(sim.item())
        # average_signals = get_avg_signal(selected_eeg, file_name, 1)
        # get_time_freq(average_signals, fs, file_name, 1)
print(similarities)
print(new_cs)

In [None]:
available_paths = [path for path in test_images_path if path not in processed_paths]
print(len(available_paths))

In [34]:
sample_image_paths = random.sample(available_paths, 4)

In [None]:
for sample_image_path in sample_image_paths:
    new_sample_path.append(sample_image_path)
print(new_sample_path)

In [None]:
print(len(new_sample_path))

In [None]:
processed_paths.update(sample_image_paths)
print(len(processed_paths))

In [None]:
idxes = [test_images_path.index(path) for path in sample_image_paths]
sample_eeg_paths = [eeg_paths[idx] for idx in idxes]
similarities = []
for sample_eeg_path in sample_eeg_paths:
    print(sample_eeg_path)
    sample_eeg = np.load(sample_eeg_path, allow_pickle=True)
    selected_eeg = sample_eeg[selected_channel_idxes, :]
    psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
    psd = torch.from_numpy(psd.flatten())
    psd = psd.unsqueeze(0)
    sim = F.cosine_similarity(target_psd, psd)
    new_cs.append(sim.item())
    similarities.append(sim.item())
print(similarities)
print(new_cs)

In [None]:
print(new_sample_path)
print(new_cs)
print(len(new_sample_path), len(new_cs))

In [None]:
probilities = softmax(new_cs)
print(probilities)

In [None]:
max_index = np.argmax(probilities)
max_probability = probilities[max_index]
print("最大概率及其索引:", max_probability, max_index)

remaining_probs = np.delete(probilities, max_index)

normalized_remaining_probs = remaining_probs / remaining_probs.sum()

random_index = np.random.choice(len(remaining_probs), p=normalized_remaining_probs)

chosen_index = np.arange(len(probilities))[np.delete(np.arange(len(probilities)), max_index)][random_index]
random_probability = probilities[chosen_index]

print("轮盘赌选择的概率及其索引:", random_probability, chosen_index)

In [None]:
best_sample_path = new_sample_path[max_index]
print(best_sample_path)
other_sample_path = new_sample_path[random_index]
print(other_sample_path)

In [None]:
best_sample_cs = new_cs[max_index]
other_sample_cs = new_cs[random_index]
print(best_sample_cs, other_sample_cs)

In [None]:
new_sample_path = []
new_sample_path.append(best_sample_path)
new_sample_path.append(other_sample_path)
print(new_sample_path)

new_cs = []
new_cs.append(best_sample_cs)
new_cs.append(other_sample_cs)
print(new_cs)

In [None]:
save_path = '/mnt/repo0/kyw/close-loop/loop_signal_result_1/fusion_loop_2'
fusion_image_to_images(new_sample_path, 4, device, save_path, 128)

In [None]:
# 2
image_path_list = []
label_list = []
for image in sorted(os.listdir(save_path)):
    image_path = os.path.join(save_path, image)
    new_sample_path.append(image_path)
    image_path_list.append(image_path)
    file_name = os.path.splitext(image)[0]
    label_list.append(file_name)
print(new_sample_path)
generate_and_save_eeg_for_all_images(model_path, image_path_list, save_path, device, label_list)

In [None]:
similarities = []
for eeg in sorted(os.listdir(save_path)):
    if eeg.endswith('npy'):
        eeg_path = os.path.join(save_path, eeg)
        print(eeg_path)
        file_name = os.path.splitext(eeg)[0]
        eeg = np.load(eeg_path, allow_pickle=True)
        selected_eeg = eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        new_cs.append(sim.item())
        similarities.append(sim.item())
        # average_signals = get_avg_signal(selected_eeg, file_name, 1)
        # get_time_freq(average_signals, fs, file_name, 1)
print(similarities)
print(new_cs)

In [None]:
available_paths = [path for path in test_images_path if path not in processed_paths]
print(len(available_paths))

In [53]:
sample_image_paths = random.sample(available_paths, 4)

In [None]:
for sample_image_path in sample_image_paths:
    new_sample_path.append(sample_image_path)
print(new_sample_path)

In [None]:
processed_paths.update(sample_image_paths)
print(len(processed_paths))

In [None]:
idxes = [test_images_path.index(path) for path in sample_image_paths]
sample_eeg_paths = [eeg_paths[idx] for idx in idxes]
similarities = []
for sample_eeg_path in sample_eeg_paths:
    print(sample_eeg_path)
    sample_eeg = np.load(sample_eeg_path, allow_pickle=True)
    selected_eeg = sample_eeg[selected_channel_idxes, :]
    psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
    psd = torch.from_numpy(psd.flatten())
    psd = psd.unsqueeze(0)
    sim = F.cosine_similarity(target_psd, psd)
    new_cs.append(sim.item())
    similarities.append(sim.item())
print(similarities)
print(new_cs)

In [None]:
print(new_sample_path)
print(new_cs)
print(len(new_sample_path), len(new_cs))

In [None]:
probilities = softmax(new_cs)
print(probilities)

In [None]:
max_index = np.argmax(probilities)
max_probability = probilities[max_index]
print("最大概率及其索引:", max_probability, max_index)

remaining_probs = np.delete(probilities, max_index)

normalized_remaining_probs = remaining_probs / remaining_probs.sum()

random_index = np.random.choice(len(remaining_probs), p=normalized_remaining_probs)

chosen_index = np.arange(len(probilities))[np.delete(np.arange(len(probilities)), max_index)][random_index]
random_probability = probilities[chosen_index]

print("轮盘赌选择的概率及其索引:", random_probability, chosen_index)

In [None]:
best_sample_path = new_sample_path[max_index]
print(best_sample_path)
other_sample_path = new_sample_path[random_index]
print(other_sample_path)

In [None]:
best_sample_cs = new_cs[max_index]
other_sample_cs = new_cs[random_index]
print(best_sample_cs, other_sample_cs)

In [None]:
new_sample_path = []
new_sample_path.append(best_sample_path)
new_sample_path.append(other_sample_path)
print(new_sample_path)

new_cs = []
new_cs.append(best_sample_cs)
new_cs.append(other_sample_cs)
print(new_cs)

In [None]:
save_path = '/mnt/repo0/kyw/close-loop/loop_signal_result_1/fusion_loop_3'
fusion_image_to_images(new_sample_path, 4, device, save_path, 128)

In [None]:
# 3
image_path_list = []
label_list = []
for image in sorted(os.listdir(save_path)):
    image_path = os.path.join(save_path, image)
    new_sample_path.append(image_path)
    image_path_list.append(image_path)
    file_name = os.path.splitext(image)[0]
    label_list.append(file_name)
print(new_sample_path)
generate_and_save_eeg_for_all_images(model_path, image_path_list, save_path, device, label_list)

In [None]:
similarities = []
for eeg in sorted(os.listdir(save_path)):
    if eeg.endswith('npy'):
        eeg_path = os.path.join(save_path, eeg)
        print(eeg_path)
        file_name = os.path.splitext(eeg)[0]
        eeg = np.load(eeg_path, allow_pickle=True)
        selected_eeg = eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        new_cs.append(sim.item())
        similarities.append(sim.item())
        # average_signals = get_avg_signal(selected_eeg, file_name, 1)
        # get_time_freq(average_signals, fs, file_name, 1)
print(similarities)
print(new_cs)

In [None]:
available_paths = [path for path in test_images_path if path not in processed_paths]
print(len(available_paths))
sample_image_paths = random.sample(available_paths, 4)
for sample_image_path in sample_image_paths:
    new_sample_path.append(sample_image_path)
print(new_sample_path)
processed_paths.update(sample_image_paths)
print(len(processed_paths))

In [None]:
idxes = [test_images_path.index(path) for path in sample_image_paths]
sample_eeg_paths = [eeg_paths[idx] for idx in idxes]
similarities = []
for sample_eeg_path in sample_eeg_paths:
    print(sample_eeg_path)
    sample_eeg = np.load(sample_eeg_path, allow_pickle=True)
    selected_eeg = sample_eeg[selected_channel_idxes, :]
    psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
    psd = torch.from_numpy(psd.flatten())
    psd = psd.unsqueeze(0)
    sim = F.cosine_similarity(target_psd, psd)
    new_cs.append(sim.item())
    similarities.append(sim.item())
print(similarities)
print(new_cs)

In [None]:
probilities = softmax(new_cs)
print(probilities)

In [None]:
max_index = np.argmax(probilities)
max_probability = probilities[max_index]
print("最大概率及其索引:", max_probability, max_index)

remaining_probs = np.delete(probilities, max_index)

normalized_remaining_probs = remaining_probs / remaining_probs.sum()

random_index = np.random.choice(len(remaining_probs), p=normalized_remaining_probs)

chosen_index = np.arange(len(probilities))[np.delete(np.arange(len(probilities)), max_index)][random_index]
random_probability = probilities[chosen_index]

print("轮盘赌选择的概率及其索引:", random_probability, chosen_index)

In [None]:
best_sample_path = new_sample_path[max_index]
print(best_sample_path)
other_sample_path = new_sample_path[random_index]
print(other_sample_path)
best_sample_cs = new_cs[max_index]
other_sample_cs = new_cs[random_index]
print(best_sample_cs, other_sample_cs)

In [None]:
new_sample_path = []
new_sample_path.append(best_sample_path)
new_sample_path.append(other_sample_path)
print(new_sample_path)

new_cs = []
new_cs.append(best_sample_cs)
new_cs.append(other_sample_cs)
print(new_cs)

In [None]:
save_path = '/mnt/repo0/kyw/close-loop/loop_signal_result_1/fusion_loop_4'
fusion_image_to_images(new_sample_path, 4, device, save_path, 128)

In [None]:
# 4
image_path_list = []
label_list = []
for image in sorted(os.listdir(save_path)):
    image_path = os.path.join(save_path, image)
    new_sample_path.append(image_path)
    image_path_list.append(image_path)
    file_name = os.path.splitext(image)[0]
    label_list.append(file_name)
print(new_sample_path)
generate_and_save_eeg_for_all_images(model_path, image_path_list, save_path, device, label_list)

In [None]:
similarities = []
for eeg in sorted(os.listdir(save_path)):
    if eeg.endswith('npy'):
        eeg_path = os.path.join(save_path, eeg)
        print(eeg_path)
        file_name = os.path.splitext(eeg)[0]
        eeg = np.load(eeg_path, allow_pickle=True)
        selected_eeg = eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        new_cs.append(sim.item())
        similarities.append(sim.item())
        # average_signals = get_avg_signal(selected_eeg, file_name, 1)
        # get_time_freq(average_signals, fs, file_name, 1)
print(similarities)
print(new_cs)

In [None]:
available_paths = [path for path in test_images_path if path not in processed_paths]
print(len(available_paths))
sample_image_paths = random.sample(available_paths, 4)
for sample_image_path in sample_image_paths:
    new_sample_path.append(sample_image_path)
print(new_sample_path)
processed_paths.update(sample_image_paths)
print(len(processed_paths))

In [None]:
idxes = [test_images_path.index(path) for path in sample_image_paths]
sample_eeg_paths = [eeg_paths[idx] for idx in idxes]
similarities = []
for sample_eeg_path in sample_eeg_paths:
    print(sample_eeg_path)
    sample_eeg = np.load(sample_eeg_path, allow_pickle=True)
    selected_eeg = sample_eeg[selected_channel_idxes, :]
    psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
    psd = torch.from_numpy(psd.flatten())
    psd = psd.unsqueeze(0)
    sim = F.cosine_similarity(target_psd, psd)
    new_cs.append(sim.item())
    similarities.append(sim.item())
print(similarities)
print(new_cs)

In [None]:
save_path = '/mnt/repo0/kyw/close-loop/loop_signal_result_1/fusion_loop_5'
new_sample_path = ['/mnt/repo0/kyw/images_set/test_images/00191_unicycle/unicycle_10s.jpg', '/mnt/repo0/kyw/close-loop/loop_signal_result_1/fusion_loop_4/128_1.jpg']
fusion_image_to_images(new_sample_path, 4, device, save_path, 128)

In [None]:
# 5
image_path_list = []
label_list = []
for image in sorted(os.listdir(save_path)):
    image_path = os.path.join(save_path, image)
    new_sample_path.append(image_path)
    image_path_list.append(image_path)
    file_name = os.path.splitext(image)[0]
    label_list.append(file_name)
print(new_sample_path)
generate_and_save_eeg_for_all_images(model_path, image_path_list, save_path, device, label_list)

In [None]:
similarities = []
for eeg in sorted(os.listdir(save_path)):
    if eeg.endswith('npy'):
        eeg_path = os.path.join(save_path, eeg)
        print(eeg_path)
        file_name = os.path.splitext(eeg)[0]
        eeg = np.load(eeg_path, allow_pickle=True)
        selected_eeg = eeg[selected_channel_idxes, :]
        psd, _ = psd_array_multitaper(selected_eeg, fs, adaptive=True, normalization='full', verbose=0)
        psd = torch.from_numpy(psd.flatten())
        psd = psd.unsqueeze(0)
        sim = F.cosine_similarity(target_psd, psd)
        new_cs.append(sim.item())
        similarities.append(sim.item())
        # average_signals = get_avg_signal(selected_eeg, file_name, 1)
        # get_time_freq(average_signals, fs, file_name, 1)
print(similarities)
print(new_cs)