In [1]:
# from functools import partial
from pathlib import Path
# from typing import Optional, Tuple
# import cv2
# import fire
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from PIL import Image
from scipy.sparse.linalg import eigsh
# from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
# from torchvision.utils import draw_bounding_boxes
from tqdm import tqdm
import extract_utils as utils
from torch.utils.data import Dataset, DataLoader
from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead
from torch import nn
import torchvision
from scipy.ndimage import affine_transform
import datetime, dateutil

## Extract Eigen Vectors

In [2]:
images_root="/home/phdcs2/Hard_Disk/Datasets/Deep-Spectral-Segmentation/data/object-segmentation/ECSSD/images"
features_dir="/home/phdcs2/Hard_Disk/Datasets/Deep-Spectral-Segmentation/data/object-segmentation/ECSSD/features/dino_vits16"
output_dir="/home/phdcs2/Hard_Disk/Datasets/Deep-Spectral-Segmentation/data/object-segmentation/ECSSD/eigs_dot1PCApred_ds_model_10_jn"
which_matrix= 'laplacian'
which_color_matrix= 'knn'
which_features= 'k'
normalize=True
threshold_at_zero=True
lapnorm= True
K= 5
image_downsample_factor = None
image_color_lambda = 0.0
multiprocessing = 0
batch_size=2
epochs=10

## Incorporating SimSiam

In [3]:
# Define a custom dataset class
class Feature_Dataset(Dataset):
    def __init__(self, features):
        self.features = features

    def __getitem__(self, index):
        return self.features[index]

    def __len__(self):
        return len(self.features)

In [4]:
class SimSiam(nn.Module):
    def __init__(self):
        super().__init__()
        self.projection_head = SimSiamProjectionHead(feats.shape[1], 128,feats.shape[1])
        self.prediction_head = SimSiamPredictionHead(feats.shape[1], 128, feats.shape[1])

    def forward(self, x):
        z = self.projection_head(x)
        p = self.prediction_head(z)
        z = z.detach()
        return z, p

## Model Based Optimization

In [5]:
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
model_path="/home/phdcs2/Hard_Disk/Datasets/Deep-Spectral-Segmentation/data/object-segmentation/ECSSD/weights/dot1_pcapredlinear_ds_10_model_jn%s.pt" % (timestamp)
device = "cuda" if torch.cuda.is_available() else "cpu"
utils.make_output_dir(output_dir)

inputs = list(enumerate(sorted(Path(features_dir).iterdir())))
pca_comp=128
pca = PCA(n_components=pca_comp)
layer=nn.Linear(pca_comp,pca_comp).cuda()

criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(layer.parameters(), lr=0.06)



print("Starting Model Based Training")
best_loss = float('inf')
for epoch in range(epochs):
    print("epoch", epoch+1)
    total_loss = 0
    for inp in tqdm(inputs):
        index, features_file = inp
         # Load
        data_dict = torch.load(features_file, map_location='cpu')
        # image_id = data_dict['file'][:-4]
        # Load
        feats = data_dict[which_features].squeeze().cuda()
        # print("Without normalizing, Features Shape is",feats.shape)
        if normalize:
            feats = F.normalize(feats, p=2, dim=-1)

        # Get sizes
        B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
        if image_downsample_factor is None:
            image_downsample_factor = P
        H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor

        # Upscale features to match the resolution
        if (H_patch, W_patch) != (H_pad_lr, W_pad_lr):
            feats = F.interpolate(
                feats.T.reshape(1, -1, H_patch, W_patch),
                size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False
            ).reshape(-1, H_pad_lr * W_pad_lr).T

        ### Model-Based Optimization
        x0=feats
        x0_arr=x0.cpu()
        
        # print(x0_arr.shape)
        z0_arr= pca.fit_transform(x0_arr)

        
        # Define the affine transformation parameters
        scale = np.random.uniform(0.8, 1.2)  # Random scaling factor between 0.8 and 1.2
        translation = np.random.uniform(-10, 10, size=2)  # Random translation vector between -10 and 10 in both directions
        rotation = np.random.uniform(-15, 15)  # Random rotation angle between -15 and 15 degrees
        shear = np.random.uniform(-0.2, 0.2, size=2)  # Random shear factor between -0.2 and 0.2 in both directions

        # Define the affine matrix
        affine_matrix = np.array([[scale * np.cos(rotation), -shear[0] * scale * np.sin(rotation), translation[0]],
                                  [shear[1] * scale * np.sin(rotation), scale * np.cos(rotation), translation[1]],
                                  [0, 0, 1]])
        z1_arr=affine_transform(z0_arr, affine_matrix)
        z1_arr=pca.fit_transform(z1_arr)
        z0 = torch.from_numpy(z0_arr).float()
        z1 = torch.from_numpy(z1_arr).float()
        
        z0 = z0.to(device)
        z1 = z1.to(device)
        p0=layer(z0)
        p1=layer(z1)
            
