In [None]:
import pandas as pd
import numpy as np
import os
import keras
import matplotlib.pyplot as plt
from keras.layers import Dense, GlobalAveragePooling2D, Dropout, Flatten
from keras.applications.resnet50 import ResNet50, preprocess_input
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.optimizers import Adam

In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader


# Load the pretrained ResNet-50 model
resnet50 = models.resnet50(pretrained=True)

# Modify the model to use it as a feature extractor
#Comment this when trying to compare with 2048 feature matrix
#resnet50.avgpool = torch.nn.Identity() #Replace the avgpool layer with an Identity layer, which effectively does nothing
resnet50.fc = torch.nn.Identity()  # Remove the final fully connected layer

# Set the model to evaluation mode
resnet50.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load your dataset
dataset = ImageFolder('/content/drive/MyDrive/Images', transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)


# Function to extract features
def extract_features(model, data_loader):
    model.eval()
    features = []

    with torch.no_grad():
        for inputs, _ in data_loader:
            output = model(inputs)
            #shape of features
            print("Shape of feature", output.shape)
            output = output.squeeze()  # Squeeze the output to [100352] or [2048]
            features.append(output.cpu().numpy())

    return np.array(features)

# Extract features
features = extract_features(resnet50, data_loader)

Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])


In [None]:
print(features)
print(len(features[0]))
print(len(features[1]))

[[0.29308692 0.34076434 0.0606418  ... 0.5005807  0.19622953 0.16804181]
 [0.17378132 0.15040462 0.12912093 ... 0.29789406 0.12450052 0.17119877]
 [0.06553076 0.22242767 0.08264972 ... 0.27547953 0.13562112 0.14316227]
 [0.1805996  0.5667376  1.1920793  ... 0.64303255 0.03507868 0.10211677]
 [0.4666547  0.17642798 0.10143929 ... 0.37027115 0.11223205 0.10707459]
 [0.233478   0.5380204  0.5149275  ... 0.43153307 0.0828105  0.16118416]]
2048
2048


In [None]:
#Cosine distance and L2 norm of cat 1 with other images
from scipy.spatial import distance

def cosine_similarity(feature1, feature2):
    # Normalize the feature vectors
    norm1 = np.linalg.norm(feature1)
    norm2 = np.linalg.norm(feature2)
    # print("Normalized value 1 in cosine similarity", norm1)
    # print("Normalized value 2 in cosine similarity",norm2)
    # Compute cosine similarity as dot product divided by norms
    sim = 1 - (np.dot(feature1, feature2) / (norm1 * norm2))
    return sim

def l2_norm(feature1, feature2):
    return distance.euclidean(feature1, feature2)

# Assuming 'features' is your numpy array of features from the previous code
first_image_features = features[0]  # Features of the first image

cosine_similarities = []
l2_distances = []

for feature in features[1:]:  # Start from the second element
    cos_sim = cosine_similarity(first_image_features, feature)
    l2_dist = l2_norm(first_image_features, feature)

    cosine_similarities.append(cos_sim)
    l2_distances.append(l2_dist)

# Now you have cosine similarities and L2 norms in lists

# Printing the results
print("Cosine Distance with the first image:")
print(cosine_similarities)

print("\nL2 Distances with the first image:")
print(l2_distances)

Cosine Distance with the first image:
[0.23807841539382935, 0.16347575187683105, 0.28057265281677246, 0.3209238648414612, 0.332078218460083]

L2 Distances with the first image:
[18.465999603271484, 15.076408386230469, 21.9773006439209, 23.492877960205078, 23.480634689331055]


In [None]:
def information_decay(F, D):
    """
    Applies 'information decay' to a list of features by modifying a subset of its elements.

    Parameters:
    - F (list or np.array): The feature list to be modified.
    - D (float): Decay constant used to determine the proportion of features to modify.

    Returns:
    - np.array: The modified feature list.
    """
    L = len(F)  # Length of the feature list
    N = int(L * D)  # Number of elements to modify
    old_F=F

    # Randomly sample N indexes from range 0 to L without replacement
    idxes = np.random.choice(range(L), size=N, replace=False)

    # Modify each selected index
    for i in idxes:
        # Assign a random floating-point number between 0 and 1
        F[i] = np.random.rand()

    print("Inside Information decay function")
    print(F==old_F)
    return F

