In [None]:
import numpy as np
import torch
from sklearn.cluster import KMeans
from torchvision.models import resnet18

def dft(weights):
    ffts = []
    for i in range(weights.shape[0]):
        ffts.append([])
        for j in range(weights.shape[1]):
            ffts[i].append(np.fft.fft2(weights[i][j]))
    ffts = np.array(ffts)
    return np.real(ffts), np.imag(ffts)

def get_dfts_models():
    dfts = []
    for i in range(19):
        for j in range(i+1,20):
            model = torch.load(f"/kaggle/input/models/models/model_{i}_{j}.pt",map_location=torch.device('cpu'))
            dfr, _ = dft(model["conv1.weight"])
            dfr_r = np.reshape(dfr, (64,3*7*7))
            dfts.append(dfr_r)
    return np.array(dfts)

def inverse_dft(cluster_results):
    weights = []
    for i in range(cluster_results.shape[0]):
        ws = []
        for j in range(3):
            w = cluster_results[i][j*49:(j+1)*49]
            ws.append(np.real(np.fft.ifft2(np.reshape(w,(7,7)))))
        weights.append(np.array(ws))
    return np.array(weights)

dfts = get_dfts_models()
dfts_r = np.reshape(dfts, (dfts.shape[0]*dfts.shape[1], dfts.shape[2]))
kmeans = KMeans(n_clusters = 64, n_init='auto').fit(dfts_r)
final_weights = inverse_dft(kmeans.cluster_centers_)
crn = resnet18(10)
crn.conv1.weight = torch.nn.Parameter(torch.tensor(final_weights))