#         x0 = x0.unsqueeze(0).to(device)
#         x1=torchvision.transforms.RandomAffine(0)(x0)
#         x0_new = x0.view(feats.shape[0], 1, 384)
#         x1_new = x1.view(feats.shape[0], 1, 384)
#         z0, p0 = model_simsiam(x0_new, True)
#         z1, p1 = model_simsiam(x1_new, True)

        loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_val_loss = total_loss / len(inputs)
    print(f"epoch: {epoch:>02}, loss: {avg_val_loss:.5f}")
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(layer.state_dict(), model_path)
        print("Saved Best Model! in epoch", epoch+1)
    else:
        print("Weigh not updated in epoch", epoch+1)



Starting Model Based Training
epoch 1


  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explained_variance_ratio_ = explained_variance_ / total_var
  explai

In [None]:
pca_comp=128
pca = PCA(n_components=pca_comp)
utils.make_output_dir(output_dir)
inputs = list(enumerate(sorted(Path(features_dir).iterdir())))
for inp in tqdm(inputs):
    index, features_file = inp
    print(index, features_file)
     # Load
    data_dict = torch.load(features_file, map_location='cpu')
    print(data_dict.keys())   #['k', 'indices', 'file', 'id', 'model_name', 'patch_size', 'shape']
    # print("shape=", data_dict['shape'], "k shape", data_dict['k'].shape, "patch_size=", data_dict['patch_size'])
    image_id = data_dict['file'][:-4]
    print(image_id)
    # Load
    output_file = str(Path(output_dir) / f'{image_id}.pth')
    if Path(output_file).is_file():
        print(f'Skipping existing file {str(output_file)}')
        # break
        # return  # skip because already generated

    # Load affinity matrix
    feats = data_dict[which_features].squeeze().cuda()
    # print("Without normalizing, Features Shape is",feats.shape)
    if normalize:
        feats = F.normalize(feats, p=2, dim=-1)
    # print("After normalization, Features Shape",feats.shape)
    # print("which_matrix=", which_matrix)
    # Eigenvectors of affinity matrix
    if which_matrix == 'affinity_torch':
        W = feats @ feats.T
        # W_feat=contrastive_affinity(feats, feats.T)
        # print("W shape=", W.shape)
        if threshold_at_zero:
            W = (W * (W > 0))
            # print("W shape=", W.shape)
        eigenvalues, eigenvectors = torch.eig(W, eigenvectors=True)
        eigenvalues = eigenvalues.cpu()
        eigenvectors = eigenvectors.cpu()
        print("which matrix=",which_matrix, "eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)


    # Eigenvectors of affinity matrix with scipy
    elif which_matrix == 'affinity_svd':
        USV = torch.linalg.svd(feats, full_matrices=False)
        eigenvectors = USV[0][:, :K].T.to('cpu', non_blocking=True)
        eigenvalues = USV[1][:K].to('cpu', non_blocking=True)
        print("which matrix=",which_matrix,"eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)

    # Eigenvectors of affinity matrix with scipy
    elif which_matrix == 'affinity':
        # print("Without normalizing, Features Shape is",feats.shape)
        W = (feats @ feats.T)
        # W_feat=contrastive_affinity(feats, feats.T)
        # print("W shape=", W.shape)
        if threshold_at_zero:
            W = (W * (W > 0))
        W = W.cpu().numpy()
        # print("W shape=", W.shape)
        eigenvalues, eigenvectors = eigsh(W, which='LM', k=K)
        eigenvectors = torch.flip(torch.from_numpy(eigenvectors), dims=(-1,)).T
        print("which matrix=",which_matrix, "eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)

    # Eigenvectors of matting laplacian matrix
    elif which_matrix in ['matting_laplacian', 'laplacian']:

        # Get sizes
        B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
        if image_downsample_factor is None:
            image_downsample_factor = P
        H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor

        # Upscale features to match the resolution
        if (H_patch, W_patch) != (H_pad_lr, W_pad_lr):
            feats = F.interpolate(
                feats.T.reshape(1, -1, H_patch, W_patch),
                size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False
            ).reshape(-1, H_pad_lr * W_pad_lr).T

        ### Feature affinities
        # print("Without normalizing, Features Shape is",feats.shape)

        W_feat_ds = (feats @ feats.T)
#         layer=nn.Linear(pca_comp,pca_comp).cuda()
#         x0=feats
#         x0_arr=x0.cpu()
#         # print(x0_arr.shape)
#         z0_arr= pca.fit_transform(x0_arr)

#         # Define the affine transformation parameters
#         scale = np.random.uniform(0.8, 1.2)  # Random scaling factor between 0.8 and 1.2
#         translation = np.random.uniform(-10, 10, size=2)  # Random translation vector between -10 and 10 in both directions
#         rotation = np.random.uniform(-15, 15)  # Random rotation angle between -15 and 15 degrees
#         shear = np.random.uniform(-0.2, 0.2, size=2)  # Random shear factor between -0.2 and 0.2 in both directions

#         # Define the affine matrix
#         affine_matrix = np.array([[scale * np.cos(rotation), -shear[0] * scale * np.sin(rotation), translation[0]],
#                                   [shear[1] * scale * np.sin(rotation), scale * np.cos(rotation), translation[1]],
#                                   [0, 0, 1]])
#         z1_arr=affine_transform(z0_arr, affine_matrix)
#         z1_arr=pca.fit_transform(z1_arr)
#         z0 = torch.from_numpy(z0_arr).float()
#         z1 = torch.from_numpy(z1_arr).float()

#         # feat_list.append(feats)
#         feat_dataset_z0 = Feature_Dataset(z0)
#         if feats.shape[0]%2==0:
#             features_dataloader_z0 = DataLoader(feat_dataset_z0, batch_size=batch_size, shuffle=True)
#         else:
#             features_dataloader_z0 = DataLoader(feat_dataset_z0, batch_size=batch_size, shuffle=True, drop_last=True)

#         feat_dataset_z1 = Feature_Dataset(z1)
#         if feats.shape[0]%2==0:
#             features_dataloader_z1 = DataLoader(feat_dataset_z1, batch_size=batch_size, shuffle=True)
#         else:
#             features_dataloader_z1 = DataLoader(feat_dataset_z1, batch_size=batch_size, shuffle=True, drop_last=True)

#         device = "cuda" if torch.cuda.is_available() else "cpu"

#         criterion = NegativeCosineSimilarity()
#         optimizer = torch.optim.SGD(layer.parameters(), lr=0.06)
#         print("Starting Training")
#         for epoch in range(epochs):
#             total_loss = 0
#             for z0_new,z1_new in zip(features_dataloader_z0,features_dataloader_z1):
#                 z0_new = z0_new.to(device)
#                 z1_new = z1_new.to(device)
#     #             print("z0_new.shape", z0_new.shape)
#     #             print("z1_new.shape", z1_new.shape)
#                 p0=layer(z0_new)
#                 p1=layer(z1_new)
#     #             print("p0.shape", p0.shape)
#     #             print("p1.shape", p1.shape)
#                 loss = 0.5 * (criterion(z0_new, p1) + criterion(z1_new, p0))
#                 total_loss += loss.detach()
#                 loss.backward()
#                 optimizer.step()
#                 optimizer.zero_grad()
#             avg_loss = total_loss / len(features_dataloader_z0)
#             print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
#         projected_feature=layer(z0.to(device))
#         print(projected_feature.shape)
#         W_feat_siam=torch.matmul(projected_feature, projected_feature.t())
        x0=feats
        x0_arr=x0.cpu()
        z0_arr= pca.fit_transform(x0_arr)
        z0 = torch.from_numpy(z0_arr).float()
        z0 = z0.to(device)
        projected_feature=layer(z0)
        print(projected_feature.shape)
        W_feat_siam=torch.matmul(projected_feature, projected_feature.t())
        W_feat=W_feat_ds + 0.1*W_feat_siam
        if threshold_at_zero:
            W_feat = (W_feat * (W_feat > 0))
        W_feat = W_feat / W_feat.max()  # NOTE: If features are normalized, this naturally does nothing
        # W_feat = W_feat.cpu().numpy()
        W_feat = W_feat.detach().cpu().numpy()
        # print("W_feat shape=",W_feat.shape)

        ### Color affinities
        # If we are fusing with color affinites, then load the image and compute
        if image_color_lambda > 0:

            # Load image
            image_file = str(Path(images_root) / f'{image_id}.jpg')
            image_lr = Image.open(image_file).resize((W_pad_lr, H_pad_lr), Image.BILINEAR)
            image_lr = np.array(image_lr) / 255.

            # Color affinities (of type scipy.sparse.csr_matrix)
            if which_color_matrix == 'knn':
                W_lr = utils.knn_affinity(image_lr / 255)
            elif which_color_matrix == 'rw':
                W_lr = utils.rw_affinity(image_lr / 255)

            # Convert to dense numpy array
            W_color = np.array(W_lr.todense().astype(np.float32))
            # print("W_color shape", W_color.shape)

        else:

            # No color affinity
            W_color = 0

        # Combine
        W_comb = W_feat + W_color * image_color_lambda  # combination
        D_comb = np.array(utils.get_diagonal(W_comb).todense())  # is dense or sparse faster? not sure, should check
        # print("W_comb shape= ", W_comb.shape, "D_comb shape",  D_comb.shape)
        if lapnorm:
            try:
                eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM', M=D_comb)
            except:
                eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM', M=D_comb)
        else:
            try:
                eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM')
            except:
                eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM')
        eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float()
    print("eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)
    # Sign ambiguity
    for k in range(eigenvectors.shape[0]):
        if 0.5 < torch.mean((eigenvectors[k] > 0).float()).item() < 1.0:  # reverse segment
            eigenvectors[k] = 0 - eigenvectors[k]

    # Save dict
    output_dict = {'eigenvalues': eigenvalues, 'eigenvectors': eigenvectors}
    torch.save(output_dict, output_file)

In [None]:
# utils.make_output_dir(output_dir)
# feat_list=[]
# inputs = list(enumerate(sorted(Path(features_dir).iterdir())))
# for inp in tqdm(inputs):
#     index, features_file = inp
#     print(index, features_file)
#      # Load
#     data_dict = torch.load(features_file, map_location='cpu')
#     print(data_dict.keys())   #['k', 'indices', 'file', 'id', 'model_name', 'patch_size', 'shape']
#     # print("shape=", data_dict['shape'], "k shape", data_dict['k'].shape, "patch_size=", data_dict['patch_size'])
#     image_id = data_dict['file'][:-4]
#     print(image_id)
#     # Load
#     output_file = str(Path(output_dir) / f'{image_id}.pth')
#     if Path(output_file).is_file():
#         print(f'Skipping existing file {str(output_file)}')
#         # break
#         # return  # skip because already generated

#     # Load affinity matrix
#     feats = data_dict[which_features].squeeze().cuda()
#     # print("Without normalizing, Features Shape is",feats.shape)
#     if normalize:
#         feats = F.normalize(feats, p=2, dim=-1)
#     # print("After normalization, Features Shape",feats.shape)
#     # print("which_matrix=", which_matrix)
#     # Eigenvectors of affinity matrix
#     if which_matrix == 'affinity_torch':
#         W = feats @ feats.T
#         # W_feat=contrastive_affinity(feats, feats.T)
#         # print("W shape=", W.shape)
#         if threshold_at_zero:
#             W = (W * (W > 0))
#             # print("W shape=", W.shape)
#         eigenvalues, eigenvectors = torch.eig(W, eigenvectors=True)
#         eigenvalues = eigenvalues.cpu()
#         eigenvectors = eigenvectors.cpu()
#         print("which matrix=",which_matrix, "eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)


#     # Eigenvectors of affinity matrix with scipy
#     elif which_matrix == 'affinity_svd':
#         USV = torch.linalg.svd(feats, full_matrices=False)
#         eigenvectors = USV[0][:, :K].T.to('cpu', non_blocking=True)
#         eigenvalues = USV[1][:K].to('cpu', non_blocking=True)
#         print("which matrix=",which_matrix,"eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)

#     # Eigenvectors of affinity matrix with scipy
#     elif which_matrix == 'affinity':
#         # print("Without normalizing, Features Shape is",feats.shape)
#         W = (feats @ feats.T)
#         # W_feat=contrastive_affinity(feats, feats.T)
#         # print("W shape=", W.shape)
#         if threshold_at_zero:
#             W = (W * (W > 0))
#         W = W.cpu().numpy()
#         # print("W shape=", W.shape)
#         eigenvalues, eigenvectors = eigsh(W, which='LM', k=K)
#         eigenvectors = torch.flip(torch.from_numpy(eigenvectors), dims=(-1,)).T
#         print("which matrix=",which_matrix, "eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)

#     # Eigenvectors of matting laplacian matrix
#     elif which_matrix in ['matting_laplacian', 'laplacian']:

#         # Get sizes
#         B, C, H, W, P, H_patch, W_patch, H_pad, W_pad = utils.get_image_sizes(data_dict)
#         if image_downsample_factor is None:
#             image_downsample_factor = P
#         H_pad_lr, W_pad_lr = H_pad // image_downsample_factor, W_pad // image_downsample_factor

#         # Upscale features to match the resolution
#         if (H_patch, W_patch) != (H_pad_lr, W_pad_lr):
#             feats = F.interpolate(
#                 feats.T.reshape(1, -1, H_patch, W_patch),
#                 size=(H_pad_lr, W_pad_lr), mode='bilinear', align_corners=False
#             ).reshape(-1, H_pad_lr * W_pad_lr).T

#         ### Feature affinities
#         # print("Without normalizing, Features Shape is",feats.shape)

#         W_feat_ds = (feats @ feats.T)
#         feats_arr=feats.cpu()
        
#         # feat_list.append(feats)
#         feat_dataset = Feature_Dataset(feats)
#         if feats.shape[0]%2==0:
#             features_dataloader = DataLoader(feat_dataset, batch_size=2, shuffle=True)
#         else:
#             features_dataloader = DataLoader(feat_dataset, batch_size=2, shuffle=True, drop_last=True)
#         layer=nn.Linear(pca_comp,pca_comp).cuda()
#         device = "cuda" if torch.cuda.is_available() else "cpu"
    
#         criterion = NegativeCosineSimilarity()
#         optimizer = torch.optim.SGD(model_simsiam.parameters(), lr=0.06)
#         print("Starting Training")
#         for epoch in range(1):
#             total_loss = 0
#             for x0 in features_dataloader:
#             # for (x0), _, _ in features_dataloader:
#                 print(x0.shape)
# #                 z0=pca.fit_transform()
#                 x0 = x0.unsqueeze(0).to(device)
#                 # print(x0.shape)
#                 x1=torchvision.transforms.RandomAffine(0)(x0)
#                 # print(x1.shape)
#                 x0 = x0.squeeze(0).to(device)
#                 # print("x0.shape=", x0.shape)
#                 x1 = x1.squeeze(0).to(device)
#                 # print("x1 shape=", x1.shape)

#                 z0, p0 = model_simsiam(x0)
#                 z1, p1 = model_simsiam(x1)
#                 loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
#                 total_loss += loss.detach()
#                 loss.backward()
#                 optimizer.step()
#                 optimizer.zero_grad()
#             avg_loss = total_loss / len(features_dataloader)
#             print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")
#         projected_feature=model_simsiam(feats)
#         W_feat_siam=torch.matmul(projected_feature[0], projected_feature[0].t())
#         W_feat=W_feat_ds + W_feat_siam
#         # print("W_feat.shape=", W_feat.shape)
#         # print("W_feat.shape=", W_feat.shape)
#         # W_feat=contrastive_affinity(feats, feats.T)
#         if threshold_at_zero:
#             W_feat = (W_feat * (W_feat > 0))
#         W_feat = W_feat / W_feat.max()  # NOTE: If features are normalized, this naturally does nothing
#         # W_feat = W_feat.cpu().numpy()
#         W_feat = W_feat.detach().cpu().numpy()
#         # print("W_feat shape=",W_feat.shape)

#         ### Color affinities
#         # If we are fusing with color affinites, then load the image and compute
#         if image_color_lambda > 0:

#             # Load image
#             image_file = str(Path(images_root) / f'{image_id}.jpg')
#             image_lr = Image.open(image_file).resize((W_pad_lr, H_pad_lr), Image.BILINEAR)
#             image_lr = np.array(image_lr) / 255.

#             # Color affinities (of type scipy.sparse.csr_matrix)
#             if which_color_matrix == 'knn':
#                 W_lr = utils.knn_affinity(image_lr / 255)
#             elif which_color_matrix == 'rw':
#                 W_lr = utils.rw_affinity(image_lr / 255)

#             # Convert to dense numpy array
#             W_color = np.array(W_lr.todense().astype(np.float32))
#             # print("W_color shape", W_color.shape)

#         else:

#             # No color affinity
#             W_color = 0

#         # Combine
#         W_comb = W_feat + W_color * image_color_lambda  # combination
#         D_comb = np.array(utils.get_diagonal(W_comb).todense())  # is dense or sparse faster? not sure, should check
#         # print("W_comb shape= ", W_comb.shape, "D_comb shape",  D_comb.shape)
#         if lapnorm:
#             try:
#                 eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM', M=D_comb)
#             except:
#                 eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM', M=D_comb)
#         else:
#             try:
#                 eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, sigma=0, which='LM')
#             except:
#                 eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=K, which='SM')
#         eigenvalues, eigenvectors = torch.from_numpy(eigenvalues), torch.from_numpy(eigenvectors.T).float()
#     print("eigenvalues shape", eigenvalues.shape, "eigenvectors shape", eigenvectors.shape)
#     # Sign ambiguity
#     for k in range(eigenvectors.shape[0]):
#         if 0.5 < torch.mean((eigenvectors[k] > 0).float()).item() < 1.0:  # reverse segment
#             eigenvectors[k] = 0 - eigenvectors[k]

#     # Save dict
#     output_dict = {'eigenvalues': eigenvalues, 'eigenvectors': eigenvectors}
#     torch.save(output_dict, output_file)