In [None]:
#sensory register code

from collections import deque

class SensoryRegister:
    def __init__(self, capacity=100):
        # Initialize the sensory register with a fixed capacity
        self.capacity = capacity
        self.queue = deque(maxlen=capacity)

    def push_information(self, data, info_type, attention):
        """
        Pushes information into the sensory register. If the queue is full, the oldest information is automatically removed.

        Parameters:
        - data (any): The raw data of the information, e.g., feature vector from ResNet50.
        - info_type (str): Type of the information (e.g., 'visual', 'auditory').
        - attention (float): Attention value assigned to the information, where 0 < attention <= 1.

        Returns:
        - list: A list of information units that need to be transferred to short-term memory.
        - In this case the type is Image since we are working with images only. Can be olfactory, audio etc.
        """
        information_unit = {'data': data, 'type': info_type, 'attention': attention}
        # Push new information to the queue
        self.queue.append(information_unit)

        # Transfer to short-term memory if attention is high
        if attention > 0.5:

            return information_unit



        # There's no need to return anything if the information doesn't require transferring to STM
        return None

    def get_all_information(self):
        """
        Returns a list of all current information in the sensory register.
        """
        return list(self.queue)

# Example Usage
# sensory = SensoryRegister()
# high_attention_data = sensory.push_information(data=features[0], info_type="visual", attention=0.6)
# sensory.push_information(data=features[1], info_type="visual", attention=0.4)

# print("High attention data (to transfer to STM):", high_attention_data)
# print("All current sensory information:", sensory.get_all_information())


In [None]:
#short term memory code

import heapq

class ShortTermStore:
    def __init__(self, capacity=7, decay_rate=0.05):
        self.capacity = capacity
        self.decay_rate = decay_rate
        self.memory = []
        self.current_time = 0

    def add_trace(self, data, strength):
        if len(self.memory) >= self.capacity:
            print("Removed trace due to insufficient capacity in short term store")
            heapq.heappop(self.memory)  # Remove the weakest trace
        heapq.heappush(self.memory, (strength, self.current_time, data))
        self.current_time += 1

    def decay(self):
         # Temporarily store decayed items to re-add them properly to the heap
          temp_memory = []
          while self.memory:
              strength, time, data = heapq.heappop(self.memory)
              new_strength = strength - self.decay_rate
              if new_strength > 0:
                if new_strength <0.5:
                  print("Decaying memory trace in short term store")
                  old_data=data
                  data=information_decay(data,self.decay_rate)
                  print(data==old_data)
                temp_memory.append((new_strength, time, data))

          # Rebuild the heap from the decayed items
          for item in temp_memory:
              heapq.heappush(self.memory, item)

    def rehearse(self, index):
        # Rehearse a specific memory trace by index to increase its strength
        if index < len(self.memory):
            trace = list(heapq.heappop(self.memory))
            trace[0] += 0.1  # Increase strength
            heapq.heappush(self.memory, tuple(trace))

    def transfer_to_lts(self, lts):
        # Transfer to long-term store if strength exceeds 0.7
        temp_memory=[]
        while self.memory:
          trace = heapq.heappop(self.memory)
          if trace[0] > 0.7:
            lts.add_trace(trace[2], trace[0])
            print(f"Transferred to LTS: {trace[2]}")

          else:
            temp_memory.append(trace)

        # Rebuild the heap from the non-transferred items
        for item in temp_memory:
            heapq.heappush(self.memory, item)
        # for i in range(len(self.memory)):
        #     if self.memory[i][0] > 0.7:
        #         lts.add_trace(self.memory[i][2], self.memory[i][0])
        #         heapq.heappop(self.memory[i])


    def recall(self, new_feature, lts):
        # Find the trace with the minimum cosine distance to new_feature
        closest_match = None
        min_distance = float('inf')
        lts.recall(new_feature, 0.5, self)

        for _, _, data in self.memory:
            print("Length of data", len(data))
            distance = cosine_similarity(np.array(data), np.array(new_feature))
            print("Distance during recall in short term store", distance)
            if distance < min_distance:
                min_distance = distance
                closest_match = data

        return closest_match


