In [2]:
import infrastructure as inf
import numpy as np
import torch.nn as nn

import torch 
torch.manual_seed(42)

from torchvision.models import resnet18



In [3]:
def models_to_filter_per_layer(models):
    print("Using",len(models),"models to create filters")
    filters_per_layer = {}
    
    filters_per_layer["0"] = []
    for model in models:
        filters_per_layer["0"].append(model.conv1.weight.data.cpu().numpy())
    
    filters_per_layer["1"] = []
    for model in models:
        filters_per_layer["1"].append(model.layer1[0].conv1.weight.data.cpu().numpy())
    
    filters_per_layer["2"] = []
    for model in models:
        filters_per_layer["2"].append(model.layer1[0].conv2.weight.data.cpu().numpy())

    filters_per_layer["3"] = []
    for model in models:
        filters_per_layer["3"].append(model.layer1[1].conv1.weight.data.cpu().numpy())
    
    filters_per_layer["4"] = []
    for model in models:
        filters_per_layer["4"].append(model.layer1[1].conv2.weight.data.cpu().numpy())

    filters_per_layer["5"] = []
    for model in models:
        filters_per_layer["5"].append(model.layer2[0].conv1.weight.data.cpu().numpy())

    filters_per_layer["6"] = []
    for model in models:
        filters_per_layer["6"].append(model.layer2[0].conv2.weight.data.cpu().numpy())

    filters_per_layer["7"] = []
    for model in models:
        filters_per_layer["7"].append(model.layer2[1].conv1.weight.data.cpu().numpy())

    filters_per_layer["8"] = []
    for model in models:
        filters_per_layer["8"].append(model.layer2[1].conv2.weight.data.cpu().numpy())

    filters_per_layer["9"] = []
    for model in models:
        filters_per_layer["9"].append(model.layer3[0].conv1.weight.data.cpu().numpy())

    filters_per_layer["10"] = []
    for model in models:
        filters_per_layer["10"].append(model.layer3[0].conv2.weight.data.cpu().numpy())

    filters_per_layer["11"] = []
    for model in models:
        filters_per_layer["11"].append(model.layer3[1].conv1.weight.data.cpu().numpy())

    filters_per_layer["12"] = []
    for model in models:
        filters_per_layer["12"].append(model.layer3[1].conv2.weight.data.cpu().numpy())

    filters_per_layer["13"] = []
    for model in models:
        filters_per_layer["13"].append(model.layer4[0].conv1.weight.data.cpu().numpy())

    filters_per_layer["14"] = []
    for model in models:
        filters_per_layer["14"].append(model.layer4[0].conv2.weight.data.cpu().numpy())

    filters_per_layer["15"] = []
    for model in models:
        filters_per_layer["15"].append(model.layer4[1].conv1.weight.data.cpu().numpy())

    filters_per_layer["16"] = []
    for model in models:
        filters_per_layer["16"].append(model.layer4[1].conv2.weight.data.cpu().numpy())

    for key in filters_per_layer.keys():
        filters_per_layer[key] = np.array(filters_per_layer[key])

    return filters_per_layer

In [17]:
from sklearn.cluster import KMeans
def dft(weights):
    ffts = []
    for i in range(weights.shape[0]):
        ffts.append(np.fft.fft2(weights[i]))
    ffts = np.array(ffts)
    return np.real(ffts), np.imag(ffts)

def inverse_dft(cluster_results):
    weights = []
    for i in range(cluster_results.shape[0]):
        weights.append(np.fft.ifft2(cluster_results[i]))
    weights = np.array(weights)
    return np.real(weights), np.imag(weights)

# filters = filters_per_layer[layer_key]
def clustering_fourier_single_layer(filters, num_clusters):
    filters_reshaped = np.reshape(filters, (filters.shape[0]*filters.shape[1]*filters.shape[2],filters.shape[3],filters.shape[4]))
    dfts, _ = dft(filters_reshaped)
    dfts_reshaped = np.reshape(dfts,(dfts.shape[0],dfts.shape[1]*dfts.shape[2]))
    kmeans = KMeans(n_clusters = num_clusters, n_init='auto')
    cluster_labels = kmeans.fit_predict(dfts_reshaped)
    final_weights, _ = inverse_dft(np.reshape(kmeans.cluster_centers_, (dfts.shape[0],dfts.shape[1],dfts.shape[2])))
    return cluster_labels, final_weights

In [18]:
filters = np.load("saved_filters_for_key0.npy")
filters_reshaped = np.reshape(filters, (filters.shape[0]*filters.shape[1]*filters.shape[2],filters.shape[3],filters.shape[4]))
dfts, _ = dft(filters_reshaped)
np.reshape(dfts,(dfts.shape[0],dfts.shape[1]*dfts.shape[2])).shape

(384, 49)

In [16]:
def generate_gabor_filter(size, sigma, theta, Lambda, psi, gamma):
    """
    Generates a Gabor filter with given parameters.
    :param size: Size of the filter (size x size).
    :param sigma: Standard deviation of the Gaussian envelope.
    :param theta: Orientation of the Gabor filter.
    :param Lambda: Wavelength of the sinusoidal factor.
    :param psi: Phase offset.
    :param gamma: Spatial aspect ratio.
    :return: Gabor filter as a 2D array.
    """
    sigma_x = sigma
    sigma_y = sigma / gamma

    # Prepare grid in x and y
    x = np.linspace(-size // 2, size // 2, size)
    y = np.linspace(-size // 2, size // 2, size)
    x, y = np.meshgrid(x, y)

    # Rotation
    x_theta = x * np.cos(theta) + y * np.sin(theta)
    y_theta = -x * np.sin(theta) + y * np.cos(theta)

    gb = np.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) * np.cos(2 * np.pi / Lambda * x_theta + psi)
    return gb

def initialize_model_gabor():
    
    model = resnet18(num_classes=10)

    for i in range(model.conv1.out_channels):  # Output channels of conv1
        np.random.seed(i)   #generate the same filters for each model but random inside of a model
        sigma = np.random.uniform(1.5, 2.5)  # Random sigma (Standard deviation of the Gaussian envelope)
        theta = np.random.uniform(0, np.pi)  # Random theta (Orientation of the Gabor filter)
        Lambda = np.random.uniform(2, 13)  # Random Lambda (wavelength)
        psi = np.random.uniform(0, 2*np.pi)   # Random psi (phase offset)
        gamma = np.random.uniform(0.9, 1.1)  # Random gamma (aspect ratio)
        for j in range(model.conv1.in_channels):  # Input channels of conv1
            gabor_filter = generate_gabor_filter(7, sigma, theta, Lambda, psi, gamma)
            model.conv1.weight.data[i, j, :, :] = torch.from_numpy(gabor_filter)

    return model