## Generate Data augmentations using encoding models

In [4]:
import numpy as np
import nibabel as nib
import nilearn 
import matplotlib.pyplot as plt
import os
from os.path import join as opj
import pandas as pd
import seaborn as sns
import glob
from nilearn import plotting
from nilearn.image import *
import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from nilearn.plotting import plot_stat_map
from nilearn.image import mean_img
from nilearn.plotting import plot_img, plot_epi
from nilearn.maskers import NiftiMasker
from sklearn.preprocessing import StandardScaler
import wandb
import pickle
from torch.utils.data import Dataset, DataLoader
from dataset import fMRI_Dataset, fMRI_Text_Dataset
import torch
from torch import nn
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint

from pytorch_lightning.loggers import WandbLogger
from network import Encoder, ContrastiveModel
import torch
import torch.nn as nn
import pytorch_lightning as pl

import pickle
import os
import himalaya
from himalaya.backend import set_backend


In [15]:


base_path="/home/matteo/brain-diffuser/data"
augment_imges_list=[]
for sub in tqdm.tqdm(["subj01","subj02","subj05","subj07"]):
    stim_captions_train_path=opj(base_path,"nsddata_stimuli","stimuli","nsd","annotations",f"captions_train2017.json")

    processed_data=opj(base_path,"processed_data",sub)

    sub_idx=int(sub.split("0")[-1])

    imgs_train_data=opj(processed_data,f"nsd_train_stim_sub{sub_idx}.npy")
    augment_images = np.load(imgs_train_data)
    augment_imges_list.append(augment_images)


100%|██████████| 4/4 [07:03<00:00, 105.77s/it]


In [16]:
augment_imges_list=np.concatenate(augment_imges_list,axis=0)

In [6]:
## Use CLIPVision from huggingface to extract features from images

from transformers import CLIPProcessor, CLIPModel, CLIPVisionModel, AutoProcessor
import torch

# Load the CLIP model

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

model.to(device)
model.eval()




CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (position_embedding): Embedding(50, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
        

In [7]:
## Extract image features using CLIP in batches

def extract_image_features(images, model, processor, device, batch_size=128):

    image_features = []
    with torch.no_grad():
        for i in tqdm.trange(0, len(images), batch_size):
            batch_images = images[i:i + batch_size]

            inputs = processor(images=batch_images, return_tensors="pt",padding=True)["pixel_values"].to(device)
            with torch.no_grad():
                outputs = model(inputs).pooler_output
                image_features.append(outputs)

        image_features = torch.cat(image_features, dim=0)

        return image_features

In [21]:
augment_image_features = extract_image_features(augment_imges_list, model, processor, device)


100%|██████████| 277/277 [03:32<00:00,  1.30it/s]


In [31]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def find_duplicate_indices(augment_embeddings, test_embeddings, threshold=0.95):
    """
    Find indices of potential duplicates in augment_embeddings based on similarity with test_embeddings.

    Args:
        augment_embeddings (torch.Tensor): Tensor of shape (num_augment, embedding_dim) for augment images.
        test_embeddings (torch.Tensor): Tensor of shape (num_test, embedding_dim) for test images.
        threshold (float): Cosine similarity threshold above which images are considered duplicates.

    Returns:
        List[int]: Indices of potential duplicates in augment_embeddings.
    """
    duplicate_indices = []
    
    # Calculate similarity between each augment_embedding and all test_embeddings
    for i, augment_embedding in tqdm.tqdm(enumerate(augment_embeddings),total=len(augment_embeddings)):
        similarity_scores = cosine_similarity(
            augment_embedding.unsqueeze(0), test_embeddings  # Convert augment_embedding to (1, embedding_dim) for comparison
        )
        
        # Check if any similarity score exceeds the threshold
        if similarity_scores.max() > threshold:
            duplicate_indices.append(i)
    
    return duplicate_indices



In [33]:
augment_image_features = augment_image_features.cpu()

In [39]:
indices_to_remove = []

for sub in ["CSI1","CSI2","CSI3","CSI4"]:

    data_path =  f"/home/matteo/storage/brain_tuning/{subj}"
    test_features = np.load(opj(data_path, "test_image_features.npy"))
    duplicate_indices = find_duplicate_indices(augment_image_features, torch.tensor(test_features), threshold=0.95)

    indices_to_remove.extend(duplicate_indices)
    

100%|██████████| 35436/35436 [01:21<00:00, 433.13it/s]
100%|██████████| 35436/35436 [01:21<00:00, 436.88it/s]
100%|██████████| 35436/35436 [01:19<00:00, 446.22it/s]
100%|██████████| 35436/35436 [01:19<00:00, 448.27it/s]


In [42]:
## Create a clean feature tensor

clean_augment_image_features = torch.cat([augment_image_features[i].unsqueeze(0) for i in range(len(augment_image_features)) if i not in indices_to_remove], dim=0)

In [44]:
## Save the clean tensor features and indices to remove here: /home/matteo/storage/brain_tuning/

np.save("/home/matteo/storage/brain_tuning/clean_augment_image_features.npy", clean_augment_image_features.numpy())
np.save("/home/matteo/storage/brain_tuning/indices_to_remove.npy", np.array(indices_to_remove))


In [8]:
# clean_augment_image_features = np.load("/home/matteo/storage/brain_tuning/clean_augment_image_features.npy")
# indices_to_remove = np.load("/home/matteo/storage/brain_tuning/indices_to_remove.npy")

## Load Image encoding model to produce data augmentations

In [11]:
import pickle
import os
import himalaya
from himalaya.backend import set_backend


## encode the augmented images using the image encoding model 
device_id = 0
torch.cuda.set_device(device_id)  # Set the current device

backend = set_backend("torch_cuda")

for subj in ["CSI1","CSI2","CSI3","CSI4"]:
    data_path =  f"/home/matteo/storage/brain_tuning/{subj}"
    top_voxels = np.load(os.path.join(data_path, "top_voxels.npy"))

    # Load the encoding model from the pickle file
    with open(os.path.join(data_path, "encoding_model.pkl"), "rb") as f:
        encoding_model = pickle.load(f)

    augmented_brain = encoding_model.predict(backend.asarray(clean_augment_image_features).to(f'cuda:{device_id}'))[:, top_voxels]

    np.save(os.path.join(data_path, "augmented_brain.npy"), augmented_brain.cpu().numpy())


## Save also augmented captions

In [72]:
base_path="/home/matteo/brain-diffuser/data"
augment_text_list=[]
for sub in tqdm.tqdm(["subj01","subj02","subj05","subj07"]):
    stim_captions_train_path=os.path.join(base_path,"nsddata_stimuli","stimuli","nsd","annotations",f"captions_train2017.json")
    sub_idx=int(sub.split("0")[-1])

    processed_data=os.path.join(base_path,"processed_data",sub)
    captions_train_data=os.path.join(processed_data, f"nsd_train_cap_sub{sub_idx}.npy")


    augment_texts = np.load(captions_train_data,allow_pickle=True)
    augment_text_list.append(augment_texts)

augment_text_list=np.concatenate(augment_text_list,axis=0)

100%|██████████| 4/4 [00:00<00:00, 112.99it/s]


In [78]:
# indices_to_remove
clean_augment_text_list = np.array([augment_text_list[i] for i in range(len(augment_text_list)) if i not in indices_to_remove])

In [82]:
## Save the clean augment text list here: /home/matteo/storage/brain_tuning/

np.save("/home/matteo/storage/brain_tuning/clean_augment_text_list.npy", clean_augment_text_list)

## Extract Text Features

In [1]:
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import tqdm

# Load the CLIP text model and tokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

text_model.to(device)
text_model.eval()

## Extract text features using CLIP in batches
def extract_text_features(texts, text_model, text_tokenizer, device, batch_size=32):
    text_features = []
    with torch.no_grad():
        for i in tqdm.trange(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]

            # Tokenize and process the text batch
            inputs = text_tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            outputs = text_model(**inputs).pooler_output  # Get the pooled output for the text features
            
            text_features.append(outputs)

        # Concatenate all features along the batch dimension
        text_features = torch.cat(text_features, dim=0)

    return text_features


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
clean_augment_text_features = extract_text_features(clean_augment_text_list[:,0].tolist(), text_model, text_tokenizer, device)

  0%|          | 0/1108 [00:00<?, ?it/s]

100%|██████████| 1108/1108 [00:14<00:00, 78.30it/s] 


In [4]:
## Save the clean text features here: /home/matteo/storage/brain_tuning/

np.save("/home/matteo/storage/brain_tuning/clean_augment_text_features.npy", clean_augment_text_features.cpu().numpy())

In [2]:
# clean_augment_text_features = np.load("/home/matteo/storage/brain_tuning/clean_augment_text_features.npy")

## Encode augmented text features with text encoding models

In [6]:
## encode the augmented images using the image encoding model 
device_id = 0
torch.cuda.set_device(device_id)  # Set the current device

backend = set_backend("torch_cuda")

for subj in ["CSI1","CSI2","CSI3","CSI4"]:
    data_path =  f"/home/matteo/storage/brain_tuning/{subj}"
    top_voxels = np.load(os.path.join(data_path, "TEXT_top_voxels.npy"))

    # Load the encoding model from the pickle file
    with open(os.path.join(data_path, "TEXT_encoding_model.pkl"), "rb") as f:
        encoding_model = pickle.load(f)

    augmented_brain = encoding_model.predict(backend.asarray(clean_augment_text_features).to(f'cuda:{device_id}'))[:, top_voxels]

    np.save(os.path.join(data_path, "TEXT_augmented_brain.npy"), augmented_brain.cpu().numpy())