In [None]:
#long term memory code

# import heapq

# class LongTermStore_1:
#     def __init__(self, decay_rate=0.0005):
#         self.memory = {}  # key: data, value: {'strength': float, 'connections': list, 'valid': bool}
#         self.decay_rate = decay_rate

#     def add_trace(self, data, strength, valid=True):
#         self.memory[data] = {'strength': strength, 'connections': [], 'valid': valid}

#     def connect_traces(self, data1, data2):
#         if data1 in self.memory and data2 in self.memory:
#             self.memory[data1]['connections'].append(data2)
#             self.memory[data2]['connections'].append(data1)

#     def decay(self):
#         # Placeholder for using information_decay function to apply decay, adjusted for dictionary structure.

#         """
#         Apply decay to each trace in the memory. If a trace's strength goes below zero,
#         it will still be retained but could be considered for cleanup if needed.
#         """
#         to_delete = []  # List to hold keys of traces to delete if necessary
#         for data, details in list(self.memory.items()):
#             # Reduce the strength
#             new_strength = details['strength'] - self.decay_rate
#             if new_strength > 0:
#               details['strength'] = new_strength
#               if(new_strength < 0.5):
#                 details['data'] = information_decay(details['data'], self.decay_rate)

#             else:
#                 # Optionally, mark the trace as invalid or delete it
#                 # details['valid'] = False
#                 # Or directly delete the trace if it's not needed
#                 to_delete.append(data)

#         # Remove traces that have decayed completely (if you choose to delete them)
#         for data in to_delete:
#             del self.memory[data]



#     def recall(self, feature, threshold, short_term_store):
#         # Find the closest match based on a feature
#         closest_match = None
#         min_distance = float('inf')

#         for data, details in self.memory.items():
#             distance = np.linalg.norm(feature - np.array(data))  # Assuming 'data' is also a vector-like feature
#             if distance < min_distance:
#                 min_distance = distance
#                 closest_match = data

#         # Check if the closest match is valid and the distance is below the threshold
#         if closest_match and self.memory[closest_match]['valid'] and min_distance < threshold:
#             # Transfer the memory trace to the short-term store
#             index=0
#             self._transfer_trace(closest_match, short_term_store, threshold, index)

#         return closest_match

#     def _transfer_trace(self, data, short_term_store, threshold, index):
#         # Recursively transfer the trace and its connected traces to the short-term store
#         stack = [data]
#         visited = set()

#         while stack:
#             current = stack.pop()
#             if current not in visited:
#                 visited.add(current)
#                 trace_details = self.memory[current]
#                 if trace_details['strength'] > threshold:
#                     short_term_store.add_trace(current, trace_details['strength'])
#                     # Rehearse the trace in STS to boost its strength, simulating memory reinforcement
#                     short_term_store.rehearse(current,index)
#                     index=index+1
#                     #trace_details['strength'] += 0.1
#                     # Add connected traces to the stack
#                     for neighbor in trace_details['connections']:
#                         stack.append(neighbor)



In [None]:
import numpy as np

#long term memory code
class LongTermStore:
    def __init__(self, decay_rate=0.0005):
        self.memory = []  # Now a list of dictionaries
        self.decay_rate = decay_rate

    def add_trace(self, data, strength, valid=True):
        # Each trace is now appended as a dictionary to the list
        self.memory.append({'data': data, 'strength': strength, 'connections': [], 'valid': valid})

    def connect_traces(self, index1, index2):
        # Connects traces by indices in the list
        if index1 < len(self.memory) and index2 < len(self.memory):
            self.memory[index1]['connections'].append(index2)
            self.memory[index2]['connections'].append(index1)

    def decay(self):
        """
        Apply decay to each trace in the memory. If a trace's strength goes below zero,
        it will still be retained but could be considered for cleanup if needed.
        """
        to_delete = []
        for i in range(len(self.memory)):
            trace = self.memory[i]
            new_strength = trace['strength'] - self.decay_rate
            if new_strength > 0:
                trace['strength'] = new_strength
                if new_strength < 0.5:
                    print("Decaying memory trace in long term store")
                    old_data=trace['data']
                    trace['data'] = information_decay(trace['data'], self.decay_rate)
                    print(old_data==trace['data'])
            else:
                to_delete.append(i)

        # Remove traces that have decayed completely
        for index in sorted(to_delete, reverse=True):
            del self.memory[index]

    def recall(self, feature, threshold, short_term_store):
        # Find the closest match based on a feature
        closest_match = None
        min_distance = float('inf')
        closest_index = -1

        for i, trace in enumerate(self.memory):
            distance = np.linalg.norm(feature - np.array(trace['data']))
            if distance < min_distance:
                min_distance = distance
                closest_match = trace['data']
                closest_index = i


        if closest_match is not None:
          # Check if the closest match is valid and the distance is below the threshold
          if self.memory[closest_index]['valid'] and min_distance < threshold:
              self._transfer_trace(closest_index, short_term_store, threshold)

        return closest_match

    def _transfer_trace(self, index, short_term_store, threshold):
        # Recursively transfer the trace and its connected traces to the short-term store
        stack = [index]
        visited = set()

        while stack:
            current = stack.pop()
            if current not in visited:
                visited.add(current)
                trace_details = self.memory[current]
                if trace_details['strength'] > threshold:
                    short_term_store.add_trace(trace_details['data'], trace_details['strength'])
                    # Rehearse the trace in STS to boost its strength, simulating memory reinforcement
                    short_term_store.rehearse(current)
                    for neighbor in trace_details['connections']:
                        stack.append(neighbor)


In [None]:
# prompt: use the above codes for enhancing it

# Create the sensory register, short-term store, and long-term store
sensory_register = SensoryRegister()
short_term_store = ShortTermStore()
long_term_store = LongTermStore()

# Push information into the sensory register
sensory_register.push_information(data=features[0], info_type="visual", attention=0.6)

# Transfer information from sensory register to short-term store
information_unit = sensory_register.get_all_information()[0]
if information_unit:
    short_term_store.add_trace(data=information_unit['data'], strength=information_unit['attention'])

# Apply decay to short-term store
short_term_store.decay()

# Transfer information from short-term store to long-term store
short_term_store.transfer_to_lts(long_term_store)

#print long term store memory
print("Long-Term Store:", long_term_store.memory)

# Recall a memory trace based on a new feature from the long term store
new_feature = features[1]
print("Length of new features",len(features[1]))
threshold = 0.5
long_term_store.recall(feature=new_feature, threshold=threshold, short_term_store=short_term_store)
recovered_feature=short_term_store.recall(new_feature,long_term_store)

print("Recovered Features", recovered_feature)
print("Size of recovered feature", len(recovered_feature))

# Print the current state of the memory stores
print("Sensory Register:", sensory_register.get_all_information())
print("Short-Term Store:", short_term_store.memory)
print("Long-Term Store:", long_term_store.memory)


Long-Term Store: []
Length of new features 2048
Length of data 2048
Distance during recall in short term store 0.23807841539382935
Recovered Features [0.29308692 0.34076434 0.0606418  ... 0.5005807  0.19622953 0.16804181]
Size of recovered feature 2048
Sensory Register: [{'data': array([0.29308692, 0.34076434, 0.0606418 , ..., 0.5005807 , 0.19622953,
       0.16804181], dtype=float32), 'type': 'visual', 'attention': 0.6}]
Short-Term Store: [(0.5499999999999999, 0, array([0.29308692, 0.34076434, 0.0606418 , ..., 0.5005807 , 0.19622953,
       0.16804181], dtype=float32))]
Long-Term Store: []


In [None]:
!unzip '/content/drive/MyDrive/I_Revnet_zip/pytorch-i-revnet-master.zip' -d /content/

Archive:  /content/drive/MyDrive/I_Revnet_zip/pytorch-i-revnet-master.zip
307413043e33540cbe9c3746ef420261f8138315
   creating: /content/pytorch-i-revnet-master/
  inflating: /content/pytorch-i-revnet-master/.gitignore  
  inflating: /content/pytorch-i-revnet-master/CIFAR_main.py  
  inflating: /content/pytorch-i-revnet-master/ILSVRC_main.py  
  inflating: /content/pytorch-i-revnet-master/LICENSE  
  inflating: /content/pytorch-i-revnet-master/README.md  
   creating: /content/pytorch-i-revnet-master/checkpoint/
   creating: /content/pytorch-i-revnet-master/checkpoint/ilsvrc2012/
   creating: /content/pytorch-i-revnet-master/checkpoint/ilsvrc2012/pre-trained/
  inflating: /content/pytorch-i-revnet-master/checkpoint/ilsvrc2012/pre-trained/info.txt  
   creating: /content/pytorch-i-revnet-master/imgs/
  inflating: /content/pytorch-i-revnet-master/imgs/algorithm.jpg  
  inflating: /content/pytorch-i-revnet-master/imgs/inverted_val_samples.jpg  
   creating: /content/pytorch-i-revnet-maste

In [None]:
def flatten_features(features):
    return np.array(features.reshape(-1))

# Assuming sensory_register is an instance of SensoryRegister
def process_through_sensory(flattened_features, sensory_register, attention):
    #flattened_features = flatten_features(features)
    sensory_register.push_information(data=flattened_features, info_type='visual', attention=attention)


# This function simulates the decision-making for transferring data between the registers
def update_memory_models(sensory_register, short_term_store, long_term_store):
    for info in sensory_register.get_all_information():
        if info['attention'] > 0.5:
            short_term_store.add_trace(info['data'], info['attention'])
        # Assume some condition for transfer to LTS or added complexity
        if info['attention'] > 0.7:
            # Transfer information from short-term store to long-term store
            short_term_store.transfer_to_lts(long_term_store)


# Periodic update functions to decay memory strengths
def decay_memory_stores(short_term_store, long_term_store):
    short_term_store.decay()
    long_term_store.decay()


def recover_image(i_revnet_model, processed_features):
    """
    Uses the i-RevNet model to recover an image from the processed features.

    Parameters:
    - i_revnet_model (torch.nn.Module): The trained i-RevNet model for image recovery.
    - processed_features (torch.Tensor): The reshaped features from which to recover the image.

    Returns:
    - torch.Tensor: The recovered image.
    """
    # Ensure the model is in evaluation mode
    i_revnet_model.eval()

    # Check if CUDA is available and move the features to the appropriate device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # i_revnet_model = i_revnet_model.to(device)
    # processed_features = processed_features.to(device)

    # Forward pass through the model
    with torch.no_grad():  # Ensure no gradients are computed in this operation
        recovered_image = i_revnet_model(processed_features)

    # The recovered_image tensor can be processed further if necessary, e.g., post-processing steps, clipping values, etc.
    return recovered_image

def recall_and_recover(long_term_store, recall_feature, threshold, short_term_store):
    recalled_data = short_term_store.recall(recall_feature,long_term_store)
    if len(recalled_data)>0:
        #reshaped_features = reshape_features(np.array(recalled_data))
        # recovered_image = recover_image(i_revnet_model, reshaped_features)
        # return recovered_image
        return recalled_data
    return None

def reshape_features(features):
    return features.reshape(1, 2048, 7, 7)


In [None]:
#compiling all steps together
#doing this for 1 feature for now
#flattening features[0] already done

# Extract features
features_new = extract_features(resnet50, data_loader)

# Create the sensory register, short-term store, and long-term store. Give own decay rates as required
sensory_register_final = SensoryRegister()
short_term_store_final = ShortTermStore(decay_rate=0.05)
long_term_store_final = LongTermStore(decay_rate=0.0005)

#Push new feature as memory to memory model
process_through_sensory(features_new[0],sensory_register_final,0.53)
process_through_sensory(features_new[1],sensory_register_final,0.5)
process_through_sensory(features_new[2],sensory_register_final,0.56)
process_through_sensory(features_new[3],sensory_register_final,0.75)
process_through_sensory(features_new[4],sensory_register_final,0.2)
process_through_sensory(features_new[4],sensory_register_final,0.1)



#print content stored in sensory register
print(sensory_register_final.get_all_information())
#Update short term store and long term store to have the trace
update_memory_models(sensory_register_final,short_term_store_final,long_term_store_final)

#print memories in short term store
print("Short-Term Store:", short_term_store_final.memory)
print("Long-Term Store:", long_term_store_final.memory)
#Decay the memroies in the short term store and long term store
decay_memory_stores(short_term_store_final,long_term_store_final)

#print memories in short term store
print("Short-Term Store:", short_term_store_final.memory)
print("Long-Term Store:", long_term_store_final.memory)


#Recall the memory from short term store when again the same feature is seen
encoded_features=recall_and_recover(long_term_store_final,features_new[0],0.5,short_term_store)


Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
Shape of feature torch.Size([1, 2048])
[{'data': array([0.29308692, 0.34076434, 0.0606418 , ..., 0.5005807 , 0.19622953,
       0.16804181], dtype=float32), 'type': 'visual', 'attention': 0.53}, {'data': array([0.17378132, 0.15040462, 0.12912093, ..., 0.29789406, 0.12450052,
       0.17119877], dtype=float32), 'type': 'visual', 'attention': 0.5}, {'data': array([0.06553076, 0.22242767, 0.08264972, ..., 0.27547953, 0.13562112,
       0.14316227], dtype=float32), 'type': 'visual', 'attention': 0.56}, {'data': array([0.1805996 , 0.5667376 , 1.1920793 , ..., 0.64303255, 0.03507868,
       0.10211677], dtype=float32), 'type': 'visual', 'attention': 0.75}, {'data': array([0.4666547 , 0.17642798, 0.10143929, ..., 0.37027115, 0.11223205,
       0.10707459], dtype=float32), 'type': 'visual', 'attention'

In [None]:
print(encoded_features.shape)
# encoded_features_flatten=flatten_features(encoded_features)
# features_array=np.array(features[0])
# features_flatten=flatten_features(features_array)
#cosine similarity of encoded features retrieved and input feature
print(encoded_features)
print(features_new[0])
cos_dist_encode= cosine_similarity(encoded_features,features_new[0])
print("Cosine distance of retrieved memory trace with input trace 1(cat)", cos_dist_encode)

#L2 Norm of encoded features retrieved and input feature
l2_norm_encode=l2_norm(encoded_features,features_new[0])
print("L2 Norm of retrieved memory trace with input trace 1(cat)", l2_norm_encode)


#Similarity of retrieved with dog input images
cos_dist_encode2=cosine_similarity(encoded_features,features_new[5])
print("Cosine distance of retrieved memory trace with input trace 5(dog)", cos_dist_encode2)

l2_norm_encode2=l2_norm(encoded_features,features_new[5])
print("L2 Norm of retrieved memory trace with input trace 5(dog)", l2_norm_encode2)


(2048,)
[0.29308692 0.34076434 0.0606418  ... 0.5005807  0.19622953 0.16804181]
[0.29308692 0.34076434 0.0606418  ... 0.5005807  0.19622953 0.16804181]
Cosine distance of retrieved memory trace with input trace 1(cat) 0.023495972156524658
L2 Norm of retrieved memory trace with input trace 1(cat) 5.869805335998535
Cosine distance of retrieved memory trace with input trace 5(dog) 0.332078218460083
L2 Norm of retrieved memory trace with input trace 5(dog) 23.480634689331055


In [None]:
!python /content/pytorch-i-revnet-master/ILSVRC_main.py --data encoded_features --nBlocks 6 16 72 6 --nStrides 2 2 2 2 \
                      --nChannels 24 96 384 1536 --init_ds 2 \
                      --resume /content/drive/MyDrive/I_Revnet_zip/ILSVRC_trained_irevnet.pth.tar --invert

In [None]:
#I_Revnet Model

"""
Code for "i-RevNet: Deep Invertible Networks"
https://openreview.net/pdf?id=HJsjkMb0Z
ICLR, 2018

(c) Joern-Henrik Jacobsen, 2018
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
#from .model_utils import split, merge, injective_pad, psi


class irevnet_block(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, first=False, dropout_rate=0.,
                 affineBN=True, mult=4):
        """ buid invertible bottleneck block """
        super(irevnet_block, self).__init__()
        self.first = first
        self.pad = 2 * out_ch - in_ch
        self.stride = stride
        self.inj_pad = injective_pad(self.pad)
        self.psi = psi(stride)
        if self.pad != 0 and stride == 1:
            in_ch = out_ch * 2
            print('')
            print('| Injective iRevNet |')
            print('')
        layers = []
        if not first:
            layers.append(nn.BatchNorm2d(in_ch//2, affine=affineBN))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_ch//2, int(out_ch//mult), kernel_size=3,
                      stride=stride, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(int(out_ch//mult), affine=affineBN))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(int(out_ch//mult), int(out_ch//mult),
                      kernel_size=3, padding=1, bias=False))
        layers.append(nn.Dropout(p=dropout_rate))
        layers.append(nn.BatchNorm2d(int(out_ch//mult), affine=affineBN))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(int(out_ch//mult), out_ch, kernel_size=3,
                      padding=1, bias=False))
        self.bottleneck_block = nn.Sequential(*layers)

    def forward(self, x):
        """ bijective or injective block forward """
        if self.pad != 0 and self.stride == 1:
            x = merge(x[0], x[1])
            x = self.inj_pad.forward(x)
            x1, x2 = split(x)
            x = (x1, x2)
        x1 = x[0]
        x2 = x[1]
        Fx2 = self.bottleneck_block(x2)
        if self.stride == 2:
            x1 = self.psi.forward(x1)
            x2 = self.psi.forward(x2)
        y1 = Fx2 + x1
        return (x2, y1)

    def inverse(self, x):
        """ bijective or injecitve block inverse """
        x2, y1 = x[0], x[1]
        if self.stride == 2:
            x2 = self.psi.inverse(x2)
        Fx2 = - self.bottleneck_block(x2)
        x1 = Fx2 + y1
        if self.stride == 2:
            x1 = self.psi.inverse(x1)
        if self.pad != 0 and self.stride == 1:
            x = merge(x1, x2)
            x = self.inj_pad.inverse(x)
            x1, x2 = split(x)
            x = (x1, x2)
        else:
            x = (x1, x2)
        return x


class iRevNet(nn.Module):
    def __init__(self, nBlocks, nStrides, nClasses, nChannels=None, init_ds=2,
                 dropout_rate=0., affineBN=True, in_shape=None, mult=4):
        super(iRevNet, self).__init__()
        self.ds = in_shape[2]//2**(nStrides.count(2)+init_ds//2)
        self.init_ds = init_ds
        self.in_ch = in_shape[0] * 2**self.init_ds
        self.nBlocks = nBlocks
        self.first = True

        print('')
        print(' == Building iRevNet %d == ' % (sum(nBlocks) * 3 + 1))
        if not nChannels:
            nChannels = [self.in_ch//2, self.in_ch//2 * 4,
                         self.in_ch//2 * 4**2, self.in_ch//2 * 4**3]

        self.init_psi = psi(self.init_ds)
        self.stack = self.irevnet_stack(irevnet_block, nChannels, nBlocks,
                                        nStrides, dropout_rate=dropout_rate,
                                        affineBN=affineBN, in_ch=self.in_ch,
                                        mult=mult)
        self.bn1 = nn.BatchNorm2d(nChannels[-1]*2, momentum=0.9)
        self.linear = nn.Linear(nChannels[-1]*2, nClasses)

    def irevnet_stack(self, _block, nChannels, nBlocks, nStrides, dropout_rate,
                      affineBN, in_ch, mult):
        """ Create stack of irevnet blocks """
        block_list = nn.ModuleList()
        strides = []
        channels = []
        for channel, depth, stride in zip(nChannels, nBlocks, nStrides):
            strides = strides + ([stride] + [1]*(depth-1))
            channels = channels + ([channel]*depth)
        for channel, stride in zip(channels, strides):
            block_list.append(_block(in_ch, channel, stride,
                                     first=self.first,
                                     dropout_rate=dropout_rate,
                                     affineBN=affineBN, mult=mult))
            in_ch = 2 * channel
            self.first = False
        return block_list

    def forward(self, x):
        """ irevnet forward """
        n = self.in_ch//2
        if self.init_ds != 0:
            x = self.init_psi.forward(x)
        out = (x[:, :n, :, :], x[:, n:, :, :])
        for block in self.stack:
            out = block.forward(out)
        out_bij = merge(out[0], out[1])
        out = F.relu(self.bn1(out_bij))
        out = F.avg_pool2d(out, self.ds)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out, out_bij

    def inverse(self, out_bij):
        """ irevnet inverse """
        out = split(out_bij)
        for i in range(len(self.stack)):
            out = self.stack[-1-i].inverse(out)
        out = merge(out[0],out[1])
        if self.init_ds != 0:
            x = self.init_psi.inverse(out)
        else:
            x = out
        return x


if __name__ == '__main__':
    model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2],
                    nChannels=[24, 96, 384, 1536], nClasses=1000, init_ds=2,
                    dropout_rate=0., affineBN=True, in_shape=[3, 224, 224],
                    mult=4)
    y = model(Variable(torch.randn(1, 3, 224, 224)))
    #print(y.size())


# Create an instance of the i_Revnet model
i_revnet_model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2],
                    nChannels=[24, 96, 384, 1536], nClasses=1000, init_ds=2,
                    dropout_rate=0., affineBN=True, in_shape=[3, 224, 224])


# Load the weights from the .pth.tar file
model_path = '/content/drive/MyDrive/I_Revnet_zip/ILSVRC_trained_irevnet.pth.tar'  # Change to the path of your model file
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
i_revnet_model.load_state_dict(checkpoint['state_dict'])

# If the file also contains optimizer state, you can load it similarly
# optimizer.load_state_dict(checkpoint['optimizer'])

In [None]:
def flatten_features(features):
    return features.reshape(-1)

# Assuming sensory_register is an instance of SensoryRegister
def process_through_sensory(features, sensory_register, attention):
    flattened_features = flatten_features(features)
    sensory_register.push_information(data=flattened_features, info_type='visual', attention=attention)


# This function simulates the decision-making for transferring data between the registers
def update_memory_models(sensory_register, short_term_store, long_term_store):
    for info in sensory_register.get_all_information():
        if info['attention'] > 0.5:
            short_term_store.add_trace(info['data'], info['attention'])
        # Assume some condition for transfer to LTS or added complexity
        if info['attention'] > 0.7:
            long_term_store.add_trace(info['data'], info['attention'])


# Periodic update functions to decay memory strengths
def decay_memory_stores(short_term_store, long_term_store):
    short_term_store.decay()
    long_term_store.decay()


def recover_image(i_revnet_model, processed_features):
    """
    Uses the i-RevNet model to recover an image from the processed features.

    Parameters:
    - i_revnet_model (torch.nn.Module): The trained i-RevNet model for image recovery.
    - processed_features (torch.Tensor): The reshaped features from which to recover the image.

    Returns:
    - torch.Tensor: The recovered image.
    """
    # Ensure the model is in evaluation mode
    i_revnet_model.eval()

    # Check if CUDA is available and move the features to the appropriate device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # i_revnet_model = i_revnet_model.to(device)
    # processed_features = processed_features.to(device)

    # Forward pass through the model
    with torch.no_grad():  # Ensure no gradients are computed in this operation
        recovered_image = i_revnet_model(processed_features)

    # The recovered_image tensor can be processed further if necessary, e.g., post-processing steps, clipping values, etc.
    return recovered_image

def recall_and_recover(i_revnet_model, long_term_store, recall_feature, threshold):
    recalled_data = long_term_store.recall(recall_feature, threshold)
    if recalled_data:
        reshaped_features = reshape_features(np.array(recalled_data))
        recovered_image = recover_image(i_revnet_model, reshaped_features)
        return recovered_image
    return None

def reshape_features(features):
    return features.reshape(6, 3072, 7, 7)


In [None]:
# importa Resnet 50 model
base_model=ResNet50(weights='imagenet',include_top=False, input_shape=(img_width, img_height, 3))

In [None]:
base_model.summary()